From a3a3629f652f8f1e9c9d6a35e9b3001849c18da0 Mon Sep 17 00:00:00 2001 From: Murillo Alves <114107747+murilloimparavel@users.noreply.github.com> Date: Sat, 4 Apr 2026 00:59:10 -0300 Subject: [PATCH 1/2] feat: CLI-first expansion with 10 subcommands + CommonJS require() parsing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds 10 new CLI subcommands that expose MCP tool functions directly, enabling CLI-first workflows without requiring the MCP server: query, impact, search, flows, flow, communities, community, architecture, large-functions, refactor Also adds CLI post-processing after build/update (fixes #93): - Signature computation - FTS5 index rebuild - Execution flow detection - Community detection CommonJS require() parsing (parser.py): - `_js_get_require_target()`: extracts module paths from require(), path.join(), path.resolve(), template literals, dynamic import() - `_extract_js_require_constructs()`: creates IMPORTS_FROM edges - `_collect_js_require_names()`: populates import map for call resolution - Empty-string guards at all call sites - Depth-limited recursion (max_depth=50) for nested require walks Results on test monorepo (14.5K files): - IMPORTS_FROM edges increased 4x (525 → 2,075) - Flows detected: 1,185 - Communities: 605 - FTS indexed: 7,766 nodes CLI UX improvements: - Updated banner and docstring with all new commands - "Graph not built" warning when DB is missing - Argparse validation for flow/community (require --id or --name) - Argparse validation for refactor rename (require --old-name/--new-name) - DELEGATED_COMMANDS skip redundant GraphStore creation Co-Authored-By: Claude Opus 4.6 (1M context) --- code_review_graph/cli.py | 340 +++++++++++++++++++++++++++++++++++- code_review_graph/parser.py | 299 +++++++++++++++++++++++++++++++ 2 files changed, 633 insertions(+), 6 deletions(-) diff --git a/code_review_graph/cli.py b/code_review_graph/cli.py index 0bf8e18..8f42ec6 100644 --- a/code_review_graph/cli.py +++ b/code_review_graph/cli.py @@ -14,6 +14,16 @@ code-review-graph register [--alias name] code-review-graph unregister code-review-graph repos + code-review-graph query + code-review-graph impact [--files ...] [--depth N] + code-review-graph search [--kind KIND] + code-review-graph flows [--sort COLUMN] + code-review-graph flow [--id ID | --name NAME] + code-review-graph communities [--sort COLUMN] + code-review-graph community [--id ID | --name NAME] + code-review-graph architecture + code-review-graph large-functions [--min-lines N] + code-review-graph refactor """ from __future__ import annotations @@ -89,6 +99,18 @@ def _print_banner() -> None: {g}eval{r} Run evaluation benchmarks {g}serve{r} Start MCP server + {b}Graph queries:{r} + {g}query{r} Query relationships {d}(callers_of, callees_of, ...){r} + {g}impact{r} Analyze blast radius of changes + {g}search{r} Search code entities by name/keyword + {g}flows{r} List execution flows by criticality + {g}flow{r} Get details of a single execution flow + {g}communities{r} List detected code communities + {g}community{r} Get details of a single community + {g}architecture{r} Architecture overview from communities + {g}large-functions{r} Find functions exceeding line threshold + {g}refactor{r} Rename preview, dead code, suggestions + {d}Run{r} {b}code-review-graph --help{r} {d}for details{r} """) @@ -150,6 +172,71 @@ def _handle_init(args: argparse.Namespace) -> None: print(" 2. Restart your AI coding tool to pick up the new config") +def _run_post_processing(store: object) -> None: + """Run post-build steps: signatures, FTS indexing, flow detection, communities. + + Mirrors the post-processing in tools/build.py. Each step is non-fatal; + failures are logged and skipped so the build result is never lost. + """ + import sqlite3 + + # Compute signatures for nodes that don't have them + try: + rows = store.get_nodes_without_signature() # type: ignore[attr-defined] + for row in rows: + node_id, name, kind, params, ret = row[0], row[1], row[2], row[3], row[4] + if kind in ("Function", "Test"): + sig = f"def {name}({params or ''})" + if ret: + sig += f" -> {ret}" + elif kind == "Class": + sig = f"class {name}" + else: + sig = name + store.update_node_signature(node_id, sig[:512]) # type: ignore[attr-defined] + store.commit() # type: ignore[attr-defined] + sig_count = len(rows) if rows else 0 + if sig_count: + print(f"Signatures computed: {sig_count} nodes") + except (sqlite3.OperationalError, TypeError, KeyError, AttributeError) as e: + store.rollback() + logging.warning("Signature computation skipped: %s", e) + + # Rebuild FTS index + try: + from .search import rebuild_fts_index + + fts_count = rebuild_fts_index(store) + print(f"FTS indexed: {fts_count} nodes") + except (sqlite3.OperationalError, ImportError, AttributeError) as e: + store.rollback() + logging.warning("FTS index rebuild skipped: %s", e) + + # Trace execution flows + try: + from .flows import store_flows as _store_flows + from .flows import trace_flows as _trace_flows + + flows = _trace_flows(store) + count = _store_flows(store, flows) + print(f"Flows detected: {count}") + except (sqlite3.OperationalError, ImportError, AttributeError) as e: + store.rollback() + logging.warning("Flow detection skipped: %s", e) + + # Detect communities + try: + from .communities import detect_communities as _detect_communities + from .communities import store_communities as _store_communities + + comms = _detect_communities(store) + count = _store_communities(store, comms) + print(f"Communities: {count}") + except (sqlite3.OperationalError, ImportError, AttributeError) as e: + store.rollback() + logging.warning("Community detection skipped: %s", e) + + def main() -> None: """Main CLI entry point.""" ap = argparse.ArgumentParser( @@ -298,6 +385,120 @@ def main() -> None: serve_cmd = sub.add_parser("serve", help="Start MCP server (stdio transport)") serve_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + # --- CLI-first commands (expose MCP tool functions directly) --- + + # query + query_cmd = sub.add_parser( + "query", help="Query graph relationships (callers_of, callees_of, imports_of, etc.)" + ) + query_cmd.add_argument( + "pattern", + choices=[ + "callers_of", "callees_of", "imports_of", "importers_of", + "children_of", "tests_for", "inheritors_of", "file_summary", + ], + help="Query pattern", + ) + query_cmd.add_argument("target", help="Node name, qualified name, or file path") + query_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + + # impact + impact_cmd = sub.add_parser( + "impact", help="Analyze blast radius of changed files" + ) + impact_cmd.add_argument( + "--files", nargs="*", default=None, + help="Changed files (auto-detected from git if omitted)", + ) + impact_cmd.add_argument("--depth", type=int, default=2, help="BFS hops (default: 2)") + impact_cmd.add_argument("--base", default="HEAD~1", help="Git diff base (default: HEAD~1)") + impact_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + + # search + search_cmd = sub.add_parser("search", help="Search code entities by name/keyword") + search_cmd.add_argument("query", help="Search string") + search_cmd.add_argument( + "--kind", default=None, + choices=["File", "Class", "Function", "Type", "Test"], + help="Filter by entity kind", + ) + search_cmd.add_argument("--limit", type=int, default=20, help="Max results (default: 20)") + search_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + + # flows + flows_cmd = sub.add_parser("flows", help="List execution flows by criticality") + flows_cmd.add_argument( + "--sort", default="criticality", + choices=["criticality", "depth", "node_count", "file_count", "name"], + help="Sort column (default: criticality)", + ) + flows_cmd.add_argument("--limit", type=int, default=20, help="Max results (default: 20)") + flows_cmd.add_argument( + "--kind", default=None, help="Filter by entry point kind (e.g. Test, Function)" + ) + flows_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + + # flow (single) + flow_cmd = sub.add_parser("flow", help="Get details of a single execution flow") + flow_cmd.add_argument("--id", type=int, default=None, help="Flow ID") + flow_cmd.add_argument("--name", default=None, help="Flow name (partial match)") + flow_cmd.add_argument( + "--source", action="store_true", help="Include source code snippets" + ) + flow_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + + # communities + comms_cmd = sub.add_parser("communities", help="List detected code communities") + comms_cmd.add_argument( + "--sort", default="size", choices=["size", "cohesion", "name"], + help="Sort column (default: size)", + ) + comms_cmd.add_argument("--min-size", type=int, default=0, help="Min community size") + comms_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + + # community (single) + comm_cmd = sub.add_parser("community", help="Get details of a single community") + comm_cmd.add_argument("--id", type=int, default=None, help="Community ID") + comm_cmd.add_argument("--name", default=None, help="Community name (partial match)") + comm_cmd.add_argument( + "--members", action="store_true", help="Include member node details" + ) + comm_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + + # architecture + arch_cmd = sub.add_parser("architecture", help="Architecture overview from community structure") + arch_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + + # large-functions + large_cmd = sub.add_parser( + "large-functions", help="Find functions/classes exceeding line threshold" + ) + large_cmd.add_argument("--min-lines", type=int, default=50, help="Min lines (default: 50)") + large_cmd.add_argument( + "--kind", default=None, choices=["Function", "Class", "File", "Test"], + help="Filter by kind", + ) + large_cmd.add_argument("--path", default=None, help="Filter by file path substring") + large_cmd.add_argument("--limit", type=int, default=50, help="Max results (default: 50)") + large_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + + # refactor + refactor_cmd = sub.add_parser( + "refactor", help="Rename preview, dead code detection, suggestions" + ) + refactor_cmd.add_argument( + "mode", choices=["rename", "dead_code", "suggest"], + help="Operation mode", + ) + refactor_cmd.add_argument("--old-name", default=None, help="(rename) Current symbol name") + refactor_cmd.add_argument("--new-name", default=None, help="(rename) New symbol name") + refactor_cmd.add_argument( + "--kind", default=None, choices=["Function", "Class"], + help="(dead_code) Filter by kind", + ) + refactor_cmd.add_argument("--path", default=None, help="(dead_code) Filter by file path") + refactor_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)") + args = ap.parse_args() if args.version: @@ -392,18 +593,22 @@ def main() -> None: logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - from .graph import GraphStore from .incremental import ( find_project_root, find_repo_root, - full_build, get_db_path, - incremental_update, - watch, ) - if args.command in ("update", "detect-changes"): - # update and detect-changes require git for diffing + # Commands that delegate to tool functions which create their own GraphStore. + # We skip opening a redundant store for these. + DELEGATED_COMMANDS = { + "query", "impact", "search", "flows", "flow", + "communities", "community", "architecture", + "large-functions", "refactor", + } + + if args.command in ("update", "detect-changes", "impact"): + # update, detect-changes, and impact require git for diffing repo_root = Path(args.repo) if args.repo else find_repo_root() if not repo_root: logging.error( @@ -416,6 +621,35 @@ def main() -> None: repo_root = Path(args.repo) if args.repo else find_project_root() db_path = get_db_path(repo_root) + + # For delegated commands, warn if the graph hasn't been built yet, then + # delegate directly to the tool functions (they manage their own store). + if args.command in DELEGATED_COMMANDS: + if not db_path.exists(): + print( + "WARNING: Graph not built yet. " + "Run 'code-review-graph build' first." + ) + print() + _run_delegated_command(args, repo_root) + return + + # For non-delegated commands that need the graph DB (everything except build), + # warn if the DB is missing. + if args.command != "build" and not db_path.exists(): + print( + "WARNING: Graph not built yet. " + "Run 'code-review-graph build' first." + ) + print() + + from .graph import GraphStore + from .incremental import ( + full_build, + incremental_update, + watch, + ) + store = GraphStore(db_path) try: @@ -427,6 +661,7 @@ def main() -> None: ) if result["errors"]: print(f"Errors: {len(result['errors'])}") + _run_post_processing(store) elif args.command == "update": result = incremental_update(repo_root, store, base=args.base) @@ -434,6 +669,8 @@ def main() -> None: f"Incremental: {result['files_updated']} files updated, " f"{result['total_nodes']} nodes, {result['total_edges']} edges" ) + if result.get("files_updated", 0) > 0: + _run_post_processing(store) elif args.command == "status": stats = store.get_stats() @@ -524,3 +761,94 @@ def main() -> None: finally: store.close() + + +def _run_delegated_command(args: argparse.Namespace, repo_root: Path) -> None: + """Run commands that delegate to tool functions with their own GraphStore.""" + if args.command == "query": + from .tools import query_graph + result = query_graph( + pattern=args.pattern, target=args.target, + repo_root=str(repo_root), + ) + print(json.dumps(result, indent=2, default=str)) + + elif args.command == "impact": + from .tools import get_impact_radius + result = get_impact_radius( + changed_files=args.files, max_depth=args.depth, + repo_root=str(repo_root), base=args.base, + ) + print(json.dumps(result, indent=2, default=str)) + + elif args.command == "search": + from .tools import semantic_search_nodes + result = semantic_search_nodes( + query=args.query, kind=args.kind, limit=args.limit, + repo_root=str(repo_root), + ) + print(json.dumps(result, indent=2, default=str)) + + elif args.command == "flows": + from .tools import list_flows + result = list_flows( + repo_root=str(repo_root), sort_by=args.sort, + limit=args.limit, kind=args.kind, + ) + print(json.dumps(result, indent=2, default=str)) + + elif args.command == "flow": + if args.id is None and args.name is None: + print("Error: provide --id or --name to select a flow.") + sys.exit(1) + from .tools import get_flow + result = get_flow( + flow_id=args.id, flow_name=args.name, + include_source=args.source, repo_root=str(repo_root), + ) + print(json.dumps(result, indent=2, default=str)) + + elif args.command == "communities": + from .tools import list_communities_func + result = list_communities_func( + repo_root=str(repo_root), sort_by=args.sort, + min_size=args.min_size, + ) + print(json.dumps(result, indent=2, default=str)) + + elif args.command == "community": + if args.id is None and args.name is None: + print("Error: provide --id or --name to select a community.") + sys.exit(1) + from .tools import get_community_func + result = get_community_func( + community_name=args.name, community_id=args.id, + include_members=args.members, repo_root=str(repo_root), + ) + print(json.dumps(result, indent=2, default=str)) + + elif args.command == "architecture": + from .tools import get_architecture_overview_func + result = get_architecture_overview_func(repo_root=str(repo_root)) + print(json.dumps(result, indent=2, default=str)) + + elif args.command == "large-functions": + from .tools import find_large_functions + result = find_large_functions( + min_lines=args.min_lines, kind=args.kind, + file_path_pattern=args.path, limit=args.limit, + repo_root=str(repo_root), + ) + print(json.dumps(result, indent=2, default=str)) + + elif args.command == "refactor": + if args.mode == "rename" and (not args.old_name or not args.new_name): + print("Error: refactor rename requires --old-name and --new-name.") + sys.exit(1) + from .tools import refactor_func + result = refactor_func( + mode=args.mode, old_name=args.old_name, + new_name=args.new_name, kind=args.kind, + file_pattern=args.path, repo_root=str(repo_root), + ) + print(json.dumps(result, indent=2, default=str)) diff --git a/code_review_graph/parser.py b/code_review_graph/parser.py index bded99f..c854020 100644 --- a/code_review_graph/parser.py +++ b/code_review_graph/parser.py @@ -905,6 +905,14 @@ def _extract_from_tree( ): continue + # --- JS/TS CommonJS require() -> IMPORTS_FROM edges --- + if language in ("javascript", "typescript", "tsx"): + self._extract_js_require_constructs( + child, node_type, source, language, file_path, + nodes, edges, enclosing_class, enclosing_func, + import_map, defined_names, _depth, + ) + # --- JS/TS variable-assigned functions (const foo = () => {}) --- if ( language in ("javascript", "typescript", "tsx") @@ -1297,6 +1305,236 @@ def _lua_get_require_target(call_node) -> Optional[str]: return raw.strip("'\"") return None + # ------------------------------------------------------------------ + # JS/TS: CommonJS require() detection + # ------------------------------------------------------------------ + + @staticmethod + def _js_get_require_target(call_node) -> Optional[str]: + """Extract the module path from a JS/TS require() or dynamic import(). + + Handles these patterns: + require('./module') -> './module' + require(path.join(__dirname, 'mod')) -> './mod' + require(path.resolve(__dirname, 'mod.js')) -> './mod.js' + require(`./commands/${name}`) -> './commands' + await import('./module') -> './module' + await import(`./utils/${name}`) -> './utils' + + Returns the module path string, or None if this is not a + recognisable require/import call. + """ + if not call_node.children: + return None + + first = call_node.children[0] + is_require = (first.type == "identifier" and first.text == b"require") + is_import = (first.type == "import" or first.text == b"import") + + if not is_require and not is_import: + return None + + # Find the arguments node + args_node = None + for child in call_node.children: + if child.type == "arguments": + args_node = child + break + if not args_node: + return None + + # Walk argument children (skip parentheses) + for arg in args_node.children: + # --- Static string literal: require('./module') --- + if arg.type == "string": + for sub in arg.children: + if sub.type == "string_fragment": + return sub.text.decode("utf-8", errors="replace") + # Fallback: strip quotes + raw = arg.text.decode("utf-8", errors="replace") + return raw.strip("'\"") + + # --- Template literal: require(`./commands/${name}`) --- + if arg.type == "template_string": + # Extract leading static fragment before first interpolation + static_parts = [] + for sub in arg.children: + if sub.type == "string_fragment": + # tree-sitter calls template literal fragments + # "string_fragment" inside template_string + static_parts.append( + sub.text.decode("utf-8", errors="replace"), + ) + break + elif sub.type == "template_substitution": + break + if static_parts and static_parts[0]: + prefix = static_parts[0].rstrip("/") + return prefix if prefix else None + + # --- path.join/path.resolve: require(path.join(dir, 'mod')) --- + if arg.type == "call_expression": + callee = arg.children[0] if arg.children else None + if ( + callee + and callee.type == "member_expression" + and callee.text + ): + callee_text = callee.text.decode("utf-8", errors="replace") + if callee_text in ("path.join", "path.resolve"): + # Extract the last string argument as the module hint + inner_args = None + for sub in arg.children: + if sub.type == "arguments": + inner_args = sub + break + if inner_args: + last_string = None + for sub in inner_args.children: + if sub.type == "string": + raw = sub.text.decode( + "utf-8", errors="replace", + ) + last_string = raw.strip("'\"") + if last_string: + # Return as relative path for resolution + if not last_string.startswith("."): + last_string = "./" + last_string + return last_string + + return None + + def _extract_js_require_constructs( + self, + child, + node_type: str, + source: bytes, + language: str, + file_path: str, + nodes: list[NodeInfo], + edges: list[EdgeInfo], + enclosing_class: Optional[str], + enclosing_func: Optional[str], + import_map: Optional[dict[str, str]], + defined_names: Optional[set[str]], + _depth: int, + ) -> bool: + """Handle JS/TS CommonJS require() patterns as IMPORTS_FROM edges. + + Returns True if the child was fully handled and should be skipped + by the main loop. + + Handles: + - variable_declaration/lexical_declaration with require(): + ``const X = require('./module')`` -> IMPORTS_FROM edge + ``const { A, B } = require('./module')`` -> IMPORTS_FROM edge + - Top-level call_expression that is require(): + ``require('./module')`` -> IMPORTS_FROM edge + - expression_statement wrapping require(): + ``require('./side-effect-module')`` -> IMPORTS_FROM edge + - call_expression with dynamic import(): + ``await import('./module')`` -> IMPORTS_FROM edge + """ + # --- variable/lexical declaration with require() --- + if node_type in ("lexical_declaration", "variable_declaration"): + found_require = False + for declarator in child.children: + if declarator.type != "variable_declarator": + continue + # Find the value: look for call_expression child + for sub in declarator.children: + if sub.type == "call_expression": + req_target = self._js_get_require_target(sub) + if req_target is not None and req_target != "": + resolved = self._resolve_module_to_file( + req_target, file_path, language, + ) + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=resolved if resolved else req_target, + file_path=file_path, + line=child.start_point[0] + 1, + )) + found_require = True + break + if found_require: + # Still let _extract_js_var_functions run if there are also + # function assignments in the same declaration — but we've + # already captured the require edge. + return False + return False + + # --- expression_statement wrapping a require() call --- + # Covers: require('./side-effect'), module.exports = { X: require() } + if node_type == "expression_statement": + self._js_collect_require_edges( + child, file_path, language, edges, + ) + # Always return False — let generic recursion continue + return False + + return False + + def _js_collect_require_edges( + self, node, file_path: str, language: str, + edges: list[EdgeInfo], + depth: int = 0, max_depth: int = 50, + ) -> None: + """Recursively find require()/import() calls in a subtree and add + IMPORTS_FROM edges for each. + + This catches patterns like: + require('./side-effect') + module.exports = { A: require('./a'), B: require('./b') } + await import('./module') + """ + if depth >= max_depth: + return + if not node.children: + return + for child in node.children: + if child.type == "call_expression": + req_target = self._js_get_require_target(child) + if req_target is not None and req_target != "": + resolved = self._resolve_module_to_file( + req_target, file_path, language, + ) + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=resolved if resolved else req_target, + file_path=file_path, + line=child.start_point[0] + 1, + )) + # Don't recurse into this call_expression's children + continue + if child.type == "await_expression": + for await_child in child.children: + if await_child.type == "call_expression": + req_target = self._js_get_require_target(await_child) + if req_target is not None and req_target != "": + resolved = self._resolve_module_to_file( + req_target, file_path, language, + ) + edges.append(EdgeInfo( + kind="IMPORTS_FROM", + source=file_path, + target=( + resolved if resolved else req_target + ), + file_path=file_path, + line=child.start_point[0] + 1, + )) + continue + # Recurse into assignment_expression, object, pair, etc. + self._js_collect_require_edges( + child, file_path, language, edges, + depth=depth + 1, max_depth=max_depth, + ) + + return + # ------------------------------------------------------------------ # JS/TS: variable-assigned functions (const foo = () => {}) # ------------------------------------------------------------------ @@ -1950,8 +2188,69 @@ def _collect_file_scope( if node_type in import_types: self._collect_import_names(child, language, source, import_map) + # JS/TS: CommonJS require() into import_map + # const X = require('./module') -> {X: './module'} + # const { A, B } = require('./module') -> {A: './module', B: ...} + if ( + language in ("javascript", "typescript", "tsx") + and node_type in ("lexical_declaration", "variable_declaration") + ): + self._collect_js_require_names(child, import_map) + return import_map, defined_names + def _collect_js_require_names( + self, decl_node, import_map: dict[str, str], + ) -> None: + """Extract imported names from CommonJS require() declarations. + + Handles: + const X = require('./module') -> {X: './module'} + const { A, B } = require('./module') -> {A: './module', B: ...} + var X = require('./module') -> {X: './module'} + """ + for declarator in decl_node.children: + if declarator.type != "variable_declarator": + continue + + # Collect the variable name(s) and the initialiser + names: list[str] = [] + call_node = None + for sub in declarator.children: + if sub.type == "identifier": + names.append(sub.text.decode("utf-8", errors="replace")) + elif sub.type == "object_pattern": + # Destructured: const { A, B } = require(...) + for pat_child in sub.children: + if pat_child.type == "shorthand_property_identifier_pattern": + names.append( + pat_child.text.decode("utf-8", errors="replace"), + ) + elif pat_child.type == "pair_pattern": + # { A: localA } — last identifier is the local name + ids = [ + c for c in pat_child.children + if c.type == "identifier" + ] + if ids: + names.append( + ids[-1].text.decode( + "utf-8", errors="replace", + ), + ) + elif sub.type == "call_expression": + call_node = sub + + if not call_node: + continue + + req_target = self._js_get_require_target(call_node) + if not req_target: + continue + + for name in names: + import_map[name] = req_target + def _collect_import_names( self, node, language: str, source: bytes, import_map: dict[str, str], ) -> None: From dd6b6398c552a7958e1339b2fc4c734a5317dd0f Mon Sep 17 00:00:00 2001 From: Murillo Alves <114107747+murilloimparavel@users.noreply.github.com> Date: Sat, 4 Apr 2026 11:02:14 -0300 Subject: [PATCH 2/2] fix: extract shared post-processing pipeline, DRY up CLI + MCP + watch (#93) Extract the 4-step post-processing pipeline (signatures, FTS, flows, communities) from tools/build.py into a shared postprocessing.py module. Wire it into CLI build/update and watch mode via callback. - Add code_review_graph/postprocessing.py with run_post_processing() - tools/build.py now delegates to run_post_processing() (-62 lines) - cli.py uses thin _cli_post_process() wrapper (+summary printing) - watch() accepts on_files_updated callback for post-build steps - Hard error (sys.exit(1)) when querying without built graph - Fix: IndexError instead of KeyError in signature except clause - Fix: Callable type annotation instead of Any for watch callback - 9 new tests covering pipeline, idempotency, step isolation Closes #93 Co-Authored-By: Claude Opus 4.6 (1M context) --- code_review_graph/cli.py | 94 ++++++--------------- code_review_graph/incremental.py | 19 ++++- code_review_graph/postprocessing.py | 123 ++++++++++++++++++++++++++++ code_review_graph/tools/build.py | 71 +--------------- tests/test_postprocessing.py | 117 ++++++++++++++++++++++++++ 5 files changed, 284 insertions(+), 140 deletions(-) create mode 100644 code_review_graph/postprocessing.py create mode 100644 tests/test_postprocessing.py diff --git a/code_review_graph/cli.py b/code_review_graph/cli.py index 8f42ec6..0db30eb 100644 --- a/code_review_graph/cli.py +++ b/code_review_graph/cli.py @@ -172,69 +172,22 @@ def _handle_init(args: argparse.Namespace) -> None: print(" 2. Restart your AI coding tool to pick up the new config") -def _run_post_processing(store: object) -> None: - """Run post-build steps: signatures, FTS indexing, flow detection, communities. - - Mirrors the post-processing in tools/build.py. Each step is non-fatal; - failures are logged and skipped so the build result is never lost. - """ - import sqlite3 - - # Compute signatures for nodes that don't have them - try: - rows = store.get_nodes_without_signature() # type: ignore[attr-defined] - for row in rows: - node_id, name, kind, params, ret = row[0], row[1], row[2], row[3], row[4] - if kind in ("Function", "Test"): - sig = f"def {name}({params or ''})" - if ret: - sig += f" -> {ret}" - elif kind == "Class": - sig = f"class {name}" - else: - sig = name - store.update_node_signature(node_id, sig[:512]) # type: ignore[attr-defined] - store.commit() # type: ignore[attr-defined] - sig_count = len(rows) if rows else 0 - if sig_count: - print(f"Signatures computed: {sig_count} nodes") - except (sqlite3.OperationalError, TypeError, KeyError, AttributeError) as e: - store.rollback() - logging.warning("Signature computation skipped: %s", e) - - # Rebuild FTS index - try: - from .search import rebuild_fts_index - - fts_count = rebuild_fts_index(store) - print(f"FTS indexed: {fts_count} nodes") - except (sqlite3.OperationalError, ImportError, AttributeError) as e: - store.rollback() - logging.warning("FTS index rebuild skipped: %s", e) - - # Trace execution flows - try: - from .flows import store_flows as _store_flows - from .flows import trace_flows as _trace_flows - - flows = _trace_flows(store) - count = _store_flows(store, flows) - print(f"Flows detected: {count}") - except (sqlite3.OperationalError, ImportError, AttributeError) as e: - store.rollback() - logging.warning("Flow detection skipped: %s", e) - - # Detect communities - try: - from .communities import detect_communities as _detect_communities - from .communities import store_communities as _store_communities - - comms = _detect_communities(store) - count = _store_communities(store, comms) - print(f"Communities: {count}") - except (sqlite3.OperationalError, ImportError, AttributeError) as e: - store.rollback() - logging.warning("Community detection skipped: %s", e) +def _cli_post_process(store: object) -> None: + """Run shared post-processing pipeline and print summary for each step.""" + from .postprocessing import run_post_processing + + pp = run_post_processing(store) # type: ignore[arg-type] + if pp.get("signatures_computed"): + print(f"Signatures: {pp['signatures_computed']} nodes") + if pp.get("fts_indexed"): + print(f"FTS indexed: {pp['fts_indexed']} nodes") + if pp.get("flows_detected") is not None: + print(f"Flows: {pp['flows_detected']}") + if pp.get("communities_detected") is not None: + print(f"Communities: {pp['communities_detected']}") + if pp.get("warnings"): + for w in pp["warnings"]: + print(f" Warning: {w}") def main() -> None: @@ -601,7 +554,7 @@ def main() -> None: # Commands that delegate to tool functions which create their own GraphStore. # We skip opening a redundant store for these. - DELEGATED_COMMANDS = { + delegated_commands = { "query", "impact", "search", "flows", "flow", "communities", "community", "architecture", "large-functions", "refactor", @@ -624,13 +577,13 @@ def main() -> None: # For delegated commands, warn if the graph hasn't been built yet, then # delegate directly to the tool functions (they manage their own store). - if args.command in DELEGATED_COMMANDS: + if args.command in delegated_commands: if not db_path.exists(): print( - "WARNING: Graph not built yet. " + "Error: Graph not built yet. " "Run 'code-review-graph build' first." ) - print() + sys.exit(1) _run_delegated_command(args, repo_root) return @@ -661,7 +614,7 @@ def main() -> None: ) if result["errors"]: print(f"Errors: {len(result['errors'])}") - _run_post_processing(store) + _cli_post_process(store) elif args.command == "update": result = incremental_update(repo_root, store, base=args.base) @@ -670,7 +623,7 @@ def main() -> None: f"{result['total_nodes']} nodes, {result['total_edges']} edges" ) if result.get("files_updated", 0) > 0: - _run_post_processing(store) + _cli_post_process(store) elif args.command == "status": stats = store.get_stats() @@ -696,7 +649,8 @@ def main() -> None: ) elif args.command == "watch": - watch(repo_root, store) + from .postprocessing import run_post_processing + watch(repo_root, store, on_files_updated=run_post_processing) elif args.command == "visualize": from .visualization import generate_html diff --git a/code_review_graph/incremental.py b/code_review_graph/incremental.py index af72d29..eefce82 100644 --- a/code_review_graph/incremental.py +++ b/code_review_graph/incremental.py @@ -14,7 +14,7 @@ import subprocess import time from pathlib import Path -from typing import Optional +from typing import Any, Callable, Optional from .graph import GraphStore from .parser import CodeParser @@ -435,10 +435,19 @@ def incremental_update( _DEBOUNCE_SECONDS = 0.3 -def watch(repo_root: Path, store: GraphStore) -> None: +def watch( + repo_root: Path, + store: GraphStore, + on_files_updated: Callable[..., Any] | None = None, +) -> None: """Watch for file changes and auto-update the graph. Uses a 300ms debounce to batch rapid-fire saves into a single update. + + Args: + on_files_updated: Optional callback invoked after each flush with + the store as its only argument. Used to run post-processing + (signatures, FTS, flows, communities) after file updates. """ import threading @@ -517,6 +526,12 @@ def _flush(self): for abs_path in paths: self._update_file(abs_path) + if paths and on_files_updated is not None: + try: + on_files_updated(store) + except Exception as e: + logger.error("Post-update callback failed: %s", e) + def _update_file(self, abs_path: str): path = Path(abs_path) if not path.is_file(): diff --git a/code_review_graph/postprocessing.py b/code_review_graph/postprocessing.py new file mode 100644 index 0000000..9892ef7 --- /dev/null +++ b/code_review_graph/postprocessing.py @@ -0,0 +1,123 @@ +"""Shared post-processing pipeline for CLI, MCP tools, and watch mode. + +Runs 4 non-fatal steps after graph build/update: +1. Signature computation (human-readable function/class signatures) +2. FTS5 full-text search index rebuild +3. Execution flow detection and storage +4. Community detection via Leiden algorithm or file grouping + +Each step is independent — failure in one does not block the others. +""" + +from __future__ import annotations + +import logging +import sqlite3 +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .graph import GraphStore + +logger = logging.getLogger(__name__) + + +def run_post_processing(store: GraphStore) -> dict[str, Any]: + """Run all post-build steps and return a result dict. + + Returns: + Dict with keys: signatures_computed, fts_indexed, flows_detected, + communities_detected. If any step fails, a 'warnings' list is included. + """ + result: dict[str, Any] = {} + warnings: list[str] = [] + + _compute_signatures(store, result, warnings) + _rebuild_fts_index(store, result, warnings) + _trace_flows(store, result, warnings) + _detect_communities(store, result, warnings) + + if warnings: + result["warnings"] = warnings + return result + + +def _compute_signatures( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Compute human-readable signatures for nodes that lack them.""" + try: + rows = store.get_nodes_without_signature() + for row in rows: + node_id, name, kind, params, ret = ( + row[0], + row[1], + row[2], + row[3], + row[4], + ) + if kind in ("Function", "Test"): + sig = f"def {name}({params or ''})" + if ret: + sig += f" -> {ret}" + elif kind == "Class": + sig = f"class {name}" + else: + sig = name + store.update_node_signature(node_id, sig[:512]) + store.commit() + result["signatures_computed"] = len(rows) + except (sqlite3.OperationalError, TypeError, IndexError) as e: + logger.warning("Signature computation failed: %s", e) + warnings.append(f"Signature computation failed: {type(e).__name__}: {e}") + + +def _rebuild_fts_index( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Rebuild the FTS5 full-text search index.""" + try: + from .search import rebuild_fts_index + + fts_count = rebuild_fts_index(store) + result["fts_indexed"] = fts_count + except (sqlite3.OperationalError, ImportError) as e: + logger.warning("FTS index rebuild failed: %s", e) + warnings.append(f"FTS index rebuild failed: {type(e).__name__}: {e}") + + +def _trace_flows( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Trace execution flows from entry points.""" + try: + from .flows import store_flows, trace_flows + + flows = trace_flows(store) + count = store_flows(store, flows) + result["flows_detected"] = count + except (sqlite3.OperationalError, ImportError) as e: + logger.warning("Flow detection failed: %s", e) + warnings.append(f"Flow detection failed: {type(e).__name__}: {e}") + + +def _detect_communities( + store: GraphStore, + result: dict[str, Any], + warnings: list[str], +) -> None: + """Detect code communities via Leiden algorithm or file grouping.""" + try: + from .communities import detect_communities, store_communities + + comms = detect_communities(store) + count = store_communities(store, comms) + result["communities_detected"] = count + except (sqlite3.OperationalError, ImportError) as e: + logger.warning("Community detection failed: %s", e) + warnings.append(f"Community detection failed: {type(e).__name__}: {e}") diff --git a/code_review_graph/tools/build.py b/code_review_graph/tools/build.py index ad832ec..e6eca64 100644 --- a/code_review_graph/tools/build.py +++ b/code_review_graph/tools/build.py @@ -2,15 +2,12 @@ from __future__ import annotations -import logging -import sqlite3 from typing import Any from ..incremental import full_build, incremental_update +from ..postprocessing import run_post_processing from ._common import _get_store -logger = logging.getLogger(__name__) - def build_or_update_graph( full_rebuild: bool = False, @@ -64,70 +61,8 @@ def build_or_update_graph( **result, } - # -- Post-build steps (non-fatal; failures are surfaced as warnings) -- - warnings: list[str] = [] - - # Compute signatures for nodes that don't have them - try: - rows = store.get_nodes_without_signature() - for row in rows: - node_id, name, kind, params, ret = ( - row[0], row[1], row[2], row[3], row[4], - ) - if kind in ("Function", "Test"): - sig = f"def {name}({params or ''})" - if ret: - sig += f" -> {ret}" - elif kind == "Class": - sig = f"class {name}" - else: - sig = name - store.update_node_signature(node_id, sig[:512]) - store.commit() - except (sqlite3.OperationalError, TypeError, KeyError) as e: - logger.warning("Signature computation failed: %s", e) - warnings.append(f"Signature computation failed: {type(e).__name__}: {e}") - - # Rebuild FTS index - try: - from code_review_graph.search import rebuild_fts_index - - fts_count = rebuild_fts_index(store) - build_result["fts_indexed"] = fts_count - except (sqlite3.OperationalError, ImportError) as e: - logger.warning("FTS index rebuild failed: %s", e) - warnings.append(f"FTS index rebuild failed: {type(e).__name__}: {e}") - - # Trace execution flows - try: - from code_review_graph.flows import store_flows as _store_flows - from code_review_graph.flows import trace_flows as _trace_flows - - flows = _trace_flows(store) - count = _store_flows(store, flows) - build_result["flows_detected"] = count - except (sqlite3.OperationalError, ImportError) as e: - logger.warning("Flow detection failed: %s", e) - warnings.append(f"Flow detection failed: {type(e).__name__}: {e}") - - # Detect communities - try: - from code_review_graph.communities import ( - detect_communities as _detect_communities, - ) - from code_review_graph.communities import ( - store_communities as _store_communities, - ) - - comms = _detect_communities(store) - count = _store_communities(store, comms) - build_result["communities_detected"] = count - except (sqlite3.OperationalError, ImportError) as e: - logger.warning("Community detection failed: %s", e) - warnings.append(f"Community detection failed: {type(e).__name__}: {e}") - - if warnings: - build_result["warnings"] = warnings + pp = run_post_processing(store) + build_result.update(pp) return build_result finally: store.close() diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py new file mode 100644 index 0000000..f4fefb9 --- /dev/null +++ b/tests/test_postprocessing.py @@ -0,0 +1,117 @@ +"""Tests for the shared post-processing pipeline.""" + +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from code_review_graph.graph import GraphStore +from code_review_graph.parser import EdgeInfo, NodeInfo +from code_review_graph.postprocessing import run_post_processing + + +@pytest.fixture +def store_with_data(): + """Create a store with seed data for post-processing tests.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + store = GraphStore(db_path) + + store.upsert_node(NodeInfo( + kind="Function", name="handle", file_path="/repo/app.py", + line_start=10, line_end=20, language="python", + parent_name="Service", params="request", return_type="Response", + )) + store.upsert_node(NodeInfo( + kind="Class", name="Service", file_path="/repo/app.py", + line_start=5, line_end=40, language="python", + )) + store.upsert_node(NodeInfo( + kind="Test", name="test_handle", file_path="/repo/test_app.py", + line_start=1, line_end=10, language="python", is_test=True, + )) + store.upsert_edge(EdgeInfo( + kind="CALLS", source="/repo/app.py::Service.handle", + target="/repo/app.py::process", file_path="/repo/app.py", line=15, + )) + store.commit() + yield store + store.close() + Path(db_path).unlink(missing_ok=True) + + +@pytest.fixture +def empty_store(): + """Create an empty store.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + store = GraphStore(db_path) + yield store + store.close() + Path(db_path).unlink(missing_ok=True) + + +class TestRunPostProcessing: + def test_returns_dict_with_expected_keys(self, store_with_data): + result = run_post_processing(store_with_data) + assert "signatures_computed" in result + assert isinstance(result["signatures_computed"], int) + + def test_computes_signatures(self, store_with_data): + result = run_post_processing(store_with_data) + assert result["signatures_computed"] > 0 + + def test_idempotent(self, store_with_data): + run_post_processing(store_with_data) + second = run_post_processing(store_with_data) + assert second["signatures_computed"] == 0 + + def test_empty_store_no_crash(self, empty_store): + result = run_post_processing(empty_store) + assert result["signatures_computed"] == 0 + + def test_no_warnings_on_healthy_store(self, store_with_data): + result = run_post_processing(store_with_data) + assert "warnings" not in result + + +class TestStepIsolation: + def test_fts_failure_does_not_block_flows(self, store_with_data): + with patch( + "code_review_graph.search.rebuild_fts_index", + side_effect=ImportError("fts boom"), + ): + result = run_post_processing(store_with_data) + assert result["signatures_computed"] > 0 + assert "flows_detected" in result + assert "communities_detected" in result + assert "warnings" in result + + def test_fts_import_failure_produces_warning(self, store_with_data): + with patch( + "code_review_graph.search.rebuild_fts_index", + side_effect=ImportError("no search module"), + ): + result = run_post_processing(store_with_data) + assert "warnings" in result + assert any("FTS" in w for w in result["warnings"]) + + def test_flow_import_failure_produces_warning(self, store_with_data): + with patch( + "code_review_graph.flows.trace_flows", + side_effect=ImportError("no flows module"), + ): + result = run_post_processing(store_with_data) + assert "warnings" in result + assert any("Flow" in w for w in result["warnings"]) + + def test_community_import_failure_still_has_signatures(self, store_with_data): + with patch( + "code_review_graph.communities.detect_communities", + side_effect=ImportError("no communities"), + ): + result = run_post_processing(store_with_data) + assert result["signatures_computed"] > 0 + assert "warnings" in result + assert any("Community" in w for w in result["warnings"])