Skip to content

Commit 2a9e5b7

Browse files
Aftabbsaftabbs
authored andcommitted
fix: replace in-place dataclass mutations with dataclasses.replace()
Resolves the in-place mutation warnings introduced by the _warn_on_inplace_mutation guard in PR #10650. Running `hatch run test:unit | grep "Mutating attribute"` surfaced mutations across five components. Each is replaced with `dataclasses.replace(instance, field=new_value)` so that dataclass instances are never mutated after creation. Changed files: - components/builders/chat_prompt_builder.py: replace _content mutation on rendered ChatMessage copy with dataclasses.replace() - core/pipeline/pipeline.py: replace two-field mutation on PipelineSnapshot (agent_snapshot + break_point) with a single dataclasses.replace() call - components/converters/image/file_to_image.py: replace ByteStream.mime_type mutation with dataclasses.replace() - components/extractors/llm_metadata_extractor.py: replace Document.content mutation with dataclasses.replace() (already imported `replace`) - components/fetchers/link_content.py: replace ByteStream.mime_type mutations in both sync and async run() methods - components/joiners/document_joiner.py: replace Document.score mutations in _score_norm, _reciprocal_rank_fusion, and _distribution_based_rank_fusion with non-mutating list comprehensions using dataclasses.replace() Also updates test_document_joiner.py::test_list_with_one_empty_list to compare by document ID rather than object identity, since the test previously relied on the mutation side-effect to make the assertion pass. Fixes #10659
1 parent 2261284 commit 2a9e5b7

7 files changed

Lines changed: 35 additions & 27 deletions

File tree

haystack/components/builders/chat_prompt_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
from copy import deepcopy
7+
from dataclasses import replace
78
from typing import Any, Literal
89

910
from jinja2.sandbox import SandboxedEnvironment
@@ -267,9 +268,8 @@ def run(
267268
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)
268269
compiled_template = self._env.from_string(message.text)
269270
rendered_text = compiled_template.render(template_variables_combined)
270-
# deep copy the message to avoid modifying the original message
271-
rendered_message: ChatMessage = deepcopy(message)
272-
rendered_message._content = [TextContent(text=rendered_text)]
271+
# use dataclasses.replace to avoid in-place mutation of the copied message
272+
rendered_message: ChatMessage = replace(deepcopy(message), _content=[TextContent(text=rendered_text)])
273273
processed_messages.append(rendered_message)
274274
else:
275275
processed_messages.append(message)

haystack/components/converters/image/file_to_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import mimetypes
6+
from dataclasses import replace
67
from pathlib import Path
78
from typing import Any, Literal
89

