Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions aider/repomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,12 @@ def get_ranked_tags(
progress(f"{UPDATING_REPO_MAP_MESSAGE}: {src}")

src_rank = ranked[src]
total_weight = sum(data["weight"] for _src, _dst, data in G.out_edges(src, data=True))
out_edges = list(G.out_edges(src, data=True))
total_weight = sum(edge[-1]["weight"] for edge in out_edges)
# dump(src, src_rank, total_weight)
for _src, dst, data in G.out_edges(src, data=True):
for edge in out_edges:
dst = edge[1]
data = edge[-1]
data["rank"] = src_rank * data["weight"] / total_weight
ident = data["ident"]
ranked_definitions[(dst, ident)] += data["rank"]
Expand Down
41 changes: 41 additions & 0 deletions tests/basic/test_repomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import time
import unittest
from pathlib import Path
from unittest.mock import patch

import git
import networkx as nx

from aider.dump import dump # noqa: F401
from aider.io import InputOutput
Expand Down Expand Up @@ -273,6 +275,45 @@ def test_get_repo_map_excludes_added_files(self):
# close the open cache files, so Windows won't error
del repo_map

def test_get_repo_map_handles_keyed_out_edges(self):
original_multidigraph = nx.MultiDiGraph

class KeyedOutEdgesMultiDiGraph(original_multidigraph):
@property
def out_edges(self):
base_view = original_multidigraph.out_edges.__get__(self, type(self))

def keyed_out_edges(nbunch=None, data=False, default=None):
return base_view(
nbunch=nbunch,
data=data,
default=default,
keys=True if data else False,
)

return keyed_out_edges

with GitTemporaryDirectory() as temp_dir:
with open(os.path.join(temp_dir, "defs.py"), "w", encoding="utf-8") as f:
f.write("def shared_name():\n return 1\n")

with open(os.path.join(temp_dir, "refs.py"), "w", encoding="utf-8") as f:
f.write("from defs import shared_name\n\nshared_name()\nshared_name()\n")

io = InputOutput()
repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
other_files = [
os.path.join(temp_dir, "defs.py"),
os.path.join(temp_dir, "refs.py"),
]

with patch("networkx.MultiDiGraph", KeyedOutEdgesMultiDiGraph):
result = repo_map.get_repo_map([], other_files)

self.assertIn("shared_name", result)

del repo_map


class TestRepoMapTypescript(unittest.TestCase):
def setUp(self):
Expand Down
Loading