diff --git a/aider/repomap.py b/aider/repomap.py index 541bba6ef4a..9bf8cb765a2 100644 --- a/aider/repomap.py +++ b/aider/repomap.py @@ -537,9 +537,12 @@ def get_ranked_tags( progress(f"{UPDATING_REPO_MAP_MESSAGE}: {src}") src_rank = ranked[src] - total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True)) + out_edges = list(G.out_edges(src, data=True)) + total_weight = sum(edge[-1]["weight"] for edge in out_edges) # dump(src, src_rank, total_weight) - for _src, dst, data in G.out_edges(src, data=True): + for edge in out_edges: + dst = edge[1] + data = edge[-1] data["rank"] = src_rank * data["weight"] / total_weight ident = data["ident"] ranked_definitions[(dst, ident)] += data["rank"] diff --git a/tests/basic/test_repomap.py b/tests/basic/test_repomap.py index 9df806194ac..d0a2e63923b 100644 --- a/tests/basic/test_repomap.py +++ b/tests/basic/test_repomap.py @@ -4,8 +4,10 @@ import time import unittest from pathlib import Path +from unittest.mock import patch import git +import networkx as nx from aider.dump import dump # noqa: F401 from aider.io import InputOutput @@ -273,6 +275,45 @@ def test_get_repo_map_excludes_added_files(self): # close the open cache files, so Windows won't error del repo_map + def test_get_repo_map_handles_keyed_out_edges(self): + original_multidigraph = nx.MultiDiGraph + + class KeyedOutEdgesMultiDiGraph(original_multidigraph): + @property + def out_edges(self): + base_view = original_multidigraph.out_edges.__get__(self, type(self)) + + def keyed_out_edges(nbunch=None, data=False, default=None): + return base_view( + nbunch=nbunch, + data=data, + default=default, + keys=True if data else False, + ) + + return keyed_out_edges + + with GitTemporaryDirectory() as temp_dir: + with open(os.path.join(temp_dir, "defs.py"), "w", encoding="utf-8") as f: + f.write("def shared_name():\n return 1\n") + + with open(os.path.join(temp_dir, "refs.py"), "w", encoding="utf-8") as f: + f.write("from defs import shared_name\n\nshared_name()\nshared_name()\n") + + io = InputOutput() + repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io) + other_files = [ + os.path.join(temp_dir, "defs.py"), + os.path.join(temp_dir, "refs.py"), + ] + + with patch("networkx.MultiDiGraph", KeyedOutEdgesMultiDiGraph): + result = repo_map.get_repo_map([], other_files) + + self.assertIn("shared_name", result) + + del repo_map + class TestRepoMapTypescript(unittest.TestCase): def setUp(self):