@@ -124,7 +125,7 @@ def run(
124125
continue
125126

126127
if bytestream.mime_type is None and isinstance(source, Path):
127-
bytestream.mime_type = mimetypes.guess_type(source.as_posix())[0]
128+
bytestream = replace(bytestream, mime_type=mimetypes.guess_type(source.as_posix())[0])
128129

129130
if bytestream.data == _EMPTY_BYTE_STRING:
130131
logger.warning("File {source} is empty. Skipping it.", source=source)

haystack/components/extractors/llm_metadata_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def _prepare_prompts(
263263
for idx, page in enumerate(pages["documents"]):
264264
if idx + 1 in expanded_range:
265265
content += page.content
266-
doc_copy.content = content
266+
doc_copy = replace(doc_copy, content=content)
267267
else:
268268
doc_copy = document
269269

haystack/components/fetchers/link_content.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import defaultdict
77
from collections.abc import Callable
88
from concurrent.futures import ThreadPoolExecutor
9+
from dataclasses import replace
910
from fnmatch import fnmatch
1011
from typing import cast
1112

@@ -248,7 +249,7 @@ def run(self, urls: list[str]):
248249
if len(urls) == 1:
249250
stream_metadata, stream = self._fetch(urls[0])
250251
stream.meta.update(stream_metadata)
251-
stream.mime_type = stream.meta.get("content_type", None)
252+
stream = replace(stream, mime_type=stream.meta.get("content_type", None))
252253
streams.append(stream)
253254
else:
254255
with ThreadPoolExecutor() as executor:
@@ -257,7 +258,7 @@ def run(self, urls: list[str]):
257258
for stream_metadata, stream in results: # type: ignore
258259
if stream_metadata is not None and stream is not None:
259260
stream.meta.update(stream_metadata)
260-
stream.mime_type = stream.meta.get("content_type", None)
261+
stream = replace(stream, mime_type=stream.meta.get("content_type", None))
261262
streams.append(stream)
262263

263264
return {"streams": streams}
@@ -302,7 +303,7 @@ async def run_async(self, urls: list[str]):
302303
stream_metadata, stream = result_tuple
303304
if stream_metadata is not None and stream is not None:
304305
stream.meta.update(stream_metadata)
305-
stream.mime_type = stream.meta.get("content_type", None)
306+
stream = replace(stream, mime_type=stream.meta.get("content_type", None))
306307
streams.append(stream)
307308

308309
return {"streams": streams}

haystack/components/joiners/document_joiner.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import itertools
66
from collections import defaultdict
7+
from dataclasses import replace
78
from enum import Enum
89
from math import inf
910
from typing import Any
@@ -190,10 +191,7 @@ def _merge(self, document_lists: list[list[Document]]) -> list[Document]:
190191
scores_map[doc.id] += (doc.score if doc.score else 0) * weight
191192
documents_map[doc.id] = doc
192193

193-
for doc in documents_map.values():
194-
doc.score = scores_map[doc.id]
195-
196-
return list(documents_map.values())
194+
return [replace(doc, score=scores_map[doc.id]) for doc in documents_map.values()]
197195

198196
def _reciprocal_rank_fusion(self, document_lists: list[list[Document]]) -> list[Document]:
199197
"""
@@ -223,10 +221,7 @@ def _reciprocal_rank_fusion(self, document_lists: list[list[Document]]) -> list[
223221
for _id in scores_map:
224222
scores_map[_id] /= len(document_lists) / k
225223

226-
for doc in documents_map.values():
227-
doc.score = scores_map[doc.id]
228-
229-
return list(documents_map.values())
224+
return [replace(doc, score=scores_map[doc.id]) for doc in documents_map.values()]
230225

231226
@staticmethod
232227
def _distribution_based_rank_fusion(document_lists: list[list[Document]]) -> list[Document]:
@@ -236,26 +231,29 @@ def _distribution_based_rank_fusion(document_lists: list[list[Document]]) -> lis
236231
(https://medium.com/plain-simple-software/distribution-based-score-fusion-dbsf-a-new-approach-to-vector-search-ranking-f87c37488b18)
237232
If a Document is in more than one retriever, the one with the highest score is used.
238233
"""
234+
rescaled_lists: list[list[Document]] = []
239235
for documents in document_lists:
240236
if len(documents) == 0:
237+
rescaled_lists.append(documents)
241238
continue
242239

243-
scores_list = []
244-
245-
for doc in documents:
246-
scores_list.append(doc.score if doc.score is not None else 0)
240+
scores_list = [doc.score if doc.score is not None else 0 for doc in documents]
247241

248242
mean_score = sum(scores_list) / len(scores_list)
249243
std_dev = (sum((x - mean_score) ** 2 for x in scores_list) / len(scores_list)) ** 0.5
250244
min_score = mean_score - 3 * std_dev
251245
max_score = mean_score + 3 * std_dev
252246
delta_score = max_score - min_score
253247

254-
for doc in documents:
255-
doc.score = (doc.score - min_score) / delta_score if delta_score != 0.0 else 0.0
256-
# if all docs have the same score delta_score is 0, the docs are uninformative for the query
248+
# if all docs have the same score delta_score is 0, the docs are uninformative for the query
249+
rescaled_lists.append(
250+
[
251+
replace(doc, score=(doc.score - min_score) / delta_score if delta_score != 0.0 else 0.0)
252+
for doc in documents
253+
]
254+
)
257255

258-
return DocumentJoiner._concatenate(document_lists=document_lists)
256+
return DocumentJoiner._concatenate(document_lists=rescaled_lists)
259257

260258
def to_dict(self) -> dict[str, Any]:
261259
"""

haystack/core/pipeline/pipeline.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from collections.abc import Mapping
6+
from dataclasses import replace
67
from typing import Any
78

89
from haystack import logging, tracing
@@ -409,8 +410,11 @@ def run( # noqa: PLR0915, PLR0912, C901
409410
# agent snapshot and attach it to the pipeline snapshot we create here.
410411
# We also update the break_point to be an AgentBreakpoint.
411412
if error.pipeline_snapshot and error.pipeline_snapshot.agent_snapshot:
412-
pipeline_snapshot.agent_snapshot = error.pipeline_snapshot.agent_snapshot
413-
pipeline_snapshot.break_point = error.pipeline_snapshot.agent_snapshot.break_point
413+
pipeline_snapshot = replace(
414+
pipeline_snapshot,
415+
agent_snapshot=error.pipeline_snapshot.agent_snapshot,
416+
break_point=error.pipeline_snapshot.agent_snapshot.break_point,
417+
)
414418

415419
# Attach the pipeline snapshot to the error before re-raising
416420
error.pipeline_snapshot = pipeline_snapshot

test/components/joiners/test_document_joiner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ def test_list_with_one_empty_list(self, join_mode: JoinMode):
102102
joiner = DocumentJoiner(join_mode=join_mode)
103103
documents = [Document(content="a"), Document(content="b"), Document(content="c")]
104104
result = joiner.run([[], documents])
105-
assert result == {"documents": documents}
105+
# Verify the same documents are returned (scoring functions assign scores to the results;
106+
# compare by ID to avoid relying on in-place score mutation of the input list).
107+
result_ids = {doc.id for doc in result["documents"]}
108+
expected_ids = {doc.id for doc in documents}
109+
assert result_ids == expected_ids
106110

107111
def test_unsupported_join_mode(self):
108112
unsupported_mode = "unsupported_mode"

0 commit comments

Comments
 (0)