diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2bb6841..134016f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -52,6 +52,12 @@ jobs: cache: "npm" cache-dependency-path: libs/openant-core/parsers/javascript/package-lock.json + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: libs/openant-core/parsers/go/go_parser/go.mod + cache-dependency-path: libs/openant-core/parsers/go/go_parser/go.mod + - name: Install Python dependencies working-directory: libs/openant-core run: pip install -r requirements.txt && pip install ".[dev]" @@ -68,6 +74,16 @@ jobs: working-directory: libs/openant-core/parsers/javascript run: npm ci + - name: Build go_parser binary (Linux/macOS) + if: runner.os != 'Windows' + working-directory: libs/openant-core/parsers/go/go_parser + run: go build -o go_parser . + + - name: Build go_parser binary (Windows) + if: runner.os == 'Windows' + working-directory: libs/openant-core/parsers/go/go_parser + run: go build -o go_parser.exe . + - name: Run Python and parser tests working-directory: libs/openant-core run: python -m pytest tests/ -v diff --git a/.gitignore b/.gitignore index 599ac15..d432e6c 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,7 @@ __pycache__/ node_modules/ apps/openant-cli/bin/ libs/openant-core/parsers/go/go_parser/go_parser +libs/openant-core/parsers/javascript/.openant-npm-install.lock _docs/ +docs/ +.worktrees/ diff --git a/CHANGELOG.md b/CHANGELOG.md index bbe7a9a..bcd0d15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,101 @@ All notable changes to OpenAnt are documented in this file. +## [2026-05-12] — Parser depth, dependency UX, and LLM reachability (opt-in) + +### Fixed + +- **`openant parse` now defaults `--level` to `reachable`.** The Go CLI's + `parse` command previously defaulted to `--level all`, contradicting + `scan` and the Python CLI which both default to `reachable`. The + documentation has always said the default is `reachable`. Anyone running + `openant parse ` without `-l` now gets the same dataset as + `openant scan --steps parse` — the documented behavior. Set + `--level all` explicitly to restore the previous output. (#35) + +- **JS parser dependencies are now auto-installed on first use.** + `openant parse` on a JavaScript/TypeScript repository previously failed + out of the box with `Cannot find module 'ts-morph'` because nothing in + the install flow ran `npm install` for `parsers/javascript/`. The Python + parser adapter now runs `npm install` once on first JS parse using + `node_modules/.package-lock.json` as the completion sentinel (catches + Ctrl+C-interrupted installs). Python/Go-only users still never need + `npm`. Includes a cross-platform file lock to prevent concurrent install + corruption. Closes #6. (#37) + +- **TypeScript parser now resolves dependency-injected service calls.** + NestJS-style `this.userService.findById()` calls were previously + unresolved in the call graph because the parser didn't extract + constructor parameter types. Adds DI-aware resolution covering + constructor injection (`constructor(private svc: SvcType)`), + field-decorator injection (`@Inject` / `@InjectRepository` / etc.), and + Angular's functional `inject()` API. Resolution priority: exact type → + nominal (`implements`/`extends`) → unambiguous prefix (e.g. + `CallService` → `CallServiceV1`). All steps return `null` on ambiguity + to preserve the resolver's no-false-positive guarantee. Class-level + metadata is keyed by `relativePath:className` so multi-module monorepos + with same-named classes work. (#39) + +- **Express anonymous route handler callbacks are now extracted as units.** + `router.post('/orders', authenticateToken, async (req, res) => {...})` — + the anonymous handler callback was previously invisible to the analyzer + because the call-expression argument list wasn't walked. Synth units + now carry `route_handler` (last callback) or `route_middleware` (earlier + callbacks) with HTTP method/path metadata. Both unit types are now in + `ENTRY_POINT_TYPES` so the reachability filter doesn't drop them. The + receiver filter (`app` / `router` / `routes` / `server` / `web` / `api` + / `endpoints` / `controller`) prevents false positives on + `myCache.get(...)` style calls. Named middleware identifiers become + call-graph edges so `authenticateToken` shows up as an upstream + dependency of the handler. Closes #21. (#49) + +### Added + +- **Auto-reinstall when `pyproject.toml` changes.** The Go CLI now hashes + `libs/openant-core/pyproject.toml` (SHA-256) and stores the hash at + `~/.openant/venv/.deps-hash`. Every `EnsureRuntime` call compares the + stored hash against the current file and re-runs `pip install -e ` + automatically when they differ. Eliminates the "user did `git pull`, + dependencies changed, but venv is stale" silent failure mode that + previously required manual reinstall. Best-effort: hash read/write + failures degrade gracefully with stderr warnings rather than crashing + the CLI. (#36) + +- **`openant init` no longer requires a git repository for local paths.** + Init on a non-git directory (tarball download, generated code, locally + modified tree) now succeeds with `commit_sha` set to the `"nogit"` + placeholder. `--commit` on a non-git directory warns and is ignored + rather than hard-failing. Adds a shared `config/languages.json` + consumed by both the Go CLI and the Python parser adapter — single + source of truth for file-extension mappings and skip directories, + eliminating Go↔Python drift. Language auto-detection is exposed as + opt-in via `-l auto` (experimental dominance heuristic — see #61 for + the validation work needed before it becomes the default). (#40) + +- **`--llm-reachability` opt-in stage on `openant scan`.** A new optional + review pass that uses Opus (default) to surface reachability signals + the structural analysis misses — likely entry points (framework + handlers, plugin/CLI registrations, message queues), external content + ingestion sites (HTTP request bodies, file/network reads, env/argv, + IPC), and async/cross-process data flows. Promote-only semantics: + signals can mark units as entry points but never demote a unit the + structural pass kept. When enabled, parse runs with `processing_level + = "all"` so the LLM sees the full unfiltered codebase, then the + structural reachability filter re-runs with LLM-promoted entry points + added as additional BFS seeds. Output: `llm_reachability.json` plus + per-unit `llm_reachability_signals` field on `dataset.json`. + Cost-conscious: opt-in only, batched (default 25 units per Opus call), + scales with total repo size rather than the filtered unit count. Off + by default. (#50) + +- **All parsers now write `call_graph.json`.** Previously only the Python + and Zig parsers persisted this file; JS, Go, C, Ruby, and PHP did + reachability filtering internally and didn't expose the graph. Required + for the new `--llm-reachability` re-filter to work across all + languages. Defensive WARNING in `scanner.py` fires with a cost-impact + message if the file is ever missing for a language that should support + it. (#50) + ## [2026-05-10] — Windows compatibility & CI hardening ### Fixed diff --git a/apps/openant-cli/cmd/init.go b/apps/openant-cli/cmd/init.go index 50934aa..01c4a60 100644 --- a/apps/openant-cli/cmd/init.go +++ b/apps/openant-cli/cmd/init.go @@ -1,7 +1,9 @@ package cmd import ( + "encoding/json" "fmt" + "io/fs" "os" "os/exec" "path/filepath" @@ -26,6 +28,7 @@ After init, all commands (parse, scan, etc.) work without path arguments. Examples: openant init https://github.com/grafana/grafana -l go openant init https://github.com/grafana/grafana -l go --commit 591ceb2eec0 + openant init https://github.com/grafana/grafana -l auto openant init ./repos/grafana -l go openant init ./repos/grafana -l go --name myorg/grafana`, Args: cobra.ExactArgs(1), @@ -44,7 +47,7 @@ var ( ) func init() { - initCmd.Flags().StringVarP(&initLanguage, "language", "l", "", "Language to analyze: python, javascript, go, c, ruby, php (required)") + initCmd.Flags().StringVarP(&initLanguage, "language", "l", "", "Language to analyze: python, javascript, go, c, ruby, php, zig, auto (auto = experimental dominance heuristic; see #61)") initCmd.Flags().StringVar(&initCommit, "commit", "", "Specific commit SHA (default: HEAD)") initCmd.Flags().StringVar(&initName, "name", "", "Override project name (default: derived from URL/path)") initCmd.Flags().BoolVar(&initFull, "full", false, "Force full scan (rejects --incremental/--diff-base/--pr)") @@ -118,7 +121,7 @@ func runInit(cmd *cobra.Command, args []string) { } } } else { - // Local: verify it's a git repo and resolve absolute path + // Local: resolve absolute path source = "local" absPath, err := filepath.Abs(input) @@ -127,29 +130,48 @@ func runInit(cmd *cobra.Command, args []string) { os.Exit(1) } - if _, err := os.Stat(filepath.Join(absPath, ".git")); err != nil { - output.PrintError(fmt.Sprintf("%s is not a git repository (no .git directory)", absPath)) + repoPath = absPath + } + + // Auto-detect language if not specified + if initLanguage == "" || initLanguage == "auto" { + fmt.Fprintf(os.Stderr, "Auto-detecting language...\n") + detected, err := detectLanguage(repoPath) + if err != nil { + output.PrintError(fmt.Sprintf("Language auto-detection failed: %s\nSpecify manually with -l/--language", err)) os.Exit(1) } + initLanguage = detected + fmt.Fprintf(os.Stderr, "Detected language: %s\n", initLanguage) + } - repoPath = absPath + // Get commit SHA (best-effort — not all local paths are git repos) + isGit := false + if _, err := os.Stat(filepath.Join(repoPath, ".git")); err == nil { + isGit = true } - // Get commit SHA commitSHA := initCommit - if commitSHA == "" { - out, err := exec.Command("git", "-C", repoPath, "rev-parse", "HEAD").Output() - if err != nil { - output.PrintError(fmt.Sprintf("Failed to get HEAD commit: %s", err)) - os.Exit(1) + if isGit { + if commitSHA == "" { + out, err := exec.Command("git", "-C", repoPath, "rev-parse", "HEAD").Output() + if err != nil { + output.PrintError(fmt.Sprintf("Failed to get HEAD commit: %s", err)) + os.Exit(1) + } + commitSHA = strings.TrimSpace(string(out)) + } else { + // Resolve short SHA to full SHA + out, err := exec.Command("git", "-C", repoPath, "rev-parse", commitSHA).Output() + if err == nil { + commitSHA = strings.TrimSpace(string(out)) + } } - commitSHA = strings.TrimSpace(string(out)) } else { - // Resolve short SHA to full SHA - out, err := exec.Command("git", "-C", repoPath, "rev-parse", commitSHA).Output() - if err == nil { - commitSHA = strings.TrimSpace(string(out)) + if commitSHA != "" { + output.PrintWarning("--commit ignored: not a git repository") } + commitSHA = "nogit" } // Create project @@ -224,3 +246,125 @@ func runInit(cmd *cobra.Command, args []string) { output.PrintSuccess("Set as active project") fmt.Println() } + +// languagesConfig is the structure of config/languages.json. +type languagesConfig struct { + SkipDirs []string `json:"skip_dirs"` + Extensions map[string]string `json:"extensions"` +} + +// findLanguagesConfig locates config/languages.json by walking up from the +// executable path and then the current working directory. +func findLanguagesConfig() (string, error) { + rel := filepath.Join("config", "languages.json") + + // Strategy 1: walk up from the executable. + if exePath, err := os.Executable(); err == nil { + exePath, _ = filepath.EvalSymlinks(exePath) + dir := filepath.Dir(exePath) + for range 6 { + candidate := filepath.Join(dir, rel) + if info, err := os.Stat(candidate); err == nil && !info.IsDir() { + return candidate, nil + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + } + + // Strategy 2: walk up from CWD. + if cwd, err := os.Getwd(); err == nil { + dir := cwd + for range 6 { + candidate := filepath.Join(dir, rel) + if info, err := os.Stat(candidate); err == nil && !info.IsDir() { + return candidate, nil + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + } + + return "", fmt.Errorf("could not find config/languages.json from executable or working directory") +} + +// loadLanguagesConfig loads the shared language detection config. +func loadLanguagesConfig() (*languagesConfig, error) { + path, err := findLanguagesConfig() + if err != nil { + return nil, err + } + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %w", path, err) + } + var cfg languagesConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse %s: %w", path, err) + } + return &cfg, nil +} + +// detectLanguage walks a repository and returns the dominant language by file count. +// Extension mappings and skip directories are loaded from config/languages.json +// (shared with libs/openant-core/core/parser_adapter.py::detect_language()). +func detectLanguage(repoPath string) (string, error) { + cfg, err := loadLanguagesConfig() + if err != nil { + return "", fmt.Errorf("failed to load language config: %w", err) + } + + skipDirs := make(map[string]bool, len(cfg.SkipDirs)) + for _, d := range cfg.SkipDirs { + skipDirs[d] = true + } + + counts := make(map[string]int) + + err = filepath.WalkDir(repoPath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil // skip inaccessible paths + } + if d.IsDir() { + if skipDirs[d.Name()] { + return filepath.SkipDir + } + return nil + } + + ext := strings.ToLower(filepath.Ext(d.Name())) + if lang, ok := cfg.Extensions[ext]; ok { + counts[lang]++ + } + return nil + }) + if err != nil { + return "", fmt.Errorf("failed to walk repository: %w", err) + } + + // Find the dominant language + bestLang := "" + bestCount := 0 + for lang, count := range counts { + if count > bestCount { + bestCount = count + bestLang = lang + } + } + + if bestLang == "" { + return "", fmt.Errorf( + "no supported source files found in %s. "+ + "Supported languages: Python, JavaScript/TypeScript, Go, C/C++, Ruby, PHP, Zig", + repoPath, + ) + } + + return bestLang, nil +} diff --git a/apps/openant-cli/cmd/parse.go b/apps/openant-cli/cmd/parse.go index 988801f..563ca5a 100644 --- a/apps/openant-cli/cmd/parse.go +++ b/apps/openant-cli/cmd/parse.go @@ -34,12 +34,33 @@ var ( func init() { parseCmd.Flags().StringVarP(&parseOutput, "output", "o", "", "Output directory (default: project scan dir)") parseCmd.Flags().StringVarP(&parseLanguage, "language", "l", "", "Language: python, javascript, go, c, ruby, php, auto") - parseCmd.Flags().StringVar(&parseLevel, "level", "all", "Processing level: all, reachable, codeql, exploitable") + parseCmd.Flags().StringVar(&parseLevel, "level", "reachable", "Processing level: all, reachable, codeql, exploitable") parseCmd.Flags().StringVar(&parseDiffBase, "diff-base", "", "Incremental mode: tag units overlapping diff vs this ref") parseCmd.Flags().IntVar(&parsePR, "pr", 0, "Incremental mode against a GitHub PR number (mutex with --diff-base)") parseCmd.Flags().StringVar(&parseDiffScope, "diff-scope", "changed_functions", "Diff scope: changed_files, changed_functions, callers") } +// buildParsePyArgs assembles the argv passed to the Python `openant parse` +// subprocess. Defaults that match the Python CLI (language=auto, +// level=reachable) are omitted so the Python side stays in charge of the +// canonical default value. +func buildParsePyArgs(repoPath, output, datasetName, language, level, manifestPath string) []string { + pyArgs := []string{"parse", repoPath, "--output", output} + if datasetName != "" { + pyArgs = append(pyArgs, "--name", datasetName) + } + if language != "auto" { + pyArgs = append(pyArgs, "--language", language) + } + if level != "reachable" { + pyArgs = append(pyArgs, "--level", level) + } + if manifestPath != "" { + pyArgs = append(pyArgs, "--diff-manifest", manifestPath) + } + return pyArgs +} + func runParse(cmd *cobra.Command, args []string) { repoPath, ctx, err := resolveRepoArg(args) if err != nil { @@ -92,19 +113,7 @@ func runParse(cmd *cobra.Command, args []string) { } } - pyArgs := []string{"parse", repoPath, "--output", parseOutput} - if datasetName != "" { - pyArgs = append(pyArgs, "--name", datasetName) - } - if parseLanguage != "auto" { - pyArgs = append(pyArgs, "--language", parseLanguage) - } - if parseLevel != "all" { - pyArgs = append(pyArgs, "--level", parseLevel) - } - if manifestPath != "" { - pyArgs = append(pyArgs, "--diff-manifest", manifestPath) - } + pyArgs := buildParsePyArgs(repoPath, parseOutput, datasetName, parseLanguage, parseLevel, manifestPath) result, err := python.Invoke(rt.Path, pyArgs, "", quiet, resolvedAPIKey()) if err != nil { diff --git a/apps/openant-cli/cmd/parse_test.go b/apps/openant-cli/cmd/parse_test.go new file mode 100644 index 0000000..e080df2 --- /dev/null +++ b/apps/openant-cli/cmd/parse_test.go @@ -0,0 +1,87 @@ +package cmd + +import ( + "strings" + "testing" +) + +func TestParseLevelFlagDefaultIsReachable(t *testing.T) { + flag := parseCmd.Flag("level") + if flag == nil { + t.Fatal("parseCmd has no --level flag") + } + if got, want := flag.DefValue, "reachable"; got != want { + t.Errorf("--level default = %q, want %q", got, want) + } +} + +func TestParseLevelFlagUsageMentionsChoices(t *testing.T) { + flag := parseCmd.Flag("level") + if flag == nil { + t.Fatal("parseCmd has no --level flag") + } + for _, choice := range []string{"all", "reachable", "codeql", "exploitable"} { + if !strings.Contains(flag.Usage, choice) { + t.Errorf("--level usage missing %q: %q", choice, flag.Usage) + } + } +} + +func TestBuildParsePyArgsLevelForwarding(t *testing.T) { + tests := []struct { + name string + level string + wantLevel bool // true if --level should appear in argv + }{ + {"default reachable is omitted", "reachable", false}, + {"all is forwarded", "all", true}, + {"codeql is forwarded", "codeql", true}, + {"exploitable is forwarded", "exploitable", true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + args := buildParsePyArgs("/repo", "/out", "", "auto", tc.level, "") + gotLevel, gotValue := findFlag(args, "--level") + if gotLevel != tc.wantLevel { + t.Errorf("--level present = %v, want %v (argv=%v)", gotLevel, tc.wantLevel, args) + } + if tc.wantLevel && gotValue != tc.level { + t.Errorf("--level value = %q, want %q (argv=%v)", gotValue, tc.level, args) + } + }) + } +} + +func TestBuildParsePyArgsBaseline(t *testing.T) { + args := buildParsePyArgs("/repo", "/out", "org-repo-abc1234", "python", "exploitable", "/tmp/manifest.json") + want := []string{ + "parse", "/repo", + "--output", "/out", + "--name", "org-repo-abc1234", + "--language", "python", + "--level", "exploitable", + "--diff-manifest", "/tmp/manifest.json", + } + if len(args) != len(want) { + t.Fatalf("argv = %v, want %v", args, want) + } + for i := range want { + if args[i] != want[i] { + t.Errorf("argv[%d] = %q, want %q (full=%v)", i, args[i], want[i], args) + } + } +} + +// findFlag returns whether name is present in argv, and its following value +// (or "" if it has no value). +func findFlag(argv []string, name string) (bool, string) { + for i, a := range argv { + if a == name { + if i+1 < len(argv) { + return true, argv[i+1] + } + return true, "" + } + } + return false, "" +} diff --git a/apps/openant-cli/cmd/scan.go b/apps/openant-cli/cmd/scan.go index 2a646b5..c9ce974 100644 --- a/apps/openant-cli/cmd/scan.go +++ b/apps/openant-cli/cmd/scan.go @@ -51,6 +51,8 @@ var ( scanDiffBase string scanPR int scanDiffScope string + scanLLMReachability bool + scanLLMReachabilityMaxCodeBytes int ) func init() { @@ -79,6 +81,8 @@ func registerScanFlags(cmd *cobra.Command) { cmd.Flags().StringVar(&scanDiffBase, "diff-base", "", "Incremental mode: filter pipeline to units overlapping diff vs this ref (e.g. origin/main, HEAD~5)") cmd.Flags().IntVar(&scanPR, "pr", 0, "Incremental mode against a GitHub PR number (requires gh; mutex with --diff-base)") cmd.Flags().StringVar(&scanDiffScope, "diff-scope", "changed_functions", "Diff scope: changed_files, changed_functions, callers") + cmd.Flags().BoolVar(&scanLLMReachability, "llm-reachability", false, "Enable the LLM reachability review stage (Opus). Surfaces entry points and external-input sites the structural pass would miss by reviewing the full codebase before the reachability filter is applied. Off by default — enabling this incurs cost proportional to total repo size, not the filtered unit count (~one Opus call per 25 units across the whole codebase).") + cmd.Flags().IntVar(&scanLLMReachabilityMaxCodeBytes, "llm-reachability-max-code-bytes", 1500, "Max code bytes per unit sent to the LLM reachability stage (default: 1500). Higher values (e.g. 4096, 8192) catch entry-point indicators past byte 1500 in long handlers / generated code, at proportional Opus cost increase. Only meaningful with --llm-reachability.") } func runScan(cmd *cobra.Command, args []string) { @@ -197,6 +201,12 @@ func runScan(cmd *cobra.Command, args []string) { if manifestPath != "" { pyArgs = append(pyArgs, "--diff-manifest", manifestPath) } + if scanLLMReachability { + pyArgs = append(pyArgs, "--llm-reachability") + } + if scanLLMReachabilityMaxCodeBytes != 1500 { + pyArgs = append(pyArgs, "--llm-reachability-max-code-bytes", fmt.Sprintf("%d", scanLLMReachabilityMaxCodeBytes)) + } // Pass repository metadata from project context so reports don't show // [NOT PROVIDED] placeholders. diff --git a/apps/openant-cli/internal/python/runtime.go b/apps/openant-cli/internal/python/runtime.go index 20a1631..ddaa67e 100644 --- a/apps/openant-cli/internal/python/runtime.go +++ b/apps/openant-cli/internal/python/runtime.go @@ -2,6 +2,8 @@ package python import ( + "crypto/sha256" + "encoding/hex" "fmt" "os" "os/exec" @@ -198,6 +200,16 @@ func CheckOpenantInstalled(pythonPath string) error { ) } + // Save dependency hash so CheckDepsStale knows this is the baseline. + pyprojectPath := filepath.Join(corePath, "pyproject.toml") + if h, err := hashFile(pyprojectPath); err == nil { + if err := writeStoredHash(h); err != nil { + fmt.Fprintf(os.Stderr, + "warning: could not save dependency hash at %s: %v (next run may reinstall)\n", + depsHashPath(), err) + } + } + fmt.Fprintln(os.Stderr, "openant installed successfully.") return nil } @@ -220,13 +232,125 @@ func EnsureRuntime() (*RuntimeInfo, error) { vp := venvPython() if rt.Path != vp && fileExists(vp) && isOpenantImportable(vp) { if info, err := checkPython(vp); err == nil { - return info, nil + rt = info } } + // Check if dependencies have changed since last install. + if err := CheckDepsStale(rt.Path); err != nil { + return nil, err + } + return rt, nil } +// depsHashPath returns the path to the stored dependency hash inside the venv. +func depsHashPath() string { + return filepath.Join(venvDir(), ".deps-hash") +} + +// hashFile returns the hex-encoded SHA-256 of a file's contents. +func hashFile(path string) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]), nil +} + +// readHashAt reads a stored hash from the given path, or "" if absent. +func readHashAt(path string) string { + data, err := os.ReadFile(path) + if err != nil { + return "" + } + return strings.TrimSpace(string(data)) +} + +// writeHashAt saves a hash to the given path, creating the parent directory +// if it does not already exist. +func writeHashAt(path, hash string) error { + if dir := filepath.Dir(path); dir != "" && dir != "." { + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + } + return os.WriteFile(path, []byte(hash+"\n"), 0644) +} + +// readStoredHash reads the previously stored dependency hash, or "" if absent. +func readStoredHash() string { return readHashAt(depsHashPath()) } + +// writeStoredHash saves the dependency hash to the venv marker file. +func writeStoredHash(hash string) error { return writeHashAt(depsHashPath(), hash) } + +// depsStalenessAt inspects pyproject.toml at corePath and the hash stored at +// hashPath, and reports whether a reinstall is needed. The boolean is true +// when deps are stale (i.e. the hash differs and a reinstall is warranted). +// The caller is expected to skip the check on any error. +func depsStalenessAt(corePath, hashPath string) (stale bool, currentHash string, err error) { + pyprojectPath := filepath.Join(corePath, "pyproject.toml") + currentHash, err = hashFile(pyprojectPath) + if err != nil { + return false, "", err + } + return currentHash != readHashAt(hashPath), currentHash, nil +} + +// depsStaleness is the production wrapper around depsStalenessAt that uses +// the real venv hash path. +func depsStaleness(corePath string) (stale bool, currentHash string, err error) { + return depsStalenessAt(corePath, depsHashPath()) +} + +// CheckDepsStale checks if pyproject.toml has changed since the last install. +// If stale, it re-runs pip install -e and updates the stored hash. +// Returns nil if deps are up-to-date or were successfully refreshed. +func CheckDepsStale(pythonPath string) error { + return checkDepsStaleWith(pythonPath, findOpenantCore) +} + +// checkDepsStaleWith is the testable core of CheckDepsStale; coreFinder is +// injected so tests can avoid os.Chdir to simulate a missing source tree. +func checkDepsStaleWith(pythonPath string, coreFinder func() (string, error)) error { + corePath, err := coreFinder() + if err != nil { + // Can't find source — skip staleness check + return nil + } + + stale, currentHash, err := depsStaleness(corePath) + if err != nil { + // Can't read pyproject.toml — skip check + return nil + } + if !stale { + return nil // deps are up-to-date + } + + fmt.Fprintln(os.Stderr, "Dependencies changed, updating openant installation...") + // Known limitation: concurrent invocations that both detect stale deps + // will race to pip-install into the same venv. pip does not support + // concurrent writes; an OS-level lock would be needed to close this gap. + if err := installOpenant(pythonPath, corePath); err != nil { + return fmt.Errorf( + "failed to update openant dependencies: %w\n"+ + "Try manually: %s -m pip install -e %s", + err, pythonPath, corePath, + ) + } + + // Store the new hash + if err := writeStoredHash(currentHash); err != nil { + // Non-fatal — install succeeded, just can't cache the hash + fmt.Fprintf(os.Stderr, "Warning: could not save dependency hash: %v\n", err) + } + + fmt.Fprintln(os.Stderr, "Dependencies updated successfully.") + return nil +} + // createVenv creates a new venv at ~/.openant/venv/ using the given Python. func createVenv(pythonPath string) error { dir := venvDir() diff --git a/apps/openant-cli/internal/python/runtime_test.go b/apps/openant-cli/internal/python/runtime_test.go index 573814e..ead1f21 100644 --- a/apps/openant-cli/internal/python/runtime_test.go +++ b/apps/openant-cli/internal/python/runtime_test.go @@ -1,6 +1,9 @@ package python import ( + "crypto/sha256" + "encoding/hex" + "errors" "os" "path/filepath" "runtime" @@ -8,6 +11,10 @@ import ( "testing" ) +// --------------------------------------------------------------------------- +// venvPython / venvDir +// --------------------------------------------------------------------------- + func TestVenvPython_Windows(t *testing.T) { if runtime.GOOS != "windows" { t.Skip("test only runs on Windows") @@ -45,3 +52,286 @@ func TestVenvDir_ReturnsAbsolutePath(t *testing.T) { t.Errorf("venvDir() should return absolute path, got %q", vd) } } + +// --------------------------------------------------------------------------- +// hashFile +// --------------------------------------------------------------------------- + +func TestHashFile_KnownContent(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + content := []byte("[project]\nname = \"openant\"\n") + if err := os.WriteFile(path, content, 0644); err != nil { + t.Fatal(err) + } + + got, err := hashFile(path) + if err != nil { + t.Fatalf("hashFile returned error: %v", err) + } + + sum := sha256.Sum256(content) + want := hex.EncodeToString(sum[:]) + if got != want { + t.Errorf("hashFile = %q, want %q", got, want) + } +} + +func TestHashFile_EmptyFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "empty") + if err := os.WriteFile(path, []byte{}, 0644); err != nil { + t.Fatal(err) + } + + got, err := hashFile(path) + if err != nil { + t.Fatalf("hashFile returned error: %v", err) + } + + sum := sha256.Sum256([]byte{}) + want := hex.EncodeToString(sum[:]) + if got != want { + t.Errorf("hashFile = %q, want %q", got, want) + } +} + +func TestHashFile_MissingFile(t *testing.T) { + _, err := hashFile(filepath.Join(t.TempDir(), "nonexistent")) + if err == nil { + t.Error("expected error for missing file, got nil") + } +} + +func TestHashFile_DifferentContent(t *testing.T) { + dir := t.TempDir() + pathA := filepath.Join(dir, "a.toml") + pathB := filepath.Join(dir, "b.toml") + os.WriteFile(pathA, []byte("version 1"), 0644) + os.WriteFile(pathB, []byte("version 2"), 0644) + + hashA, _ := hashFile(pathA) + hashB, _ := hashFile(pathB) + if hashA == hashB { + t.Error("different files should produce different hashes") + } +} + +// --------------------------------------------------------------------------- +// readStoredHash / writeStoredHash +// --------------------------------------------------------------------------- + +// readStoredHash / writeStoredHash delegate to readHashAt/writeHashAt with +// a path under the user's real ~/.openant/venv/. The tests exercise the +// underlying readHashAt/writeHashAt helpers directly to avoid touching the +// real venv directory. + +func TestWriteAndReadHashAt_RoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, ".deps-hash") + + hash := "abc123def456" + if err := writeHashAt(path, hash); err != nil { + t.Fatalf("writeHashAt: %v", err) + } + + got := readHashAt(path) + if got != hash { + t.Errorf("readHashAt = %q, want %q (trailing newline should be trimmed)", got, hash) + } +} + +func TestReadHashAt_MissingFile_ReturnsEmpty(t *testing.T) { + got := readHashAt(filepath.Join(t.TempDir(), "nonexistent")) + if got != "" { + t.Errorf("readHashAt missing file = %q, want \"\"", got) + } +} + +func TestReadHashAt_TrimsWhitespace(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, ".deps-hash") + if err := os.WriteFile(path, []byte(" abc\n\n"), 0644); err != nil { + t.Fatal(err) + } + if got := readHashAt(path); got != "abc" { + t.Errorf("readHashAt = %q, want %q", got, "abc") + } +} + +func TestReadStoredHash_DoesNotPanic(t *testing.T) { + // Smoke test: reading from the real ~/.openant/venv/.deps-hash must + // not panic regardless of whether the file exists. + _ = readStoredHash() +} + +func TestWriteHashAt_CreatesMissingParentDir(t *testing.T) { + dir := t.TempDir() + // nested directory that does not yet exist + path := filepath.Join(dir, "a", "b", ".deps-hash") + if err := writeHashAt(path, "deadbeef"); err != nil { + t.Fatalf("writeHashAt should create missing parents: %v", err) + } + if got := readHashAt(path); got != "deadbeef" { + t.Errorf("readHashAt after writeHashAt = %q, want %q", got, "deadbeef") + } +} + +// --------------------------------------------------------------------------- +// depsStalenessAt — covers the trigger detection logic without invoking pip +// --------------------------------------------------------------------------- + +// writeFakeCore creates a minimal pyproject.toml under a fake core dir and +// returns the core dir path. +func writeFakeCore(t *testing.T, contents string) string { + t.Helper() + core := t.TempDir() + if err := os.WriteFile(filepath.Join(core, "pyproject.toml"), []byte(contents), 0644); err != nil { + t.Fatal(err) + } + return core +} + +func TestDepsStalenessAt_FreshState_NoHashStored_IsStale(t *testing.T) { + core := writeFakeCore(t, "[project]\nname = \"x\"\n") + hashPath := filepath.Join(t.TempDir(), ".deps-hash") + + stale, cur, err := depsStalenessAt(core, hashPath) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !stale { + t.Error("expected stale=true when no hash has been stored") + } + if cur == "" { + t.Error("expected non-empty current hash") + } +} + +func TestDepsStalenessAt_MatchingHash_NotStale(t *testing.T) { + core := writeFakeCore(t, "[project]\nname = \"x\"\n") + hashPath := filepath.Join(t.TempDir(), ".deps-hash") + + // First call: capture the hash and write it out. + _, cur, err := depsStalenessAt(core, hashPath) + if err != nil { + t.Fatalf("first call: %v", err) + } + if err := writeHashAt(hashPath, cur); err != nil { + t.Fatal(err) + } + + // Second call: hash matches, should not be stale. + stale, _, err := depsStalenessAt(core, hashPath) + if err != nil { + t.Fatalf("second call: %v", err) + } + if stale { + t.Error("expected stale=false when stored hash matches current") + } +} + +func TestDepsStalenessAt_ModifiedPyproject_IsStale(t *testing.T) { + core := writeFakeCore(t, "[project]\nname = \"x\"\nversion = \"0.1\"\n") + hashPath := filepath.Join(t.TempDir(), ".deps-hash") + + _, originalHash, err := depsStalenessAt(core, hashPath) + if err != nil { + t.Fatal(err) + } + if err := writeHashAt(hashPath, originalHash); err != nil { + t.Fatal(err) + } + + // Mutate pyproject.toml — simulating a `git pull` that bumped a dep. + if err := os.WriteFile( + filepath.Join(core, "pyproject.toml"), + []byte("[project]\nname = \"x\"\nversion = \"0.2\"\ndependencies = [\"requests\"]\n"), + 0644, + ); err != nil { + t.Fatal(err) + } + + stale, newHash, err := depsStalenessAt(core, hashPath) + if err != nil { + t.Fatal(err) + } + if !stale { + t.Error("expected stale=true after pyproject.toml was modified") + } + if newHash == originalHash { + t.Error("expected new hash to differ from original after content change") + } +} + +func TestDepsStalenessAt_MissingPyproject_ReturnsError(t *testing.T) { + core := t.TempDir() // no pyproject.toml inside + hashPath := filepath.Join(t.TempDir(), ".deps-hash") + + stale, _, err := depsStalenessAt(core, hashPath) + if err == nil { + t.Error("expected error when pyproject.toml is missing") + } + if stale { + t.Error("expected stale=false on error") + } +} + +func TestDepsStalenessAt_StoredHashEqualsEmpty_StillStale(t *testing.T) { + // If the hash file is present but empty (e.g. truncated write), the + // stored hash trims to "" and we should treat the deps as stale so the + // next run heals the state by reinstalling. + core := writeFakeCore(t, "[project]\nname = \"x\"\n") + hashPath := filepath.Join(t.TempDir(), ".deps-hash") + if err := os.WriteFile(hashPath, []byte("\n"), 0644); err != nil { + t.Fatal(err) + } + + stale, _, err := depsStalenessAt(core, hashPath) + if err != nil { + t.Fatal(err) + } + if !stale { + t.Error("expected stale=true when stored hash is empty") + } +} + +// --------------------------------------------------------------------------- +// CheckDepsStale — integration-style tests with temp dirs +// --------------------------------------------------------------------------- + +func TestCheckDepsStale_SkipsWhenCoreNotFound(t *testing.T) { + err := checkDepsStaleWith("/nonexistent/python", func() (string, error) { + return "", errors.New("simulated: core not found") + }) + if err != nil { + t.Errorf("expected nil when core not found, got: %v", err) + } +} + +// --------------------------------------------------------------------------- +// fileExists +// --------------------------------------------------------------------------- + +func TestFileExists_True(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "exists.txt") + os.WriteFile(path, []byte("hi"), 0644) + + if !fileExists(path) { + t.Error("fileExists should return true for existing file") + } +} + +func TestFileExists_False_Missing(t *testing.T) { + if fileExists(filepath.Join(t.TempDir(), "nope")) { + t.Error("fileExists should return false for missing file") + } +} + +func TestFileExists_False_Directory(t *testing.T) { + dir := t.TempDir() + if fileExists(dir) { + t.Error("fileExists should return false for directories") + } +} diff --git a/config/languages.json b/config/languages.json new file mode 100644 index 0000000..7a99dde --- /dev/null +++ b/config/languages.json @@ -0,0 +1,34 @@ +{ + "skip_dirs": [ + "node_modules", + "__pycache__", + "venv", + ".venv", + "dist", + "build", + ".git", + "vendor" + ], + "extensions": { + ".py": "python", + ".js": "javascript", + ".ts": "javascript", + ".jsx": "javascript", + ".tsx": "javascript", + ".mjs": "javascript", + ".cjs": "javascript", + ".go": "go", + ".c": "c", + ".h": "c", + ".cpp": "c", + ".hpp": "c", + ".cc": "c", + ".cxx": "c", + ".hxx": "c", + ".hh": "c", + ".rb": "ruby", + ".rake": "ruby", + ".php": "php", + ".zig": "zig" + } +} diff --git a/libs/openant-core/core/llm_reachability.py b/libs/openant-core/core/llm_reachability.py new file mode 100644 index 0000000..8e19d1d --- /dev/null +++ b/libs/openant-core/core/llm_reachability.py @@ -0,0 +1,461 @@ +""" +LLM-based reachability review stage. + +A complementary, advisory pass over the **full, unfiltered** codebase that +uses a strong LLM (Opus by default) to surface reachability signals beyond +what the structural analysis catches: + +- Likely entry points the structural pass missed (framework-specific + handlers, plugin registrations, lambdas, message handlers, etc.). +- External content ingestion sites (HTTP request bodies, file/network + reads, env/argv, IPC channels). +- Cross-process or async data flow indicators. + +Pipeline ordering (managed by ``core/scanner.py``): + +1. Parse with ``processing_level="all"`` so every unit is available. +2. ``analyze_reachability`` reviews all units and returns signals. +3. ``apply_signals`` promotes high-confidence ``entry_point`` signals by + setting ``is_entry_point=True`` on the target unit. +4. The structural reachability filter re-runs with LLM-promoted entry + points added as extra BFS seeds, yielding a dataset filtered to the + user's requested ``processing_level`` but expanded by LLM findings. + +Signals are **promote-only** — they never DEMOTE a unit that structural +analysis already kept. This matches the "complements, not replaces" intent +in issue #17. + +Output: +- ``analyze_reachability(...)`` returns a list of ``ReachabilitySignal`` + dicts. +- ``apply_signals(dataset, signals)`` mutates the dataset in place so each + unit gains an ``llm_reachability_signals`` field, and high-confidence + ``entry_point`` signals set ``is_entry_point = True`` on the target unit. + +Usage: + from core.llm_reachability import analyze_reachability, apply_signals + + signals = analyze_reachability(dataset, app_context=app_ctx) + apply_signals(dataset, signals) +""" + +from __future__ import annotations + +import json +import re +import sys +from dataclasses import dataclass, asdict +from typing import Any, Callable, Dict, List, Optional + + +# Models — aligns with core/analyzer.py which uses "claude-opus-4-6" for Opus. +MODEL_PRIMARY = "claude-opus-4-6" +MODEL_SECONDARY = "claude-sonnet-4-20250514" + + +# Maximum number of units to send in a single LLM call. Larger batches save +# round trips but risk token-limit errors and degraded recall. +DEFAULT_BATCH_SIZE = 25 + +# Default maximum bytes of code we send per unit. Trimmed to keep prompts +# tractable. Callers can override via the ``max_code_bytes`` parameter on +# :func:`analyze_reachability` (exposed as ``--llm-reachability-max-code-bytes`` +# on ``openant scan``); higher values catch entry-point indicators past the +# default cutoff in long handlers / generated code, at proportional cost. +DEFAULT_MAX_CODE_BYTES = 1500 +# Backward-compatible alias for any external caller importing the old name. +MAX_CODE_BYTES = DEFAULT_MAX_CODE_BYTES + + +# --------------------------------------------------------------------------- +# Public dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class ReachabilitySignal: + """A single LLM-emitted reachability signal for one unit. + + ``kind`` is one of: + - ``entry_point`` — unit is itself a likely entry point. + - ``external_input`` — unit receives external/untrusted input. + - ``cross_process`` — unit participates in async / cross-process data flow. + + ``confidence`` is one of ``high``, ``medium``, ``low``. + """ + + unit_id: str + kind: str + confidence: str + reason: str + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +# --------------------------------------------------------------------------- +# Prompt construction +# --------------------------------------------------------------------------- + + +PROMPT_TEMPLATE = """You are a senior application-security engineer auditing +a codebase for REACHABILITY signals — places where untrusted input can enter +the system. A previous structural pass has already flagged some entry points +and reachable units; your job is to surface ADDITIONAL signals it may have +missed (framework-specific handlers, plugin/CLI registrations, message +queues, async tasks, file/network ingestion, env/argv, IPC, etc.). + +Be conservative. Only emit a signal when the code clearly indicates one of: + + - "entry_point" — this unit is itself a likely entry point reachable + by an external actor (HTTP/CLI/queue/stream handler, + scheduled task, framework lifecycle hook, etc.). + - "external_input" — this unit reads or accepts data from an external + source (request body, file, socket, env, argv, stdin, + child-process output, untrusted message, etc.). + - "cross_process" — this unit dispatches or receives data across async + / process / queue boundaries (so taint may flow in + or out via a path the static call-graph misses). + +Confidence levels: + - "high" — the code unambiguously demonstrates the pattern. + - "medium" — the pattern is present but partially obscured. + - "low" — only suggestive; emit only if you'd want a human reviewer. + +Return STRICT JSON of the form: + + {{ + "signals": [ + {{"unit_id": "", "kind": "entry_point|external_input|cross_process", + "confidence": "high|medium|low", "reason": ""}}, + ... + ] + }} + +If no signals apply, return ``{{"signals": []}}``. Do NOT wrap the JSON in +markdown fences. Do NOT include any prose outside the JSON. + +{app_context_block} + +UNITS TO REVIEW (existing structural flags shown for context — your job is to +ADD signals beyond what those already capture): + +{units_block} +""" + + +def _build_app_context_block(app_context: Optional[Dict[str, Any]]) -> str: + """Render an optional app-context section for the prompt.""" + if not app_context: + return "APPLICATION CONTEXT: (none provided)" + try: + ctx_json = json.dumps(app_context, indent=2, sort_keys=True) + except (TypeError, ValueError): + ctx_json = str(app_context) + return f"APPLICATION CONTEXT:\n{ctx_json}" + + +def _trim_code(code: str, max_bytes: int = DEFAULT_MAX_CODE_BYTES) -> str: + """Truncate a code blob so the batch fits in a reasonable prompt window.""" + if not code: + return "" + if len(code) <= max_bytes: + return code + return code[:max_bytes] + "\n# ...[truncated]" + + +def _unit_for_prompt( + unit: Dict[str, Any], + max_code_bytes: int = DEFAULT_MAX_CODE_BYTES, +) -> Dict[str, Any]: + """Project a unit into the minimal shape we send to the LLM.""" + code_blob = "" + code = unit.get("code") or {} + if isinstance(code, dict): + code_blob = code.get("primary_code") or code.get("source") or "" + elif isinstance(code, str): + code_blob = code + + return { + "unit_id": unit.get("id", ""), + "unit_type": unit.get("unit_type", "function"), + "is_entry_point": bool(unit.get("is_entry_point", False)), + "reachable": unit.get("reachable"), + "code": _trim_code(code_blob, max_bytes=max_code_bytes), + } + + +def build_prompt( + units: List[Dict[str, Any]], + app_context: Optional[Dict[str, Any]] = None, + max_code_bytes: int = DEFAULT_MAX_CODE_BYTES, +) -> str: + """Assemble the LLM prompt for a batch of units.""" + app_block = _build_app_context_block(app_context) + payload = [_unit_for_prompt(u, max_code_bytes=max_code_bytes) for u in units] + units_block = json.dumps(payload, indent=2) + return PROMPT_TEMPLATE.format( + app_context_block=app_block, + units_block=units_block, + ) + + +# --------------------------------------------------------------------------- +# Response parsing +# --------------------------------------------------------------------------- + + +_VALID_KINDS = {"entry_point", "external_input", "cross_process"} +_VALID_CONFIDENCES = {"high", "medium", "low"} + + +def _extract_json(text: str) -> Optional[Dict[str, Any]]: + """Best-effort JSON extraction from a model response. + + Strips common markdown fences and falls back to the first ``{...}`` + block in the text. Returns ``None`` if nothing valid is found. + """ + if not text: + return None + cleaned = text.strip() + + # Strip ```json ... ``` or ``` ... ``` fences. + fence = re.match( + r"^```(?:json)?\s*(?P.*?)\s*```\s*$", + cleaned, + re.DOTALL | re.IGNORECASE, + ) + if fence: + cleaned = fence.group("body").strip() + + try: + return json.loads(cleaned) + except json.JSONDecodeError: + pass + + # Fall back to the first balanced JSON object in the response. + start = cleaned.find("{") + end = cleaned.rfind("}") + if start != -1 and end > start: + snippet = cleaned[start : end + 1] + try: + return json.loads(snippet) + except json.JSONDecodeError: + return None + return None + + +def parse_response( + response_text: str, + valid_unit_ids: Optional[set] = None, + on_error: Optional[Callable[[str], None]] = None, +) -> List[ReachabilitySignal]: + """Parse a single LLM response into validated ``ReachabilitySignal``s. + + Malformed entries are skipped (not raised); the optional ``on_error`` + callback receives a one-line description per skipped item, useful for + logging. + """ + log = on_error or (lambda msg: print(f"[LLMReach] {msg}", file=sys.stderr)) + + data = _extract_json(response_text) + if not isinstance(data, dict): + log("malformed response: not a JSON object — skipping batch") + return [] + + raw_signals = data.get("signals") + if not isinstance(raw_signals, list): + log("malformed response: 'signals' missing or not a list — skipping batch") + return [] + + out: List[ReachabilitySignal] = [] + for idx, item in enumerate(raw_signals): + if not isinstance(item, dict): + log(f"signal #{idx}: not an object — skipped") + continue + unit_id = item.get("unit_id") + kind = item.get("kind") + confidence = item.get("confidence") + reason = item.get("reason", "") + + if not isinstance(unit_id, str) or not unit_id: + log(f"signal #{idx}: missing unit_id — skipped") + continue + if kind not in _VALID_KINDS: + log(f"signal #{idx}: invalid kind {kind!r} — skipped") + continue + if confidence not in _VALID_CONFIDENCES: + log(f"signal #{idx}: invalid confidence {confidence!r} — skipped") + continue + if valid_unit_ids is not None and unit_id not in valid_unit_ids: + log(f"signal #{idx}: unknown unit_id {unit_id!r} — skipped") + continue + + out.append( + ReachabilitySignal( + unit_id=unit_id, + kind=kind, + confidence=confidence, + reason=str(reason)[:500], + ) + ) + return out + + +# --------------------------------------------------------------------------- +# Main entry points +# --------------------------------------------------------------------------- + + +def _chunk(items: List[Any], size: int) -> List[List[Any]]: + """Split ``items`` into batches of ``size``. + + A non-positive ``size`` is treated as "everything in one batch" so callers + that disable batching never hit a NameError or empty-output surprise. + """ + if size <= 0: + return [list(items)] if items else [] + return [items[i : i + size] for i in range(0, len(items), size)] + + +def analyze_reachability( + dataset: Dict[str, Any], + app_context: Optional[Dict[str, Any]] = None, + client: Any = None, + model: str = MODEL_PRIMARY, + batch_size: int = DEFAULT_BATCH_SIZE, + max_code_bytes: int = DEFAULT_MAX_CODE_BYTES, + max_units: Optional[int] = None, + on_error: Optional[Callable[[str], None]] = None, +) -> List[ReachabilitySignal]: + """Run the LLM reachability review stage over a parsed dataset. + + Args: + dataset: Parsed dataset with a ``units`` list, as produced by the + parser stage. Units are expected to expose ``id``, ``code``, and + optionally ``is_entry_point`` / ``reachable_from_entry``. + app_context: Optional application context dict; included in the + prompt to help the model reason about expected entry points + (e.g. ``{"application_type": "web_app"}``). + client: An object exposing ``analyze_sync(prompt, max_tokens=..., + model=...)``. If omitted, an :class:`AnthropicClient` is + instantiated lazily. + model: Model id to use (defaults to Opus). + batch_size: Units per LLM call. + max_code_bytes: Per-unit code-blob truncation limit. Higher values + give the LLM more context (better recall on long handlers / + generated code) at proportional Opus cost. Default 1500. + max_units: Optional cap on how many units to review. + on_error: Optional callback for parse/validation issues. + + Returns: + A flat list of :class:`ReachabilitySignal` for every unit the model + flagged. Unknown unit ids and malformed entries are filtered out. + """ + units = dataset.get("units") or [] + if max_units is not None and max_units >= 0: + units = units[:max_units] + if not units: + return [] + + if client is None: + # Lazy import so unit tests can stub this out without an API key. + from utilities.llm_client import AnthropicClient + + client = AnthropicClient(model=model) + + valid_ids = {u.get("id") for u in units if u.get("id")} + + signals: List[ReachabilitySignal] = [] + batches = _chunk(units, batch_size) + for i, batch in enumerate(batches): + prompt = build_prompt( + batch, app_context=app_context, max_code_bytes=max_code_bytes + ) + try: + text = client.analyze_sync(prompt, max_tokens=4096, model=model) + except Exception as exc: # noqa: BLE001 — advisory stage; never crash pipeline + msg = f"batch {i + 1}/{len(batches)} failed: {exc}" + if on_error: + on_error(msg) + else: + print(f"[LLMReach] {msg}", file=sys.stderr) + continue + + parsed = parse_response( + text, valid_unit_ids=valid_ids, on_error=on_error + ) + signals.extend(parsed) + + return signals + + +# --------------------------------------------------------------------------- +# Signal application (promote-only) +# --------------------------------------------------------------------------- + + +# Confidences at or above this threshold promote ``entry_point`` signals to +# ``is_entry_point = True`` on the target unit. +_PROMOTE_ENTRY_POINT_AT = {"high"} + + +def apply_signals( + dataset: Dict[str, Any], + signals: List[ReachabilitySignal], +) -> Dict[str, int]: + """Merge LLM signals back into ``dataset`` (in place, promote-only). + + For each unit referenced by a signal: + - The signal is appended to a per-unit ``llm_reachability_signals`` list. + - If the signal kind is ``entry_point`` AND its confidence is in + :data:`_PROMOTE_ENTRY_POINT_AT`, the unit's ``is_entry_point`` field + is set to ``True`` (never set back to ``False``). + + Crucially, this never DEMOTES a unit. ``is_entry_point=True`` set by the + structural pass remains true regardless of what the LLM said. + + Returns a small summary dict:: + + { + "signals_applied": , + "entry_points_promoted": , + "units_touched": , + } + """ + units = dataset.get("units") or [] + by_id = {u.get("id"): u for u in units if u.get("id")} + + promoted = 0 + touched: set = set() + applied = 0 + + for sig in signals: + unit = by_id.get(sig.unit_id) + if unit is None: + continue + + existing = unit.setdefault("llm_reachability_signals", []) + existing.append(sig.to_dict()) + applied += 1 + touched.add(sig.unit_id) + + if ( + sig.kind == "entry_point" + and sig.confidence in _PROMOTE_ENTRY_POINT_AT + and not unit.get("is_entry_point", False) + ): + unit["is_entry_point"] = True + unit["entry_point_reason"] = f"llm_reachability: {sig.reason}" + promoted += 1 + + return { + "signals_applied": applied, + "entry_points_promoted": promoted, + "units_touched": len(touched), + } + + +def signals_to_json(signals: List[ReachabilitySignal]) -> List[Dict[str, Any]]: + """Serialize a list of signals for JSON persistence.""" + return [s.to_dict() for s in signals] diff --git a/libs/openant-core/core/parser_adapter.py b/libs/openant-core/core/parser_adapter.py index 46fc08c..605450a 100644 --- a/libs/openant-core/core/parser_adapter.py +++ b/libs/openant-core/core/parser_adapter.py @@ -9,58 +9,60 @@ sys.path hacks in the original code. """ +import contextlib import json import os +import shutil import subprocess import sys from pathlib import Path from core.schemas import ParseResult -from utilities.file_io import read_json, write_json +from utilities.file_io import open_utf8, read_json, write_json # Root of openant-core (where parsers/ lives) _CORE_ROOT = Path(__file__).parent.parent +# JS parser directory (holds its own package.json / node_modules) +_JS_PARSER_DIR = _CORE_ROOT / "parsers" / "javascript" + +# Shared language detection config (single source of truth: config/languages.json) +_LANGUAGES_CONFIG = Path(__file__).parent.parent.parent.parent / "config" / "languages.json" + + +def _load_language_config() -> dict: + return read_json(_LANGUAGES_CONFIG) + def detect_language(repo_path: str) -> str: """Auto-detect the primary language of a repository. Counts source files by extension and returns the dominant language. + Extension mappings and skip directories are loaded from config/languages.json. Returns: - "python", "javascript", or "go" + One of: "python", "javascript", "go", "c", "ruby", "php", "zig" """ + config = _load_language_config() + skip_dirs = set(config["skip_dirs"]) + extensions = config["extensions"] + repo = Path(repo_path) - counts = {"python": 0, "javascript": 0, "go": 0, "c": 0, "ruby": 0, "php": 0, "zig": 0} + counts: dict[str, int] = {} for f in repo.rglob("*"): if not f.is_file(): continue - # Skip common non-source dirs - parts = f.parts - if any(p in parts for p in ( - "node_modules", "__pycache__", "venv", ".venv", - "dist", "build", ".git", "vendor", - )): + # Skip configured non-source dirs + if any(p in skip_dirs for p in f.parts): continue suffix = f.suffix.lower() - if suffix == ".py": - counts["python"] += 1 - elif suffix in (".js", ".ts", ".jsx", ".tsx", ".mjs", ".cjs"): - counts["javascript"] += 1 - elif suffix == ".go": - counts["go"] += 1 - elif suffix in (".c", ".h", ".cpp", ".hpp", ".cc", ".cxx", ".hxx", ".hh"): - counts["c"] += 1 - elif suffix in (".rb", ".rake"): - counts["ruby"] += 1 - elif suffix == ".php": - counts["php"] += 1 - elif suffix == ".zig": - counts["zig"] += 1 - - if not any(counts.values()): + if suffix in extensions: + lang = extensions[suffix] + counts[lang] = counts.get(lang, 0) + 1 + + if not counts: raise ValueError( f"No supported source files found in {repo_path}. " "Supported languages: Python, JavaScript/TypeScript, Go, C/C++, Ruby, PHP, Zig." @@ -189,10 +191,11 @@ def _maybe_apply_diff_filter( # Reachability filter (shared by Python path; JS/Go handle it internally) # --------------------------------------------------------------------------- -def _apply_reachability_filter( +def apply_reachability_filter( dataset: dict, output_dir: str, processing_level: str, + extra_entry_points: "set[str] | None" = None, ) -> dict: """Filter dataset units to only those reachable from entry points. @@ -200,6 +203,12 @@ def _apply_reachability_filter( detects entry points, computes reachability via BFS, and removes unreachable units from the dataset. + ``extra_entry_points`` supplements the structurally-detected seed set. + Pass LLM-promoted unit IDs here so the BFS propagates from them even if + the structural heuristics missed them. Any unit that already has + ``is_entry_point=True`` in the dataset (e.g. set by the LLM reachability + stage) keeps that flag — this function never demotes it. + For ``codeql`` and ``exploitable`` levels the reachability filter is still applied (it is a prerequisite), but the additional CodeQL / LLM-classification filters are not yet wired into the Python path @@ -209,6 +218,7 @@ def _apply_reachability_filter( dataset: The full, unfiltered dataset dict (mutated in place). output_dir: Directory containing call_graph.json from the parser. processing_level: One of "reachable", "codeql", "exploitable". + extra_entry_points: Additional unit IDs to seed the BFS (e.g. from LLM). Returns: The (possibly filtered) dataset dict. @@ -246,9 +256,11 @@ def _load_module(name, filename): call_graph = call_graph_data.get("call_graph", {}) reverse_call_graph = call_graph_data.get("reverse_call_graph", {}) - # Detect entry points + # Detect entry points structurally, then seed with any extras (e.g. LLM-promoted). detector = EntryPointDetector(functions, call_graph) entry_points = detector.detect_entry_points() + if extra_entry_points: + entry_points = entry_points | extra_entry_points # Compute reachable set (BFS forward from entry points) reachability = ReachabilityAnalyzer( @@ -266,8 +278,9 @@ def _load_module(name, filename): unit_id = u.get("id", "") if unit_id in reachable_ids: u["reachable"] = True - u["is_entry_point"] = unit_id in entry_points - if unit_id in entry_points: + # Preserve any is_entry_point=True already set (e.g. by LLM stage). + u["is_entry_point"] = (unit_id in entry_points) or u.get("is_entry_point", False) + if unit_id in entry_points and not u.get("entry_point_reason"): u["entry_point_reason"] = detector.get_entry_point_reason(unit_id) filtered_units.append(u) @@ -311,6 +324,10 @@ def _load_module(name, filename): return dataset +# Private alias kept for the Python parser path which calls it directly. +_apply_reachability_filter = apply_reachability_filter + + # --------------------------------------------------------------------------- # Python parser # --------------------------------------------------------------------------- @@ -364,12 +381,114 @@ def _parse_python(repo_path: str, output_dir: str, processing_level: str, skip_t # JavaScript/TypeScript parser # --------------------------------------------------------------------------- +def _js_deps_installed() -> bool: + """Return True only if a *complete* npm install has previously succeeded. + + Checking that ``node_modules/`` exists is not enough: a prior install that + was killed (Ctrl+C, OOM, disk full) leaves a partial directory. npm writes + ``node_modules/.package-lock.json`` at the *end* of a successful install, + so we use that as the completion sentinel. + """ + return (_JS_PARSER_DIR / "node_modules" / ".package-lock.json").is_file() + + +def _ensure_js_parser_dependencies() -> None: + """Install the JS parser's Node dependencies on first use. + + Mirrors the Go CLI's venv bootstrap (apps/openant-cli/internal/python/runtime.go): + the first invocation installs, subsequent invocations are a no-op. Runs only + when a JS repo is actually being parsed, so Python/Go-only users never need npm. + + Concurrency: uses a lockfile so two parallel parses don't both run + ``npm install`` in the same directory (which can corrupt node_modules). + """ + if _js_deps_installed(): + return + + if not (_JS_PARSER_DIR / "package.json").is_file(): + raise RuntimeError( + f"JS parser package.json not found at {_JS_PARSER_DIR / 'package.json'}. " + "The openant-core install may be incomplete." + ) + + npm = shutil.which("npm") + if npm is None: + raise RuntimeError( + "JavaScript parser dependencies are not installed and `npm` is not on PATH. " + f"Install Node.js/npm, then run: npm install (from {_JS_PARSER_DIR})" + ) + + # Serialize concurrent bootstraps. The lockfile lives next to package.json so + # it's always on the same filesystem as the install target. + lock_path = _JS_PARSER_DIR / ".openant-npm-install.lock" + with _file_lock(lock_path): + # Re-check under the lock: another process may have finished while we waited. + if _js_deps_installed(): + return + + print( + "[Parser] Installing JS parser dependencies (first run, this may take a minute)...", + file=sys.stderr, + ) + result = subprocess.run( + [npm, "install"], + cwd=str(_JS_PARSER_DIR), + stdout=sys.stderr, + stderr=sys.stderr, + ) + if result.returncode != 0: + raise RuntimeError( + f"`npm install` failed in {_JS_PARSER_DIR} with exit code " + f"{result.returncode}. See npm output above for details; you can " + f"reproduce with: npm install (from {_JS_PARSER_DIR})" + ) + + +@contextlib.contextmanager +def _file_lock(lock_path: Path): + """Cross-platform exclusive file lock as a context manager. + + Uses ``msvcrt`` on Windows and ``fcntl`` elsewhere. Blocks until the lock is + acquired, releases on exit. The lockfile itself is left in place; only the + OS-level lock matters for mutual exclusion. + """ + lock_path.parent.mkdir(parents=True, exist_ok=True) + # "w" (not "a+") so the file pointer is at byte 0 — msvcrt.locking locks a + # range starting at the *current* file position, so different positions + # would mean non-overlapping (i.e. non-exclusive) locks. + f = open_utf8(lock_path, "w") + try: + if os.name == "nt": + import msvcrt + + f.seek(0) + # LK_LOCK blocks (with retries) until the byte range is exclusive. + msvcrt.locking(f.fileno(), msvcrt.LK_LOCK, 1) + try: + yield + finally: + f.seek(0) + msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1) + else: + import fcntl + + fcntl.flock(f.fileno(), fcntl.LOCK_EX) + try: + yield + finally: + fcntl.flock(f.fileno(), fcntl.LOCK_UN) + finally: + f.close() + + def _parse_javascript(repo_path: str, output_dir: str, processing_level: str, skip_tests: bool = True, name: str = None) -> ParseResult: """Invoke the JavaScript/TypeScript parser. The JS parser is a PipelineTest class that runs Node.js subprocesses. We invoke it via subprocess to avoid the sys.path hacks. """ + _ensure_js_parser_dependencies() + print("[Parser] Running JavaScript parser...", file=sys.stderr) parser_script = _CORE_ROOT / "parsers" / "javascript" / "test_pipeline.py" diff --git a/libs/openant-core/core/scanner.py b/libs/openant-core/core/scanner.py index 2eba6ee..0424672 100644 --- a/libs/openant-core/core/scanner.py +++ b/libs/openant-core/core/scanner.py @@ -27,7 +27,7 @@ ) from core.step_report import step_context from core import tracking -from utilities.file_io import read_json +from utilities.file_io import read_json, write_json # Import app context generator (optional) try: @@ -60,6 +60,8 @@ def scan_repository( repo_url: str | None = None, commit_sha: str | None = None, diff_manifest: str | None = None, + llm_reachability: bool = False, + llm_reachability_max_code_bytes: int = 1500, ) -> ScanResult: """Scan a repository for vulnerabilities. @@ -107,6 +109,7 @@ def scan_repository( # Count total steps for progress display total_steps = _count_steps( generate_context, enhance, verify, generate_report, dynamic_test, + llm_reachability=llm_reachability, ) step_num = 0 @@ -124,19 +127,31 @@ def _step_label(name: str) -> str: # --------------------------------------------------------------- from core.parser_adapter import parse_repository + # When LLM reachability is enabled the stage must see ALL units so it can + # identify entry points the structural pass would miss. Parse with "all" + # here; the structural filter is re-applied after LLM signals are merged. + effective_parse_level = ( + "all" if (llm_reachability and processing_level != "all") else processing_level + ) + print(_step_label("Parsing repository..."), file=sys.stderr) + if effective_parse_level != processing_level: + print( + " [LLM reachability] parsing all units; structural filter runs after LLM signals", + file=sys.stderr, + ) with step_context("parse", output_dir, inputs={ "repo_path": repo_path, "language": language, - "processing_level": processing_level, + "processing_level": effective_parse_level, "skip_tests": skip_tests, }) as ctx: parse_result = parse_repository( repo_path=repo_path, output_dir=output_dir, language=language, - processing_level=processing_level, + processing_level=effective_parse_level, skip_tests=skip_tests, diff_manifest=diff_manifest, ) @@ -174,7 +189,7 @@ def _step_label(name: str) -> str: # --------------------------------------------------------------- # Step 2: Application Context (optional) # --------------------------------------------------------------- - app_context_path = None + app_context_path: str | None = None if generate_context and HAS_APP_CONTEXT: print(_step_label("Generating application context..."), file=sys.stderr) @@ -205,6 +220,133 @@ def _step_label(name: str) -> str: result.skipped_steps.append("app-context") print(file=sys.stderr) + # --------------------------------------------------------------- + # Step 2.5: LLM Reachability review (optional, opt-in) + # --------------------------------------------------------------- + # Runs after parse + app-context and before enhance/analyze. Because parse + # was done with processing_level="all" (when filtering is requested), the + # LLM sees every unit in the codebase and can identify entry points the + # structural heuristics would miss. After signals are applied the + # structural reachability filter is re-run with LLM-promoted entry points + # added as extra BFS seeds, so the final dataset honours the user's + # requested processing_level. Threading app_context into the prompt helps + # the model reason about expected entry points (e.g. "this is a web_app, + # look for HTTP handlers"). + if llm_reachability: + from core.llm_reachability import ( + MODEL_PRIMARY as _LLM_REACH_MODEL, + analyze_reachability, + apply_signals, + signals_to_json, + ) + + print(_step_label("Running LLM reachability review..."), file=sys.stderr) + + with step_context("llm-reachability", output_dir, inputs={ + "dataset_path": active_dataset_path, + "model": _LLM_REACH_MODEL, + }) as ctx: + try: + dataset = read_json(active_dataset_path) + except (OSError, json.JSONDecodeError) as exc: + print(f" WARNING: failed to load dataset: {exc}", file=sys.stderr) + ctx.summary = {"skipped": True, "reason": str(exc)} + dataset = None + + if dataset is not None: + app_ctx_payload = None + if app_context_path and os.path.exists(app_context_path): + try: + app_ctx_payload = read_json(app_context_path) + except (OSError, json.JSONDecodeError): + app_ctx_payload = None + + # --limit governs the analyze stage, not how many units the + # LLM reachability pass reviews — it must see the full + # codebase to find missed entry points. + signals = analyze_reachability( + dataset=dataset, + app_context=app_ctx_payload, + max_code_bytes=llm_reachability_max_code_bytes, + ) + summary = apply_signals(dataset, signals) + + signals_path = os.path.join(output_dir, "llm_reachability.json") + write_json(signals_path, {"signals": signals_to_json(signals)}, indent=2) + + pre_filter_count = len(dataset.get("units", [])) + post_filter_count = pre_filter_count + refilter_supported = False + + # Re-apply the structural reachability filter using + # LLM-promoted entry points as additional BFS seeds. + # Only possible when call_graph.json was written by the parser + # (Python and Zig paths do this; JS/Go/C/Ruby/PHP handle + # reachability filtering internally and don't persist it). + if processing_level != "all": + call_graph_path = os.path.join(output_dir, "call_graph.json") + if os.path.exists(call_graph_path): + from core.parser_adapter import apply_reachability_filter + llm_promoted_ids = { + u["id"] for u in dataset.get("units", []) + if u.get("is_entry_point") and u.get("id") + } + dataset = apply_reachability_filter( + dataset, + output_dir, + processing_level, + extra_entry_points=llm_promoted_ids, + ) + post_filter_count = len(dataset.get("units", [])) + result.units_count = post_filter_count + refilter_supported = True + else: + # Parser doesn't persist call_graph.json — the full + # unfiltered dataset will flow to downstream stages. + # Warn loudly so the cost impact is visible. + print( + f"\n WARNING: --llm-reachability with " + f"--level {processing_level}: " + f"{parse_result.language} does not yet support " + f"post-LLM re-filtering (call_graph.json not found). " + f"Downstream stages will process all " + f"{pre_filter_count} units instead of the filtered " + f"subset — this may significantly increase cost.", + file=sys.stderr, + ) + + # Persist final dataset so downstream stages see promoted + # entry points, per-unit signals, and the applied filter. + write_json(active_dataset_path, dataset, indent=2) + + ctx.summary = { + "units_reviewed": pre_filter_count, + "signals_added": summary["signals_applied"], + "entry_points_promoted": summary["entry_points_promoted"], + "units_touched": summary["units_touched"], + "post_filter_units": post_filter_count, + "refilter_supported": refilter_supported, + } + ctx.outputs = {"signals_path": signals_path} + + print( + f" LLM reachability: {summary['signals_applied']} signals, " + f"{summary['entry_points_promoted']} new entry points", + file=sys.stderr, + ) + if processing_level != "all" and refilter_supported: + print( + f" After reachability filter: {post_filter_count} units", + file=sys.stderr, + ) + + collected_step_reports.append( + _load_step_report(output_dir, "llm-reachability") + ) + else: + result.skipped_steps.append("llm-reachability") + print(file=sys.stderr) + # --------------------------------------------------------------- # Step 3: Enhance (optional) # --------------------------------------------------------------- @@ -522,6 +664,7 @@ def _count_steps( verify: bool, generate_report: bool, dynamic_test: bool, + llm_reachability: bool = False, ) -> int: """Count total steps for progress display (always includes parse, detect, build-output).""" count = 3 # parse + detect + build-output (always run) @@ -535,6 +678,8 @@ def _count_steps( count += 1 if dynamic_test: count += 1 + if llm_reachability: + count += 1 return count diff --git a/libs/openant-core/openant/cli.py b/libs/openant-core/openant/cli.py index e521b22..c303c64 100644 --- a/libs/openant-core/openant/cli.py +++ b/libs/openant-core/openant/cli.py @@ -75,6 +75,10 @@ def cmd_scan(args): repo_url=getattr(args, "repo_url", None), commit_sha=getattr(args, "commit_sha", None), diff_manifest=getattr(args, "diff_manifest", None), + llm_reachability=getattr(args, "llm_reachability", False), + llm_reachability_max_code_bytes=getattr( + args, "llm_reachability_max_code_bytes", 1500 + ), ) scan_payload = result.to_dict() @@ -988,6 +992,29 @@ def main(): scan_p.add_argument("--backoff", type=int, default=30, help="Seconds to wait when rate-limited (default: 30)") scan_p.add_argument("--diff-manifest", help="Path to diff_manifest.json for incremental scanning") + scan_p.add_argument( + "--llm-reachability", + action="store_true", + dest="llm_reachability", + help="Enable the LLM reachability review stage (Opus). " + "Surfaces entry points and external-input sites the structural " + "pass would miss by reviewing the full codebase before the " + "reachability filter is applied. Off by default — enabling " + "this incurs cost proportional to total repo size, not the " + "filtered unit count (~one Opus call per 25 units across the " + "whole codebase).", + ) + scan_p.add_argument( + "--llm-reachability-max-code-bytes", + type=int, + default=1500, + dest="llm_reachability_max_code_bytes", + help="Max code bytes per unit sent to the LLM reachability stage " + "(default: 1500). Higher values (e.g. 4096, 8192) catch " + "entry-point indicators past byte 1500 in long handlers / " + "generated code, at proportional Opus cost increase. Only " + "meaningful with --llm-reachability.", + ) scan_p.set_defaults(func=cmd_scan) # --------------------------------------------------------------- diff --git a/libs/openant-core/parsers/c/test_pipeline.py b/libs/openant-core/parsers/c/test_pipeline.py index 5072d68..f325a7f 100644 --- a/libs/openant-core/parsers/c/test_pipeline.py +++ b/libs/openant-core/parsers/c/test_pipeline.py @@ -42,10 +42,11 @@ from enum import Enum from pathlib import Path from typing import Set -from utilities.file_io import open_utf8, read_json, run_utf8, write_json -# Add parent directory to path for utilities import +# Add parent directory to path so utilities/ imports resolve when this script +# is invoked as a subprocess by core/parser_adapter.py (cwd may not include it). sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +from utilities.file_io import open_utf8, read_json, run_utf8, write_json from utilities.context_enhancer import ContextEnhancer from utilities.agentic_enhancer import EntryPointDetector, ReachabilityAnalyzer @@ -184,6 +185,10 @@ def run_parser_pipeline(self) -> bool: analyzer_output = generator.generate_analyzer_output() write_json(self.analyzer_output_file, analyzer_output) + # Write call graph for post-LLM reachability re-filtering + call_graph_file = os.path.join(self.output_dir, 'call_graph.json') + write_json(call_graph_file, graph_result) + elapsed = (datetime.now() - start_time).total_seconds() summary = { diff --git a/libs/openant-core/parsers/go/test_pipeline.py b/libs/openant-core/parsers/go/test_pipeline.py index 7e2aa11..5abdf83 100644 --- a/libs/openant-core/parsers/go/test_pipeline.py +++ b/libs/openant-core/parsers/go/test_pipeline.py @@ -283,6 +283,58 @@ def run_go_parser_all(self) -> bool: print(f"Warning: Could not apply dataset name: {e}") self.results['stages']['go_parser'] = result + + # Write call_graph.json immediately after parsing so the post-LLM + # reachability re-filter can use it regardless of processing_level. + # Go's analyzer_output.json has functions; the call graph edges live + # in each unit's metadata.direct_calls / direct_callers. + if ( + result.get('success', False) + and self.analyzer_output_file and os.path.exists(self.analyzer_output_file) + and self.dataset_file and os.path.exists(self.dataset_file) + ): + try: + analyzer = read_json(self.analyzer_output_file) + dataset_for_cg = read_json(self.dataset_file) + + raw_functions = analyzer.get("functions", {}) + # Normalise to the camelCase shape EntryPointDetector expects. + normalized_functions = { + func_id: { + 'name': fd.get('name', ''), + 'unitType': fd.get('unit_type', fd.get('unitType', 'function')), + 'code': fd.get('code', ''), + 'filePath': fd.get('file_path', fd.get('filePath', '')), + 'startLine': fd.get('start_line', fd.get('startLine', 0)), + 'endLine': fd.get('end_line', fd.get('endLine', 0)), + 'package': fd.get('package', ''), + 'receiver': fd.get('receiver', ''), + 'isExported': fd.get('is_exported', fd.get('isExported', False)), + } + for func_id, fd in raw_functions.items() + } + + call_graph: dict = {} + reverse_call_graph: dict = {} + for unit in dataset_for_cg.get('units', []): + unit_id = unit.get('id') + metadata = unit.get('metadata', {}) + direct_calls = metadata.get('direct_calls', metadata.get('directCalls', [])) + direct_callers = metadata.get('direct_callers', metadata.get('directCallers', [])) + if direct_calls: + call_graph[unit_id] = direct_calls + if direct_callers: + reverse_call_graph[unit_id] = direct_callers + + call_graph_file = os.path.join(self.output_dir, 'call_graph.json') + write_json(call_graph_file, { + "functions": normalized_functions, + "call_graph": call_graph, + "reverse_call_graph": reverse_call_graph, + }) + except (OSError, json.JSONDecodeError, KeyError) as e: + print(f"Warning: could not write call_graph.json: {e}") + return result.get('success', False) def apply_reachability_filter(self) -> bool: diff --git a/libs/openant-core/parsers/javascript/dependency_resolver.js b/libs/openant-core/parsers/javascript/dependency_resolver.js index 52d130e..39644c4 100644 --- a/libs/openant-core/parsers/javascript/dependency_resolver.js +++ b/libs/openant-core/parsers/javascript/dependency_resolver.js @@ -20,6 +20,7 @@ const path = require('path'); class DependencyResolver { constructor(analyzerOutput, options = {}) { this.functions = analyzerOutput.functions || {}; + this.classes = analyzerOutput.classes || {}; // "filePath:className" -> { constructorDeps, fieldDeps, baseTypes } this.callGraph = {}; // functionId -> [calledFunctionIds] this.reverseCallGraph = {}; // functionId -> [callerFunctionIds] this.maxDepth = options.maxDepth || 3; @@ -29,6 +30,7 @@ class DependencyResolver { this.functionsByName = Object.create(null); // simpleName -> [functionIds] this.functionsByFile = Object.create(null); // filePath -> [functionIds] this.imports = Object.create(null); // filePath -> { importedName -> { source, originalName } } + this.classesByBaseType = Object.create(null); // baseTypeName -> ["filePath:className", ...] this._buildIndexes(); } @@ -52,6 +54,13 @@ class DependencyResolver { } this.functionsByFile[filePath].push(funcId); } + + for (const [classKey, classData] of Object.entries(this.classes)) { + for (const baseType of (classData.baseTypes || [])) { + if (!this.classesByBaseType[baseType]) this.classesByBaseType[baseType] = []; + this.classesByBaseType[baseType].push(classKey); + } + } } /** @@ -60,6 +69,21 @@ class DependencyResolver { buildCallGraph() { for (const [funcId, funcData] of Object.entries(this.functions)) { const calls = this._extractCalls(funcData.code, funcId); + + // Merge in any explicit call edges declared by the analyzer. + // This is used for cases the body-text regex can't see — e.g. + // Express middleware identifiers passed as sibling args: + // app.post('/x', authenticateToken, async (req,res) => {...}) + const explicitCalls = funcData.explicitCalls || []; + const callerFile = funcId.split(':')[0]; + for (const name of explicitCalls) { + if (!name) continue; + const resolved = this._resolveCall(name, callerFile, funcId); + if (resolved && !calls.includes(resolved)) { + calls.push(resolved); + } + } + this.callGraph[funcId] = calls; // Build reverse graph @@ -134,7 +158,7 @@ class DependencyResolver { // Skip 'this' (handled above) and common built-ins if (objectName === 'this' || this._isBuiltIn(objectName)) continue; - const resolved = this._resolveMethodCall(objectName, methodName, callerFile); + const resolved = this._resolveMethodCall(objectName, methodName, callerFile, callerFuncId); if (resolved && !seenCalls.has(resolved)) { seenCalls.add(resolved); calls.push(resolved); @@ -240,16 +264,20 @@ class DependencyResolver { /** * Resolve an object.method call + * + * Supports two resolution strategies: + * 1. Direct class name match: objectName === className + * 2. DI-aware resolution: objectName is a constructor-injected parameter, + * use its type annotation to find the target class */ - _resolveMethodCall(objectName, methodName, callerFile) { - // Check if objectName matches a class name - const qualifiedName = `${objectName}.${methodName}`; + _resolveMethodCall(objectName, methodName, callerFile, callerFuncId = null) { const candidates = this.functionsByName[methodName]; if (!candidates || !Array.isArray(candidates)) { return null; } + // 1. Exact class name match (existing behavior) for (const funcId of candidates) { const funcData = this.functions[funcId]; if (funcData && funcData.className === objectName) { @@ -257,6 +285,47 @@ class DependencyResolver { } } + // 2. DI-aware resolution: look up objectName in caller's constructorDeps + // e.g., this.callService.getById() -> constructorDeps says callService: CallService + // -> resolve to CallService.getById + if (callerFuncId) { + const callerFunc = this.functions[callerFuncId]; + const classEntry = callerFunc && callerFunc.className && + this.classes[callerFile + ':' + callerFunc.className]; + if (classEntry && (classEntry.constructorDeps || classEntry.fieldDeps)) { + const typeName = (classEntry.constructorDeps || {})[objectName] + ?? (classEntry.fieldDeps || {})[objectName]; + if (typeName) { + // 2a. Exact type match + for (const funcId of candidates) { + const funcData = this.functions[funcId]; + if (funcData && funcData.className === typeName) { + return funcId; + } + } + + // 2b. Nominal type match: prefer candidates whose class implements or extends typeName. + // If exactly one such candidate exists, the resolution is unambiguous. + const nominalClassKeys = this.classesByBaseType[typeName] || []; + const nominalMatches = candidates.filter(funcId => { + const funcData = this.functions[funcId]; + if (!funcData || !funcData.className) return false; + const funcClassKey = funcId.split(':')[0] + ':' + funcData.className; + return nominalClassKeys.includes(funcClassKey); + }); + if (nominalMatches.length === 1) return nominalMatches[0]; + + // 2c. Prefix match: last resort for versioned names (e.g., CallService -> CallServiceV1). + // Skip if multiple candidates match to preserve no-false-positive property. + const prefixMatches = candidates.filter(funcId => { + const funcData = this.functions[funcId]; + return funcData && funcData.className && funcData.className.startsWith(typeName); + }); + if (prefixMatches.length === 1) return prefixMatches[0]; + } + } + } + return null; } diff --git a/libs/openant-core/parsers/javascript/test_pipeline.py b/libs/openant-core/parsers/javascript/test_pipeline.py index 667bf1f..2eee6bd 100644 --- a/libs/openant-core/parsers/javascript/test_pipeline.py +++ b/libs/openant-core/parsers/javascript/test_pipeline.py @@ -307,6 +307,23 @@ def run_typescript_analyzer(self, files: list = None) -> bool: ) self.results['stages']['typescript_analyzer'] = result + + # Write call_graph.json immediately after the analyzer output is + # available so the post-LLM reachability re-filter can use it + # regardless of processing_level (which may be "all"). + if result.get('success', False) and os.path.exists(self.analyzer_output_file): + try: + analyzer = read_json(self.analyzer_output_file) + call_graph_data = { + "functions": analyzer.get("functions", {}), + "call_graph": analyzer.get("call_graph", analyzer.get("callGraph", {})), + "reverse_call_graph": analyzer.get("reverse_call_graph", analyzer.get("reverseCallGraph", {})), + } + call_graph_file = os.path.join(self.output_dir, 'call_graph.json') + write_json(call_graph_file, call_graph_data) + except (OSError, json.JSONDecodeError, KeyError) as e: + print(f"Warning: could not write call_graph.json: {e}") + return result.get('success', False) def run_stage_with_stdout_capture(self, name: str, command: list, output_file: str) -> dict: diff --git a/libs/openant-core/parsers/javascript/typescript_analyzer.js b/libs/openant-core/parsers/javascript/typescript_analyzer.js index 7121acd..8b6b0e7 100644 --- a/libs/openant-core/parsers/javascript/typescript_analyzer.js +++ b/libs/openant-core/parsers/javascript/typescript_analyzer.js @@ -58,6 +58,7 @@ class TypeScriptAnalyzer { compilerOptions: PERMISSIVE_COMPILER_OPTIONS, }); this.functions = {}; // functionId -> function metadata + this.classes = {}; // "filePath:className" -> { constructorDeps, fieldDeps, baseTypes } this.callGraph = {}; // callerId -> array of call info } @@ -163,6 +164,7 @@ class TypeScriptAnalyzer { return { functions: this.functions, + classes: this.classes, callGraph: this.callGraph, }; } @@ -240,6 +242,87 @@ class TypeScriptAnalyzer { className: className, }; } + + // Build class-level metadata: constructorDeps and baseTypes + const classEntry = {}; + + // Extract base types (implements + extends) for nominal DI resolution. + // Strips generics: implements Repository -> Repository + const baseTypes = []; + const extendsExpr = classDecl.getExtends(); + if (extendsExpr) { + const name = extendsExpr.getExpression().getText().replace(/<.*$/, ''); + if (/^[A-Z][a-zA-Z0-9_$]*$/.test(name)) baseTypes.push(name); + } + for (const impl of classDecl.getImplements()) { + const name = impl.getExpression().getText().replace(/<.*$/, ''); + if (/^[A-Z][a-zA-Z0-9_$]*$/.test(name)) baseTypes.push(name); + } + if (baseTypes.length > 0) classEntry.baseTypes = baseTypes; + + // Extract constructor DI metadata. + // DI classes have a single primary constructor; overloads are unusual in NestJS/Angular. + const constructors = classDecl.getConstructors(); + if (constructors.length > 0) { + const ctor = constructors[0]; + const injections = {}; // paramName -> typeName + + for (const param of ctor.getParameters()) { + const paramName = param.getName(); + const typeNode = param.getTypeNode(); + if (typeNode) { + // Strip generic parameters so Repository resolves as Repository + const typeName = typeNode.getText().replace(/<.*$/, ''); + // Only store simple PascalCase type names (skip union types, primitives) + if (/^[A-Z][a-zA-Z0-9_$]*$/.test(typeName)) { + injections[paramName] = typeName; + } + } + } + + if (Object.keys(injections).length > 0) classEntry.constructorDeps = injections; + } + + // Extract field/property injection metadata. + // Covers decorator-based (@Inject, @InjectRepository, etc.) and Angular's inject() function. + const fieldDeps = {}; + for (const prop of classDecl.getProperties()) { + const propName = prop.getName(); + let typeName = null; + + // Decorator-based: any @Inject* decorator signals an injection point; + // the injected type comes from the TypeScript type annotation. + const hasInjectDecorator = prop.getDecorators().some(d => /^Inject/.test(d.getName())); + if (hasInjectDecorator) { + const typeNode = prop.getTypeNode(); + if (typeNode) { + const t = typeNode.getText().replace(/<.*$/, ''); + if (/^[A-Z][a-zA-Z0-9_$]*$/.test(t)) typeName = t; + } + } + + // Functional: private svc = inject(SvcType) (Angular inject() API) + if (!typeName) { + const init = prop.getInitializer(); + if (init && init.getKindName() === 'CallExpression') { + const expr = init.getExpression(); + if (expr && expr.getText() === 'inject') { + const args = init.getArguments(); + if (args.length > 0) { + const t = args[0].getText().replace(/<.*$/, ''); + if (/^[A-Z][a-zA-Z0-9_$]*$/.test(t)) typeName = t; + } + } + } + } + + if (typeName) fieldDeps[propName] = typeName; + } + if (Object.keys(fieldDeps).length > 0) classEntry.fieldDeps = fieldDeps; + + if (Object.keys(classEntry).length > 0) { + this.classes[`${relativePath}:${className}`] = classEntry; + } } // Extract methods from object literals in export default @@ -253,6 +336,242 @@ class TypeScriptAnalyzer { // Extract functions from module.exports.propertyName = function() {...} // Pattern used by DVNA and similar CommonJS codebases this._extractModuleExportsPropertyFunctions(sourceFile, relativePath); + + // Extract anonymous callbacks used as Express route handlers / middleware + // Pattern: app.get('/x', auth, async (req, res) => {...}) + this._extractExpressRouteCallbacks(sourceFile, relativePath); + } + + /** + * Express HTTP verbs we recognise on a router/app object. + * `use` is included to pick up middleware-mount callbacks. + */ + static EXPRESS_VERBS = new Set([ + "get", + "post", + "put", + "patch", + "delete", + "options", + "head", + "all", + "use", + ]); + + /** + * Walk a source file looking for Express-style route registrations and + * emit a synthetic function entry for each anonymous arrow / function + * expression used as a callback. + * + * Recognises patterns of the form: + * .(, ...callbacks) + * .(...callbacks) // only for `use` + * where `` is one of the Express HTTP verbs (or `use`) and the + * first argument (when present) is a string-literal path. + * + * For each anonymous callback at index >= 1 we synthesise a function + * entry. The last anonymous-or-named callback is treated as the route + * handler; earlier callbacks are middleware. Named identifiers in + * callback positions are recorded as explicit call edges from the + * synthesised callbacks (e.g. `authenticateToken` becomes an upstream + * dependency of the handler so call-graph based analyses see the + * relationship). + */ + /** + * Heuristic: does `receiver` look like an Express app / router? + * + * We accept identifiers whose name ends with or contains one of the common + * Express app/router stems (case-insensitive), and chained calls like + * `app.route(...)` or `router.route(...)`. We deliberately reject other + * receivers so generic `.get(...)` calls on caches / clients / query-builders + * aren't misread as routes. + * + * Accepted stems: app, router, routes, server, web, api, endpoints, controller. + * Codebases using single-word identifiers outside this list (e.g. `http`) will + * not be extracted; add the stem here if needed. + */ + // Stems that strongly suggest an Express app/router object. + static EXPRESS_RECEIVER_STEMS = + "app|router|routes|server|web|api|endpoints|controller"; + + _isPlausibleExpressReceiver(receiver) { + if (!receiver) return false; + const kind = receiver.getKindName(); + const stems = TypeScriptAnalyzer.EXPRESS_RECEIVER_STEMS; + + if (kind === "Identifier") { + const name = receiver.getText().toLowerCase(); + // Accept exact stems, suffix matches (myApp), and underscore-prefixed + // variants (app_server) while rejecting generic short names. + return new RegExp(`(^|_)(${stems})(\\d|$|_)`).test(name) + || new RegExp(`(${stems})$`).test(name); + } + if (kind === "CallExpression") { + // e.g. app.route('/x').get(...) — receiver is the .route() call + const inner = receiver.getExpression && receiver.getExpression(); + if (inner && inner.getKindName && inner.getKindName() === "PropertyAccessExpression") { + const innerName = inner.getName && inner.getName(); + if (innerName === "route" || innerName === "Router") return true; + } + return false; + } + if (kind === "PropertyAccessExpression") { + // e.g. this.app.get(...) or express.Router().get(...) — accept when + // the trailing identifier matches our identifier pattern. + const trailing = receiver.getName && receiver.getName(); + if (!trailing) return false; + const lower = trailing.toLowerCase(); + return new RegExp(`(${stems})$`).test(lower); + } + return false; + } + + _extractExpressRouteCallbacks(sourceFile, relativePath) { + const callExpressions = sourceFile + .getDescendantsOfKind(ts.SyntaxKind.CallExpression); + + for (const callExpr of callExpressions) { + const expression = callExpr.getExpression(); + if (!expression || expression.getKindName() !== "PropertyAccessExpression") { + continue; + } + + const methodName = expression.getName ? expression.getName() : null; + if (!methodName || !TypeScriptAnalyzer.EXPRESS_VERBS.has(methodName)) { + continue; + } + + // Filter to plausibly-Express receivers. Without this we'd match any + // `foo.get('x', () => {})` style call (e.g. cache lookups, query + // builders) and synthesise bogus route units. + const receiver = expression.getExpression + ? expression.getExpression() + : null; + if (!this._isPlausibleExpressReceiver(receiver)) { + continue; + } + + const args = callExpr.getArguments(); + if (args.length === 0) continue; + + // Determine whether the first argument is a path string literal. + const firstArg = args[0]; + const firstKind = firstArg.getKindName(); + let httpPath = null; + let callbackStartIndex = 0; + if (firstKind === "StringLiteral" || firstKind === "NoSubstitutionTemplateLiteral") { + httpPath = firstArg.getLiteralValue + ? firstArg.getLiteralValue() + : firstArg.getText().slice(1, -1); + callbackStartIndex = 1; + } else if (methodName === "use") { + // `app.use(middleware)` — no path, all args are callbacks. + httpPath = null; + callbackStartIndex = 0; + } else { + // Not an Express-shaped call (no string path and not `use`). + continue; + } + + // Gather the callback arguments (functions + named identifiers). + const callbacks = args.slice(callbackStartIndex); + if (callbacks.length === 0) continue; + + // We only emit units when at least one callback is an inline + // anonymous function. Otherwise the existing extraction logic + // already handles named handlers. + const hasInline = callbacks.some((a) => { + const k = a.getKindName(); + return k === "ArrowFunction" || k === "FunctionExpression"; + }); + if (!hasInline) continue; + + const httpMethod = methodName.toUpperCase(); + const lastCallbackIndex = callbacks.length - 1; + + // Collect named middleware identifiers (Identifier / PropertyAccess) + // that appear as siblings in the args list. They become explicit + // call-graph edges from each synthesised callback. + const namedMiddleware = []; + for (let i = 0; i < callbacks.length; i++) { + const arg = callbacks[i]; + const k = arg.getKindName(); + if (k === "Identifier") { + namedMiddleware.push(arg.getText()); + } else if (k === "PropertyAccessExpression") { + // Stores only the trailing name (e.g. "auth" from "middleware.auth"). + // dependency_resolver._resolveCall looks up by simple name, so if + // another unrelated function shares the same name the edge may + // resolve to the wrong target (silent false-positive). This is a + // known limitation of the current simple-name resolution model. + const name = arg.getName ? arg.getName() : arg.getText(); + namedMiddleware.push(name); + } + } + + for (let i = 0; i < callbacks.length; i++) { + const arg = callbacks[i]; + const k = arg.getKindName(); + if (k !== "ArrowFunction" && k !== "FunctionExpression") continue; + + // Only emit for *anonymous* function expressions. A function + // expression with a name like `function named(req,res){}` is + // already extracted elsewhere. + if (k === "FunctionExpression" && arg.getName && arg.getName()) { + continue; + } + + const isHandler = i === lastCallbackIndex; + const role = isHandler ? "handler" : `middleware:${i}`; + const pathLabel = httpPath !== null ? httpPath : ""; + const baseName = pathLabel + ? `${httpMethod} ${pathLabel} [${role}]` + : `${httpMethod} [${role}]`; + const synthName = baseName; + + const code = arg.getFullText(); + const startLine = arg.getStartLineNumber(); + const endLine = arg.getEndLineNumber(); + // Synthesise an ID that's stable per file/line so two routes on + // the same line+path don't collide. + const idSuffix = `${httpMethod}:${pathLabel}:${startLine}:${i}`; + const functionId = `${relativePath}:express(${idSuffix})`; + + if (this.functions[functionId]) continue; + + const unitType = isHandler ? "route_handler" : "route_middleware"; + const explicitCalls = namedMiddleware.filter((n) => n && n !== synthName); + + this.functions[functionId] = { + name: synthName, + code: code, + isExported: false, + unitType: unitType, + startLine: startLine, + endLine: endLine, + isEntryPoint: isHandler, + routeMetadata: { + http_method: httpMethod, + http_path: httpPath, + callback_index: i, + total_callbacks: callbacks.length, + named_middleware: explicitCalls, + }, + explicitCalls: explicitCalls, + }; + + // Emit a callGraph entry for the synthesised callback so the + // invariant `callGraph keys ≡ functions keys` holds. The named + // middleware identifiers are recorded as upstream dependencies via + // explicitCalls (merged downstream by dependency_resolver.js); here + // we capture any inline call expressions from the callback body so + // call-graph based analyses can see them too. + this.callGraph[functionId] = this.extractCallsFromFunction( + arg, + relativePath, + ); + } + } } /** diff --git a/libs/openant-core/parsers/javascript/unit_generator.js b/libs/openant-core/parsers/javascript/unit_generator.js index 3650792..7b76219 100644 --- a/libs/openant-core/parsers/javascript/unit_generator.js +++ b/libs/openant-core/parsers/javascript/unit_generator.js @@ -239,6 +239,19 @@ class UnitGenerator { unitType = 'route_handler'; } + // If the analyzer attached Express route metadata directly to the + // function (anonymous arrow handler / middleware), surface it on the + // unit's `route` field even when no external routes.json was given. + if (!routeData && funcData.routeMetadata) { + const meta = funcData.routeMetadata; + routeData = { + method: meta.http_method, + path: meta.http_path, + handler: funcData.name, + middleware: meta.named_middleware || [], + }; + } + // Get upstream dependencies (functions this calls) const upstreamIds = this.resolver.getDependencies(functionId); const upstreamDependencies = []; @@ -314,6 +327,10 @@ class UnitGenerator { handler: routeData.handler, middleware: routeData.middleware || [] } : null, + is_entry_point: funcData.isEntryPoint === true ? true : undefined, + http_method: funcData.routeMetadata ? funcData.routeMetadata.http_method : undefined, + http_path: funcData.routeMetadata ? funcData.routeMetadata.http_path : undefined, + callback_index: funcData.routeMetadata ? funcData.routeMetadata.callback_index : undefined, ground_truth: { status: 'UNKNOWN', vulnerability_types: [], diff --git a/libs/openant-core/parsers/php/test_pipeline.py b/libs/openant-core/parsers/php/test_pipeline.py index 7529ea9..32d269e 100644 --- a/libs/openant-core/parsers/php/test_pipeline.py +++ b/libs/openant-core/parsers/php/test_pipeline.py @@ -42,10 +42,11 @@ from enum import Enum from pathlib import Path from typing import Set -from utilities.file_io import open_utf8, read_json, run_utf8, write_json -# Add parent directory to path for utilities import +# Add parent directory to path so utilities/ imports resolve when this script +# is invoked as a subprocess by core/parser_adapter.py (cwd may not include it). sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +from utilities.file_io import open_utf8, read_json, run_utf8, write_json from utilities.context_enhancer import ContextEnhancer from utilities.agentic_enhancer import EntryPointDetector, ReachabilityAnalyzer @@ -184,6 +185,10 @@ def run_parser_pipeline(self) -> bool: analyzer_output = generator.generate_analyzer_output() write_json(self.analyzer_output_file, analyzer_output) + # Write call graph for post-LLM reachability re-filtering + call_graph_file = os.path.join(self.output_dir, 'call_graph.json') + write_json(call_graph_file, graph_result) + elapsed = (datetime.now() - start_time).total_seconds() summary = { diff --git a/libs/openant-core/parsers/ruby/test_pipeline.py b/libs/openant-core/parsers/ruby/test_pipeline.py index 947d495..01e29d5 100644 --- a/libs/openant-core/parsers/ruby/test_pipeline.py +++ b/libs/openant-core/parsers/ruby/test_pipeline.py @@ -42,10 +42,11 @@ from enum import Enum from pathlib import Path from typing import Set -from utilities.file_io import open_utf8, read_json, run_utf8, write_json -# Add parent directory to path for utilities import +# Add parent directory to path so utilities/ imports resolve when this script +# is invoked as a subprocess by core/parser_adapter.py (cwd may not include it). sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +from utilities.file_io import open_utf8, read_json, run_utf8, write_json from utilities.context_enhancer import ContextEnhancer from utilities.agentic_enhancer import EntryPointDetector, ReachabilityAnalyzer @@ -184,6 +185,10 @@ def run_parser_pipeline(self) -> bool: analyzer_output = generator.generate_analyzer_output() write_json(self.analyzer_output_file, analyzer_output) + # Write call graph for post-LLM reachability re-filtering + call_graph_file = os.path.join(self.output_dir, 'call_graph.json') + write_json(call_graph_file, graph_result) + elapsed = (datetime.now() - start_time).total_seconds() summary = { diff --git a/libs/openant-core/tests/parsers/javascript/__init__.py b/libs/openant-core/tests/parsers/javascript/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/openant-core/tests/parsers/javascript/test_express_route_handlers.py b/libs/openant-core/tests/parsers/javascript/test_express_route_handlers.py new file mode 100644 index 0000000..804e207 --- /dev/null +++ b/libs/openant-core/tests/parsers/javascript/test_express_route_handlers.py @@ -0,0 +1,400 @@ +"""Tests for Express anonymous route handler extraction in the JS parser. + +These exercise the typescript_analyzer.js + unit_generator.js pipeline by +running the Node.js scripts as subprocesses (mirroring tests/test_js_parser.py). + +Skips when Node.js or the parser's npm dependencies aren't installed. +""" +import json +import shutil +import subprocess +from pathlib import Path + +import pytest + + +PARSERS_JS_DIR = Path(__file__).parent.parent.parent.parent / "parsers" / "javascript" +NODE_MODULES = PARSERS_JS_DIR / "node_modules" + +pytestmark = pytest.mark.skipif( + not shutil.which("node") or not NODE_MODULES.exists(), + reason="Node.js or JS parser npm dependencies not available", +) + + +def _run_node(script_name, *args): + cmd = ["node", str(PARSERS_JS_DIR / script_name)] + list(args) + return subprocess.run(cmd, capture_output=True, text=True, timeout=30) + + +def _analyze(repo_path, file_path): + """Run the analyzer on a single file and return parsed output.""" + result = _run_node("typescript_analyzer.js", str(repo_path), str(file_path)) + assert result.returncode == 0, ( + f"analyzer failed:\nstdout={result.stdout}\nstderr={result.stderr}" + ) + return json.loads(result.stdout) + + +def _generate_units(analyzer_output_path, dataset_output_path): + result = _run_node( + "unit_generator.js", + str(analyzer_output_path), + "--output", str(dataset_output_path), + ) + assert result.returncode == 0, ( + f"unit_generator failed:\nstdout={result.stdout}\nstderr={result.stderr}" + ) + return json.loads(Path(dataset_output_path).read_text()) + + +def _write_fixture(tmp_path: Path, name: str, content: str) -> Path: + repo = tmp_path / name + repo.mkdir(parents=True, exist_ok=True) + file_path = repo / "server.js" + file_path.write_text(content) + return file_path + + +def _express_units(dataset): + return [u for u in dataset["units"] if "express(" in u["id"]] + + +def test_anonymous_handler_with_named_middleware(tmp_path): + """router.post(path, namedMiddleware, async (req, res) => {...}).""" + file_path = _write_fixture( + tmp_path, + "anon_with_mw", + """ +const express = require('express'); +const router = express.Router(); + +function authenticateToken(req, res, next) { next(); } + +router.post('/orders', authenticateToken, async (req, res) => { + const { productId, quantity } = req.body; + res.json({ productId, quantity }); +}); + +module.exports = router; +""", + ) + repo = file_path.parent + out = _analyze(repo, file_path) + + express_funcs = {k: v for k, v in out["functions"].items() if "express(" in k} + assert len(express_funcs) == 1, f"expected 1 anon handler, got {express_funcs}" + + fid, fdata = next(iter(express_funcs.items())) + assert fdata["unitType"] == "route_handler" + assert fdata["isEntryPoint"] is True + meta = fdata["routeMetadata"] + assert meta["http_method"] == "POST" + assert meta["http_path"] == "/orders" + assert meta["named_middleware"] == ["authenticateToken"] + + # Run unit_generator and verify the call-graph edge to authenticateToken. + analyzer_path = tmp_path / "analyzer.json" + analyzer_path.write_text(json.dumps(out)) + dataset_path = tmp_path / "dataset.json" + dataset = _generate_units(analyzer_path, dataset_path) + + handler_unit = next(u for u in dataset["units"] if u["id"] == fid) + assert handler_unit["unit_type"] == "route_handler" + assert handler_unit["is_entry_point"] is True + assert handler_unit["http_method"] == "POST" + assert handler_unit["http_path"] == "/orders" + assert handler_unit["route"]["method"] == "POST" + assert handler_unit["route"]["path"] == "/orders" + assert handler_unit["route"]["middleware"] == ["authenticateToken"] + + # Call-graph edge: handler -> authenticateToken + upstream_ids = handler_unit["metadata"]["direct_calls"] + auth_id = "server.js:authenticateToken" + assert auth_id in upstream_ids, ( + f"expected handler to call authenticateToken; direct_calls={upstream_ids}" + ) + + +def test_handler_no_middleware(tmp_path): + """app.get(path, (req, res) => res.json([])) — no extra edges.""" + file_path = _write_fixture( + tmp_path, + "no_mw", + """ +const express = require('express'); +const app = express(); +app.get('/users', (req, res) => res.json([])); +module.exports = app; +""", + ) + repo = file_path.parent + out = _analyze(repo, file_path) + express_funcs = {k: v for k, v in out["functions"].items() if "express(" in k} + assert len(express_funcs) == 1 + fid, fdata = next(iter(express_funcs.items())) + meta = fdata["routeMetadata"] + assert meta["http_method"] == "GET" + assert meta["http_path"] == "/users" + assert meta["named_middleware"] == [] + assert fdata["isEntryPoint"] is True + + +def test_use_with_multiple_anonymous_callbacks(tmp_path): + """router.use(path, anonMw1, anonMw2, anonHandler) — + one route_handler + two route_middleware units.""" + file_path = _write_fixture( + tmp_path, + "use_multi", + """ +const express = require('express'); +const router = express.Router(); + +router.use('/api', + (req, res, next) => { req.start = Date.now(); next(); }, + (req, res, next) => { console.log(req.path); next(); }, + async (req, res, next) => { + if (!req.headers.authorization) return res.status(401).end(); + next(); + } +); + +module.exports = router; +""", + ) + repo = file_path.parent + out = _analyze(repo, file_path) + express_funcs = {k: v for k, v in out["functions"].items() if "express(" in k} + assert len(express_funcs) == 3, f"expected 3 callbacks, got {list(express_funcs)}" + + by_type = {} + for fdata in express_funcs.values(): + by_type.setdefault(fdata["unitType"], []).append(fdata) + + assert len(by_type.get("route_handler", [])) == 1 + assert len(by_type.get("route_middleware", [])) == 2 + + handler = by_type["route_handler"][0] + assert handler["isEntryPoint"] is True + assert handler["routeMetadata"]["http_method"] == "USE" + assert handler["routeMetadata"]["http_path"] == "/api" + + for mw in by_type["route_middleware"]: + assert mw["isEntryPoint"] is False or mw.get("isEntryPoint") is None + assert mw["routeMetadata"]["http_method"] == "USE" + assert mw["routeMetadata"]["http_path"] == "/api" + assert mw["routeMetadata"]["callback_index"] < 2 + + +def test_non_express_call_is_skipped(tmp_path): + """myCache.get('foo', () => {}) must not be claimed as a route.""" + file_path = _write_fixture( + tmp_path, + "non_express", + """ +const myCache = makeCache(); +myCache.get('foo', () => { return 1; }); +const queryBuilder = makeBuilder(); +queryBuilder.post('users', () => {}); +""", + ) + repo = file_path.parent + out = _analyze(repo, file_path) + express_funcs = {k: v for k, v in out["functions"].items() if "express(" in k} + assert express_funcs == {}, ( + f"non-Express receivers must not be extracted; got {list(express_funcs)}" + ) + + +def test_synthetic_handlers_have_call_graph_entries(tmp_path): + """Synthetic Express handlers must also appear as callGraph keys. + + Regression for the invariant `len(callGraph) == len(functions)` that + other tests (e.g. test_js_parser.test_builds_call_graph) rely on. + """ + file_path = _write_fixture( + tmp_path, + "callgraph_invariant", + """ +const express = require('express'); +const router = express.Router(); + +function authenticateToken(req, res, next) { next(); } + +router.post('/orders', authenticateToken, async (req, res) => { + const { productId, quantity } = req.body; + res.json({ productId, quantity }); +}); + +module.exports = router; +""", + ) + repo = file_path.parent + out = _analyze(repo, file_path) + + express_funcs = {k: v for k, v in out["functions"].items() if "express(" in k} + assert len(express_funcs) == 1 + + # Every synthetic Express function must have a callGraph entry. + for fid in express_funcs: + assert fid in out["callGraph"], ( + f"synthetic function {fid} missing from callGraph; " + f"callGraph keys={list(out['callGraph'])}" + ) + + # Global invariant: callGraph keys ≡ functions keys. + assert len(out["callGraph"]) == len(out["functions"]), ( + f"callGraph/functions size mismatch: " + f"{len(out['callGraph'])} vs {len(out['functions'])}" + ) + + +def test_typescript_typed_callback(tmp_path): + """TS callback with type annotations: + `(req: Request, res: Response, next: NextFunction) => {...}`. + + Type annotations on the parameters and return type must not prevent + the AST walk from recognising the callback as an ArrowFunction. + """ + repo = tmp_path / "ts_typed" + repo.mkdir(parents=True, exist_ok=True) + file_path = repo / "server.ts" + file_path.write_text( + """ +import express, { Request, Response, NextFunction } from 'express'; +const app = express(); + +function authenticateToken(req: Request, res: Response, next: NextFunction): void { next(); } + +app.post('/orders', authenticateToken, async (req: Request, res: Response): Promise => { + const { productId, quantity } = req.body; + res.json({ productId, quantity }); +}); + +export default app; +""" + ) + out = _analyze(repo, file_path) + express_funcs = {k: v for k, v in out["functions"].items() if "express(" in k} + assert len(express_funcs) == 1, ( + f"expected 1 anon TS handler, got {express_funcs}" + ) + fid, fdata = next(iter(express_funcs.items())) + assert fdata["unitType"] == "route_handler" + assert fdata["isEntryPoint"] is True + meta = fdata["routeMetadata"] + assert meta["http_method"] == "POST" + assert meta["http_path"] == "/orders" + assert meta["named_middleware"] == ["authenticateToken"] + + +def test_dynamic_path_does_not_crash(tmp_path): + """`app.get('/' + prefix, handler)` — first arg isn't a string literal. + + The extractor should skip such calls without throwing. We can't + reliably extract a path from a runtime-built expression. + """ + file_path = _write_fixture( + tmp_path, + "dynamic_path", + """ +const express = require('express'); +const app = express(); +const prefix = 'foo'; +app.get('/' + prefix, (req, res) => res.send('ok')); +module.exports = app; +""", + ) + repo = file_path.parent + out = _analyze(repo, file_path) + express_funcs = {k: v for k, v in out["functions"].items() if "express(" in k} + assert express_funcs == {}, ( + f"dynamic path should be skipped, got {list(express_funcs)}" + ) + + +def test_use_no_path_anonymous_middleware(tmp_path): + """`app.use((req, res, next) => {...})` — middleware with no path. + + The synthetic unit should be emitted with http_path=null and + http_method='USE'. + """ + file_path = _write_fixture( + tmp_path, + "use_no_path", + """ +const express = require('express'); +const app = express(); +app.use((req, res, next) => { req.start = Date.now(); next(); }); +module.exports = app; +""", + ) + repo = file_path.parent + out = _analyze(repo, file_path) + express_funcs = {k: v for k, v in out["functions"].items() if "express(" in k} + assert len(express_funcs) == 1, ( + f"expected 1 anon middleware unit, got {list(express_funcs)}" + ) + fid, fdata = next(iter(express_funcs.items())) + meta = fdata["routeMetadata"] + assert meta["http_method"] == "USE" + assert meta["http_path"] is None + + +def test_anon_middleware_named_handler_mixed(tmp_path): + """`app.get(path, anonMw, namedHandler)` — anon middleware before + named handler. Anon gets a route_middleware unit; the named handler + is left to the regular extractor (no synthetic unit for it).""" + file_path = _write_fixture( + tmp_path, + "mixed", + """ +const express = require('express'); +const app = express(); +function namedHandler(req, res) { res.send('ok'); } +app.get('/x', (req, res, next) => { console.log('mw'); next(); }, namedHandler); +module.exports = app; +""", + ) + repo = file_path.parent + out = _analyze(repo, file_path) + express_funcs = {k: v for k, v in out["functions"].items() if "express(" in k} + assert len(express_funcs) == 1, ( + f"expected 1 anon middleware unit, got {list(express_funcs)}" + ) + fid, fdata = next(iter(express_funcs.items())) + assert fdata["unitType"] == "route_middleware" + # named_middleware should include the namedHandler sibling + assert fdata["routeMetadata"]["named_middleware"] == ["namedHandler"] + # namedHandler must still be extracted normally + assert any( + f.get("name") == "namedHandler" for f in out["functions"].values() + ) + + +def test_named_handler_no_anonymous_unit(tmp_path): + """router.get('/x', namedHandler) — no anon unit synthesised.""" + file_path = _write_fixture( + tmp_path, + "named", + """ +const express = require('express'); +const router = express.Router(); + +function namedHandler(req, res) { res.send('ok'); } + +router.get('/x', namedHandler); + +module.exports = router; +""", + ) + repo = file_path.parent + out = _analyze(repo, file_path) + express_funcs = {k: v for k, v in out["functions"].items() if "express(" in k} + assert express_funcs == {}, ( + f"named-only callbacks must not synthesise anon units; got {list(express_funcs)}" + ) + # namedHandler should still be picked up by the regular extractor. + assert any( + f.get("name") == "namedHandler" for f in out["functions"].values() + ) diff --git a/libs/openant-core/tests/test_call_graph_output.py b/libs/openant-core/tests/test_call_graph_output.py new file mode 100644 index 0000000..8aac7a6 --- /dev/null +++ b/libs/openant-core/tests/test_call_graph_output.py @@ -0,0 +1,422 @@ +"""Tests that each parser writes call_graph.json to the output directory. + +The call_graph.json file is required by apply_reachability_filter (and the +post-LLM re-filter path) so it must be present regardless of processing_level, +including when --llm-reachability causes a parse with processing_level="all". + +Structure expected by apply_reachability_filter: + { + "functions": {: {}, ...}, + "call_graph": {: [, ...], ...}, + "reverse_call_graph": {: [, ...], ...}, + } + +Parser availability gates (identical to patterns used in test_js_parser.py): +- Python: always available +- JavaScript: requires Node.js + parsers/javascript/node_modules +- Go: requires parsers/go/go_parser/go_parser binary +- C: requires tree_sitter_c Python package +- Ruby: requires tree_sitter_ruby Python package +- PHP: requires tree_sitter_php Python package +""" + +from __future__ import annotations + +import json +import shutil +import sys +from pathlib import Path + +import pytest + +from core.parser_adapter import apply_reachability_filter, parse_repository + +TESTS_DIR = Path(__file__).parent +FIXTURES_DIR = TESTS_DIR / "fixtures" +PARSERS_DIR = Path(__file__).parent.parent / "parsers" + +# --------------------------------------------------------------------------- +# Availability checks (used by skipif marks) +# --------------------------------------------------------------------------- + +def _node_available() -> bool: + return bool(shutil.which("node")) and (PARSERS_DIR / "javascript" / "node_modules").exists() + +def _go_parser_available() -> bool: + go_dir = PARSERS_DIR / "go" / "go_parser" + # Check both Unix and Windows binary names. + candidates = [go_dir / "go_parser", go_dir / "go_parser.exe"] + binary = next((p for p in candidates if p.exists() and p.stat().st_size > 0), None) + if binary is None: + return False + import subprocess + try: + subprocess.run([str(binary), "--help"], capture_output=True, timeout=5) + return True + except (OSError, subprocess.TimeoutExpired): + return False + +def _ts_c_available() -> bool: + try: + import tree_sitter_c # noqa: F401 + return True + except ImportError: + return False + +def _ts_ruby_available() -> bool: + try: + import tree_sitter_ruby # noqa: F401 + return True + except ImportError: + return False + +def _ts_php_available() -> bool: + try: + import tree_sitter_php # noqa: F401 + return True + except ImportError: + return False + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_REQUIRED_KEYS = {"functions", "call_graph", "reverse_call_graph"} + + +def _assert_call_graph_valid(output_dir: str) -> dict: + """Load call_graph.json from output_dir and assert it has the right shape.""" + cg_path = Path(output_dir) / "call_graph.json" + assert cg_path.exists(), f"call_graph.json not found in {output_dir}" + with open(cg_path) as f: + data = json.load(f) + assert _REQUIRED_KEYS <= data.keys(), ( + f"call_graph.json missing keys: {_REQUIRED_KEYS - data.keys()}" + ) + assert isinstance(data["functions"], dict) + assert isinstance(data["call_graph"], dict) + assert isinstance(data["reverse_call_graph"], dict) + return data + + +# --------------------------------------------------------------------------- +# apply_reachability_filter unit tests (always run — no external deps) +# --------------------------------------------------------------------------- + + +class TestApplyReachabilityFilterPublicAPI: + """apply_reachability_filter is the consumer of call_graph.json. + These tests verify it works correctly with a synthetic fixture.""" + + def _make_call_graph_json(self, tmp_path: Path) -> None: + """Write a minimal call_graph.json that apply_reachability_filter can parse. + + route_handler uses the ``@app.route`` decorator pattern that + EntryPointDetector recognises, making it a structural entry point. + """ + cg = { + "functions": { + "app.py:route_handler": { + "name": "route_handler", + "filePath": "app.py", + "unitType": "function", + "isExported": False, + "decorators": ["@app.route('/foo')"], + }, + "app.py:helper": { + "name": "helper", + "filePath": "app.py", + "unitType": "function", + "isExported": False, + "decorators": [], + }, + "app.py:orphan": { + "name": "orphan", + "filePath": "app.py", + "unitType": "function", + "isExported": False, + "decorators": [], + }, + }, + "call_graph": { + "app.py:route_handler": ["app.py:helper"], + }, + "reverse_call_graph": { + "app.py:helper": ["app.py:route_handler"], + }, + } + (tmp_path / "call_graph.json").write_text(json.dumps(cg)) + + def _make_dataset(self, unit_ids: list[str]) -> dict: + return { + "units": [ + {"id": uid, "code": {"primary_code": "pass"}, "unit_type": "function"} + for uid in unit_ids + ] + } + + def test_filters_to_reachable_units(self, tmp_path): + self._make_call_graph_json(tmp_path) + dataset = self._make_dataset( + ["app.py:route_handler", "app.py:helper", "app.py:orphan"] + ) + result = apply_reachability_filter(dataset, str(tmp_path), "reachable") + unit_ids = {u["id"] for u in result["units"]} + assert "app.py:route_handler" in unit_ids + assert "app.py:helper" in unit_ids + assert "app.py:orphan" not in unit_ids + + def test_extra_entry_points_expand_reachable_set(self, tmp_path): + self._make_call_graph_json(tmp_path) + dataset = self._make_dataset( + ["app.py:route_handler", "app.py:helper", "app.py:orphan"] + ) + # Promote orphan as an extra entry point (simulating LLM signal). + result = apply_reachability_filter( + dataset, str(tmp_path), "reachable", + extra_entry_points={"app.py:orphan"}, + ) + unit_ids = {u["id"] for u in result["units"]} + assert "app.py:orphan" in unit_ids + + def test_is_entry_point_set_on_structural_entry_points(self, tmp_path): + self._make_call_graph_json(tmp_path) + dataset = self._make_dataset(["app.py:route_handler", "app.py:helper"]) + result = apply_reachability_filter(dataset, str(tmp_path), "reachable") + by_id = {u["id"]: u for u in result["units"]} + assert by_id["app.py:route_handler"]["is_entry_point"] is True + assert by_id["app.py:helper"]["is_entry_point"] is False + + def test_llm_promoted_is_entry_point_preserved(self, tmp_path): + self._make_call_graph_json(tmp_path) + dataset = self._make_dataset(["app.py:route_handler", "app.py:helper"]) + # Pre-set is_entry_point=True on helper (simulating LLM promotion). + dataset["units"][1]["is_entry_point"] = True + result = apply_reachability_filter( + dataset, str(tmp_path), "reachable", + extra_entry_points={"app.py:helper"}, + ) + by_id = {u["id"]: u for u in result["units"]} + assert by_id["app.py:helper"]["is_entry_point"] is True + + def test_missing_call_graph_returns_dataset_unchanged(self, tmp_path): + dataset = self._make_dataset(["app.py:route_handler"]) + result = apply_reachability_filter(dataset, str(tmp_path), "reachable") + assert len(result["units"]) == 1 + + +# --------------------------------------------------------------------------- +# Python parser — always runs +# --------------------------------------------------------------------------- + + +class TestPythonCallGraphOutput: + def test_call_graph_json_written(self, sample_python_repo, tmp_output_dir): + parse_repository( + repo_path=sample_python_repo, + output_dir=tmp_output_dir, + language="python", + processing_level="all", + ) + _assert_call_graph_valid(tmp_output_dir) + + def test_call_graph_json_written_with_reachable_level( + self, sample_python_repo, tmp_output_dir + ): + parse_repository( + repo_path=sample_python_repo, + output_dir=tmp_output_dir, + language="python", + processing_level="reachable", + ) + _assert_call_graph_valid(tmp_output_dir) + + +# --------------------------------------------------------------------------- +# JavaScript parser +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not _node_available(), reason="Node.js or JS parser npm deps not available") +class TestJavaScriptCallGraphOutput: + def test_call_graph_json_written(self, sample_js_repo, tmp_output_dir): + parse_repository( + repo_path=sample_js_repo, + output_dir=tmp_output_dir, + language="javascript", + processing_level="all", + ) + _assert_call_graph_valid(tmp_output_dir) + + def test_call_graph_json_written_with_reachable_level( + self, sample_js_repo, tmp_output_dir + ): + parse_repository( + repo_path=sample_js_repo, + output_dir=tmp_output_dir, + language="javascript", + processing_level="reachable", + ) + _assert_call_graph_valid(tmp_output_dir) + + +# --------------------------------------------------------------------------- +# Go parser +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_go_repo(tmp_path): + """Minimal Go repository fixture.""" + repo = tmp_path / "go_repo" + repo.mkdir() + (repo / "go.mod").write_text("module example.com/myapp\n\ngo 1.21\n") + (repo / "main.go").write_text( + 'package main\n\nimport "fmt"\n\n' + "func main() {\n\tgreet()\n}\n\n" + 'func greet() {\n\tfmt.Println("hello")\n}\n' + ) + return str(repo) + + +@pytest.mark.skipif(not _go_parser_available(), reason="go_parser binary not available") +class TestGoCallGraphOutput: + def test_call_graph_json_written(self, sample_go_repo, tmp_output_dir): + parse_repository( + repo_path=sample_go_repo, + output_dir=tmp_output_dir, + language="go", + processing_level="all", + ) + _assert_call_graph_valid(tmp_output_dir) + + def test_call_graph_json_written_with_reachable_level( + self, sample_go_repo, tmp_output_dir + ): + parse_repository( + repo_path=sample_go_repo, + output_dir=tmp_output_dir, + language="go", + processing_level="reachable", + ) + _assert_call_graph_valid(tmp_output_dir) + + +# --------------------------------------------------------------------------- +# C parser +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_c_repo(tmp_path): + """Minimal C repository fixture.""" + repo = tmp_path / "c_repo" + repo.mkdir() + (repo / "main.c").write_text( + "#include \n\nvoid greet() {\n printf(\"hello\\n\");\n}\n\n" + "int main() {\n greet();\n return 0;\n}\n" + ) + return str(repo) + + +@pytest.mark.skipif(not _ts_c_available(), reason="tree_sitter_c not installed") +class TestCCallGraphOutput: + def test_call_graph_json_written(self, sample_c_repo, tmp_output_dir): + parse_repository( + repo_path=sample_c_repo, + output_dir=tmp_output_dir, + language="c", + processing_level="all", + ) + _assert_call_graph_valid(tmp_output_dir) + + def test_call_graph_json_written_with_reachable_level( + self, sample_c_repo, tmp_output_dir + ): + parse_repository( + repo_path=sample_c_repo, + output_dir=tmp_output_dir, + language="c", + processing_level="reachable", + ) + _assert_call_graph_valid(tmp_output_dir) + + +# --------------------------------------------------------------------------- +# Ruby parser +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_ruby_repo(tmp_path): + """Minimal Ruby repository fixture.""" + repo = tmp_path / "ruby_repo" + repo.mkdir() + (repo / "app.rb").write_text( + "def greet\n puts 'hello'\nend\n\ndef main\n greet\nend\n" + ) + return str(repo) + + +@pytest.mark.skipif(not _ts_ruby_available(), reason="tree_sitter_ruby not installed") +class TestRubyCallGraphOutput: + def test_call_graph_json_written(self, sample_ruby_repo, tmp_output_dir): + parse_repository( + repo_path=sample_ruby_repo, + output_dir=tmp_output_dir, + language="ruby", + processing_level="all", + ) + _assert_call_graph_valid(tmp_output_dir) + + def test_call_graph_json_written_with_reachable_level( + self, sample_ruby_repo, tmp_output_dir + ): + parse_repository( + repo_path=sample_ruby_repo, + output_dir=tmp_output_dir, + language="ruby", + processing_level="reachable", + ) + _assert_call_graph_valid(tmp_output_dir) + + +# --------------------------------------------------------------------------- +# PHP parser +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_php_repo(tmp_path): + """Minimal PHP repository fixture.""" + repo = tmp_path / "php_repo" + repo.mkdir() + (repo / "index.php").write_text( + " None: + """Create a file with ``content`` at ``p``, including parent dirs.""" + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text(content) + + +class TestDetectLanguagePython: + def test_single_python_file(self, tmp_path: Path) -> None: + _write(tmp_path / "main.py", "print('hi')\n") + assert detect_language(str(tmp_path)) == "python" + + def test_dominant_python_with_unrelated_files(self, tmp_path: Path) -> None: + for i in range(5): + _write(tmp_path / f"mod_{i}.py") + _write(tmp_path / "README.md", "# project") + _write(tmp_path / "data.json", "{}") + assert detect_language(str(tmp_path)) == "python" + + +class TestDetectLanguageJavaScript: + def test_plain_javascript(self, tmp_path: Path) -> None: + _write(tmp_path / "index.js", "module.exports = {};\n") + _write(tmp_path / "lib.js") + assert detect_language(str(tmp_path)) == "javascript" + + def test_typescript_classified_as_javascript(self, tmp_path: Path) -> None: + # The shared config maps .ts/.tsx/.jsx/.mjs/.cjs to "javascript". + for name in ("app.ts", "comp.tsx", "old.jsx", "esm.mjs", "cjs.cjs"): + _write(tmp_path / name) + assert detect_language(str(tmp_path)) == "javascript" + + def test_typescript_dominant_over_python(self, tmp_path: Path) -> None: + for i in range(4): + _write(tmp_path / f"src_{i}.ts") + _write(tmp_path / "scripts" / "release.py") + assert detect_language(str(tmp_path)) == "javascript" + + +class TestDetectLanguageGo: + def test_single_go_file(self, tmp_path: Path) -> None: + _write(tmp_path / "main.go", "package main\n") + assert detect_language(str(tmp_path)) == "go" + + def test_go_dominant_over_other_extensions(self, tmp_path: Path) -> None: + for i in range(6): + _write(tmp_path / f"pkg_{i}.go") + _write(tmp_path / "tools" / "fix.py") + _write(tmp_path / "web" / "ui.js") + assert detect_language(str(tmp_path)) == "go" + + +class TestDetectLanguageMixed: + """Mixed-language repos must report the dominant language by file count. + + Unlike the per-language classes above which lean on skip_dirs to mask + competing extensions, these cases place real source from two languages + side-by-side at the root so the dominance heuristic itself is exercised. + """ + + def test_ts_dominant_over_python_at_root(self, tmp_path: Path) -> None: + # 6 TS source files vs 4 Python tooling files at the same level — + # mirrors a typical Node project that ships a few Python build + # scripts. No skip_dirs trickery involved. + for i in range(6): + _write(tmp_path / "src" / f"mod_{i}.ts") + for i in range(4): + _write(tmp_path / "scripts" / f"tool_{i}.py") + assert detect_language(str(tmp_path)) == "javascript" + + def test_python_dominant_over_javascript_at_root(self, tmp_path: Path) -> None: + # Inverse case: Python repo with a small JS frontend. + for i in range(7): + _write(tmp_path / f"pkg_{i}.py") + for i in range(3): + _write(tmp_path / "frontend" / f"page_{i}.js") + assert detect_language(str(tmp_path)) == "python" + + +class TestDetectLanguageSkipDirs: + def test_node_modules_ignored(self, tmp_path: Path) -> None: + # Two real .py files at the root, plus a noisy node_modules tree. + # If skip_dirs weren't honoured, JS would (wrongly) win. + _write(tmp_path / "main.py") + _write(tmp_path / "lib.py") + for i in range(20): + _write(tmp_path / "node_modules" / f"pkg_{i}" / "index.js") + assert detect_language(str(tmp_path)) == "python" + + def test_vendor_ignored(self, tmp_path: Path) -> None: + _write(tmp_path / "cmd" / "main.go") + _write(tmp_path / "internal" / "svc.go") + for i in range(20): + _write(tmp_path / "vendor" / f"dep_{i}" / "lib.py") + assert detect_language(str(tmp_path)) == "go" + + +class TestDetectLanguageEmpty: + def test_empty_directory_raises(self, tmp_path: Path) -> None: + with pytest.raises(ValueError, match="No supported source files"): + detect_language(str(tmp_path)) + + def test_only_unsupported_files_raises(self, tmp_path: Path) -> None: + _write(tmp_path / "README.md", "# hi") + _write(tmp_path / "data.json", "{}") + with pytest.raises(ValueError, match="No supported source files"): + detect_language(str(tmp_path)) + + +class TestDetectLanguageNonGit: + """Auto-detection is purely extension-based and must not require .git.""" + + def test_non_git_directory_detected(self, tmp_path: Path) -> None: + _write(tmp_path / "main.py") + assert not (tmp_path / ".git").exists() + assert detect_language(str(tmp_path)) == "python" diff --git a/libs/openant-core/tests/test_di_resolution.py b/libs/openant-core/tests/test_di_resolution.py new file mode 100644 index 0000000..70fa4ea --- /dev/null +++ b/libs/openant-core/tests/test_di_resolution.py @@ -0,0 +1,719 @@ +"""Tests for dependency injection-aware call resolution. + +Tests that the TypeScript analyzer extracts constructor parameter types +and the dependency resolver uses them to resolve DI-injected service calls. + +Requires Node.js and npm dependencies installed: + cd parsers/javascript && npm install +""" +import json +import subprocess +import shutil +from pathlib import Path + +import pytest + +PARSERS_JS_DIR = Path(__file__).parent.parent / "parsers" / "javascript" +NODE_MODULES = PARSERS_JS_DIR / "node_modules" + +pytestmark = pytest.mark.skipif( + not shutil.which("node") or not NODE_MODULES.exists(), + reason="Node.js or JS parser npm dependencies not available", +) + + +def run_node(script_name, *args): + """Run a Node.js script from the JS parsers directory.""" + cmd = ["node", str(PARSERS_JS_DIR / script_name)] + list(args) + return subprocess.run(cmd, capture_output=True, text=True, timeout=30) + + +# -- Fixture: NestJS-style DI codebase -- + +RESOLVER_TS = """\ +import { Injectable } from '@nestjs/common'; +import { CallService } from './call.service'; +import { AuthService } from './auth.service'; + +@Injectable() +export class CallResolver { + constructor( + private callService: CallService, + private authService: AuthService, + ) {} + + async getCall(id: string) { + return await this.callService.getById(id); + } + + async deleteCall(id: string) { + return await this.callService.remove(id); + } +} +""" + +SERVICE_TS = """\ +import { Injectable } from '@nestjs/common'; + +@Injectable() +export class CallService { + async getById(id: string) { + const call = await this.repository.findOne(id); + await this.authService.can('read', call); + return call; + } + + async remove(id: string) { + return await this.repository.delete(id); + } +} +""" + +AUTH_SERVICE_TS = """\ +import { Injectable } from '@nestjs/common'; + +@Injectable() +export class AuthService { + async can(action: string, resource: any) { + // authorization check + return true; + } +} +""" + +# Versioned implementation (interface CallService, impl CallServiceV2) +VERSIONED_SERVICE_TS = """\ +import { Injectable } from '@nestjs/common'; + +@Injectable() +export class CallServiceV2 { + async getById(id: string) { + return { id }; + } + + async remove(id: string) { + return true; + } +} +""" + +# Interface + implementing class for nominal type tests +ICALL_SERVICE_TS = """\ +export interface ICallService { + getById(id: string): Promise; +} +""" + +IMPL_CALL_SERVICE_TS = """\ +import { Injectable } from '@nestjs/common'; +import { ICallService } from './icall.service'; + +@Injectable() +export class CallServiceImpl implements ICallService { + async getById(id: string) { + return { id }; + } +} +""" + +NOMINAL_RESOLVER_TS = """\ +import { Injectable } from '@nestjs/common'; +import { ICallService } from './icall.service'; + +@Injectable() +export class CallResolver { + constructor(private callService: ICallService) {} + + async getCall(id: string) { + return this.callService.getById(id); + } +} +""" + + +@pytest.fixture +def nestjs_repo(tmp_path): + """Create a minimal NestJS-style repo with DI patterns.""" + src = tmp_path / "src" + src.mkdir() + (src / "call.resolver.ts").write_text(RESOLVER_TS) + (src / "call.service.ts").write_text(SERVICE_TS) + (src / "auth.service.ts").write_text(AUTH_SERVICE_TS) + return tmp_path + + +@pytest.fixture +def nestjs_repo_versioned(tmp_path): + """Create a repo where the DI type doesn't exactly match the class name.""" + src = tmp_path / "src" + src.mkdir() + (src / "call.resolver.ts").write_text(RESOLVER_TS) + (src / "call.service.ts").write_text(VERSIONED_SERVICE_TS) + return tmp_path + + +@pytest.fixture +def nestjs_repo_nominal(tmp_path): + """Create a repo where injection is via interface and impl uses implements.""" + src = tmp_path / "src" + src.mkdir() + (src / "icall.service.ts").write_text(ICALL_SERVICE_TS) + (src / "call.service.impl.ts").write_text(IMPL_CALL_SERVICE_TS) + (src / "call.resolver.ts").write_text(NOMINAL_RESOLVER_TS) + return tmp_path + + +def find_class(classes, class_name): + """Find a class entry in the file-qualified classes dict (key is "filePath:ClassName").""" + for key, val in classes.items(): + if key.endswith(':' + class_name): + return val + return None + + +def analyze_and_resolve(repo_path, files): + """Run analyzer + resolver on given files and return resolved data.""" + analyzer_out = repo_path / "analyzer_output.json" + resolved_out = repo_path / "resolved.json" + + file_paths = [str(f) for f in files] + result = run_node( + "typescript_analyzer.js", str(repo_path), + *file_paths, + "--output", str(analyzer_out), + ) + assert result.returncode == 0, f"Analyzer failed: {result.stderr}" + + result = run_node( + "dependency_resolver.js", str(analyzer_out), + "--output", str(resolved_out), + ) + assert result.returncode == 0, f"Resolver failed: {result.stderr}" + + return json.loads(resolved_out.read_text()) + + +class TestConstructorDepsExtraction: + """Test that the analyzer extracts constructorDeps into the classes table.""" + + def test_extracts_constructor_deps(self, nestjs_repo): + analyzer_out = nestjs_repo / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(nestjs_repo), + "src/call.resolver.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + + data = json.loads(analyzer_out.read_text()) + classes = data["classes"] + + call_resolver = find_class(classes, "CallResolver") + assert call_resolver is not None, "CallResolver not in classes table" + deps = call_resolver.get("constructorDeps", {}) + assert deps.get("callService") == "CallService" + assert deps.get("authService") == "AuthService" + + # Methods themselves should NOT carry constructorDeps (stored in classes table instead) + for fid, func in data["functions"].items(): + if "CallResolver" in fid: + assert "constructorDeps" not in func, f"{fid} should not have constructorDeps" + + def test_skips_primitive_types(self, tmp_path): + """Constructor params with primitive types should not be included.""" + src = tmp_path / "src" + src.mkdir() + (src / "example.ts").write_text("""\ +export class Example { + constructor( + private name: string, + private count: number, + private service: MyService, + ) {} + + doWork() { + return this.service.run(); + } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/example.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + + data = json.loads(analyzer_out.read_text()) + deps = (find_class(data["classes"], "Example") or {}).get("constructorDeps", {}) + # Only MyService should be captured (PascalCase), not string/number + assert "service" in deps + assert deps["service"] == "MyService" + assert "name" not in deps + assert "count" not in deps + + + def test_same_name_different_file_no_collision(self, tmp_path): + """Two classes with the same name in different files must not collide. + + Pre-fix: this.classes["UserController"] is last-write-wins, so the first + class's constructorDeps are silently overwritten and its DI calls miss. + Post-fix: both entries are keyed by "filePath:ClassName". + """ + (tmp_path / "admin").mkdir() + (tmp_path / "v2").mkdir() + (tmp_path / "admin" / "user_controller.ts").write_text("""\ +export class UserController { + constructor(private fooService: FooService) {} + getFoo() { return this.fooService.get(); } +} +""") + (tmp_path / "v2" / "user_controller.ts").write_text("""\ +export class UserController { + constructor(private barService: BarService) {} + getBar() { return this.barService.get(); } +} +""") + (tmp_path / "foo_service.ts").write_text("""\ +export class FooService { + get() { return 'foo'; } +} +""") + (tmp_path / "bar_service.ts").write_text("""\ +export class BarService { + get() { return 'bar'; } +} +""") + + # 1. Analyzer: both class entries present (no last-write-wins collision) + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "admin/user_controller.ts", + "v2/user_controller.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + data = json.loads(analyzer_out.read_text()) + classes = data["classes"] + + admin_entry = next((v for k, v in classes.items() if "admin" in k and k.endswith(":UserController")), None) + v2_entry = next((v for k, v in classes.items() if "v2" in k and k.endswith(":UserController")), None) + assert admin_entry is not None, "admin/UserController missing from classes table" + assert v2_entry is not None, "v2/UserController missing from classes table" + assert admin_entry.get("constructorDeps", {}).get("fooService") == "FooService" + assert v2_entry.get("constructorDeps", {}).get("barService") == "BarService" + + # 2. Resolver: each class resolves calls to the right service (not the other's) + data = analyze_and_resolve(tmp_path, [ + "admin/user_controller.ts", + "v2/user_controller.ts", + "foo_service.ts", + "bar_service.ts", + ]) + call_graph = data["callGraph"] + + admin_calls = next((calls for fid, calls in call_graph.items() if "admin" in fid and "UserController.getFoo" in fid), None) + v2_calls = next((calls for fid, calls in call_graph.items() if "v2" in fid and "UserController.getBar" in fid), None) + + assert admin_calls is not None, "admin/UserController.getFoo not in call graph" + assert v2_calls is not None, "v2/UserController.getBar not in call graph" + assert any("FooService.get" in c for c in admin_calls), \ + f"admin/UserController.getFoo should resolve to FooService.get, got: {admin_calls}" + assert any("BarService.get" in c for c in v2_calls), \ + f"v2/UserController.getBar should resolve to BarService.get, got: {v2_calls}" + + +class TestBaseTypesExtraction: + """Test that the analyzer extracts implements/extends into baseTypes.""" + + def test_extracts_implements(self, nestjs_repo_nominal): + analyzer_out = nestjs_repo_nominal / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(nestjs_repo_nominal), + "src/call.service.impl.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + + data = json.loads(analyzer_out.read_text()) + base_types = (find_class(data["classes"], "CallServiceImpl") or {}).get("baseTypes", []) + assert "ICallService" in base_types + + def test_generic_implements_stripped(self, tmp_path): + """implements Repository should store as Repository.""" + src = tmp_path / "src" + src.mkdir() + (src / "impl.ts").write_text("""\ +export class UserRepo implements Repository { + findOne(id: string) { return null; } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/impl.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + + data = json.loads(analyzer_out.read_text()) + base_types = (find_class(data["classes"], "UserRepo") or {}).get("baseTypes", []) + assert "Repository" in base_types + assert not any("<" in t for t in base_types) + + def test_extracts_extends(self, tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "impl.ts").write_text("""\ +export class ConcreteService extends BaseService { + run() { return true; } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/impl.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + + data = json.loads(analyzer_out.read_text()) + base_types = (find_class(data["classes"], "ConcreteService") or {}).get("baseTypes", []) + assert "BaseService" in base_types + + +class TestNominalTypeResolution: + """Test that implements/extends clauses are used for DI resolution.""" + + def test_resolves_via_implements(self, nestjs_repo_nominal): + """this.callService.getById() resolves to CallServiceImpl.getById via implements.""" + data = analyze_and_resolve(nestjs_repo_nominal, [ + "src/call.resolver.ts", + "src/call.service.impl.ts", + ]) + + call_graph = data["callGraph"] + resolver_calls = None + for fid, calls in call_graph.items(): + if "CallResolver.getCall" in fid: + resolver_calls = calls + break + + assert resolver_calls is not None, "CallResolver.getCall not in call graph" + assert any( + "CallServiceImpl.getById" in c for c in resolver_calls + ), f"Expected CallServiceImpl.getById via implements, got: {resolver_calls}" + + def test_nominal_ambiguity_skips_resolution(self, tmp_path): + """Two classes implementing same interface → no resolution (ambiguous).""" + src = tmp_path / "src" + src.mkdir() + (src / "resolver.ts").write_text("""\ +export class MyResolver { + constructor(private svc: IMyService) {} + work() { return this.svc.run(); } +} +""") + (src / "impl_a.ts").write_text("""\ +export class ImplA implements IMyService { + run() { return 'a'; } +} +""") + (src / "impl_b.ts").write_text("""\ +export class ImplB implements IMyService { + run() { return 'b'; } +} +""") + data = analyze_and_resolve(tmp_path, [ + "src/resolver.ts", + "src/impl_a.ts", + "src/impl_b.ts", + ]) + + call_graph = data["callGraph"] + resolver_calls = None + for fid, calls in call_graph.items(): + if "MyResolver.work" in fid: + resolver_calls = calls + break + + assert resolver_calls is not None + assert not any( + "ImplA.run" in c or "ImplB.run" in c for c in resolver_calls + ), f"Should not resolve ambiguous implements, got: {resolver_calls}" + + +class TestFieldDepsExtraction: + """Test that @Inject* decorators and inject() function are captured as fieldDeps.""" + + def test_extracts_inject_decorator(self, tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "service.ts").write_text("""\ +import { Injectable, Inject } from '@nestjs/common'; + +@Injectable() +export class MyService { + @Inject('TOKEN') + private depService: DepService; + + run() { return this.depService.execute(); } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/service.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + data = json.loads(analyzer_out.read_text()) + field_deps = (find_class(data["classes"], "MyService") or {}).get("fieldDeps", {}) + assert field_deps.get("depService") == "DepService" + + def test_extracts_inject_repository_decorator(self, tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "service.ts").write_text("""\ +import { Injectable } from '@nestjs/common'; +import { InjectRepository } from '@nestjs/typeorm'; + +@Injectable() +export class UserService { + @InjectRepository(User) + private userRepo: Repository; + + findAll() { return this.userRepo.find(); } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/service.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + data = json.loads(analyzer_out.read_text()) + field_deps = (find_class(data["classes"], "UserService") or {}).get("fieldDeps", {}) + assert field_deps.get("userRepo") == "Repository" + + def test_extracts_functional_inject(self, tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "component.ts").write_text("""\ +import { inject } from '@angular/core'; + +export class MyComponent { + private authService = inject(AuthService); + + login() { return this.authService.signIn(); } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/component.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + data = json.loads(analyzer_out.read_text()) + field_deps = (find_class(data["classes"], "MyComponent") or {}).get("fieldDeps", {}) + assert field_deps.get("authService") == "AuthService" + + def test_ignores_non_inject_decorator(self, tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "service.ts").write_text("""\ +export class MyService { + @Column() + private name: string; + + getName() { return this.name; } +} +""") + analyzer_out = tmp_path / "analyzer_output.json" + result = run_node( + "typescript_analyzer.js", str(tmp_path), + "src/service.ts", + "--output", str(analyzer_out), + ) + assert result.returncode == 0 + data = json.loads(analyzer_out.read_text()) + field_deps = (find_class(data["classes"], "MyService") or {}).get("fieldDeps", {}) + assert "name" not in field_deps + + def test_resolves_field_injection_calls(self, tmp_path): + """Calls via @Inject field deps resolve correctly through the full pipeline.""" + src = tmp_path / "src" + src.mkdir() + (src / "service.ts").write_text("""\ +import { Injectable, Inject } from '@nestjs/common'; + +@Injectable() +export class MyService { + @Inject('TOKEN') + private depService: DepService; + + run() { return this.depService.execute(); } +} +""") + (src / "dep_service.ts").write_text("""\ +export class DepService { + execute() { return 'done'; } +} +""") + data = analyze_and_resolve(tmp_path, [ + "src/service.ts", + "src/dep_service.ts", + ]) + call_graph = data["callGraph"] + service_calls = next( + (calls for fid, calls in call_graph.items() if "MyService.run" in fid), None + ) + assert service_calls is not None, "MyService.run not in call graph" + assert any("DepService.execute" in c for c in service_calls), \ + f"Expected DepService.execute via field injection, got: {service_calls}" + + +class TestDIAwareCallResolution: + """Test that the dependency resolver uses constructorDeps for DI resolution.""" + + def test_resolves_exact_type_match(self, nestjs_repo): + """this.callService.getById() resolves to CallService.getById.""" + data = analyze_and_resolve(nestjs_repo, [ + "src/call.resolver.ts", + "src/call.service.ts", + ]) + + call_graph = data["callGraph"] + + # Find CallResolver.getCall's call graph + resolver_calls = None + for fid, calls in call_graph.items(): + if "CallResolver.getCall" in fid: + resolver_calls = calls + break + + assert resolver_calls is not None, "CallResolver.getCall not in call graph" + assert any( + "CallService.getById" in c for c in resolver_calls + ), f"Expected CallService.getById in calls, got: {resolver_calls}" + + def test_resolves_versioned_implementation(self, nestjs_repo_versioned): + """this.callService.getById() resolves to CallServiceV2.getById via prefix match.""" + data = analyze_and_resolve(nestjs_repo_versioned, [ + "src/call.resolver.ts", + "src/call.service.ts", + ]) + + call_graph = data["callGraph"] + resolver_calls = None + for fid, calls in call_graph.items(): + if "CallResolver.getCall" in fid: + resolver_calls = calls + break + + assert resolver_calls is not None + assert any( + "CallServiceV2.getById" in c for c in resolver_calls + ), f"Expected CallServiceV2.getById in calls, got: {resolver_calls}" + + def test_resolves_multiple_di_methods(self, nestjs_repo): + """Both getById and remove should resolve to CallService methods.""" + data = analyze_and_resolve(nestjs_repo, [ + "src/call.resolver.ts", + "src/call.service.ts", + ]) + + call_graph = data["callGraph"] + + # deleteCall should resolve to CallService.remove + delete_calls = None + for fid, calls in call_graph.items(): + if "CallResolver.deleteCall" in fid: + delete_calls = calls + break + + assert delete_calls is not None + assert any( + "CallService.remove" in c for c in delete_calls + ), f"Expected CallService.remove in calls, got: {delete_calls}" + + def test_ambiguous_prefix_skips_resolution(self, tmp_path): + """When multiple classes share a type-name prefix, resolution is skipped.""" + src = tmp_path / "src" + src.mkdir() + (src / "resolver.ts").write_text("""\ +export class MyResolver { + constructor(private callService: CallService) {} + getCall(id: string) { + return this.callService.getById(id); + } +} +""") + (src / "call_service.ts").write_text("""\ +export class CallServiceV1 { + getById(id: string) { return 'v1'; } +} +""") + (src / "call_service_mock.ts").write_text("""\ +export class CallServiceMock { + getById(id: string) { return 'mock'; } +} +""") + data = analyze_and_resolve(tmp_path, [ + "src/resolver.ts", + "src/call_service.ts", + "src/call_service_mock.ts", + ]) + + call_graph = data["callGraph"] + resolver_calls = None + for fid, calls in call_graph.items(): + if "MyResolver.getCall" in fid: + resolver_calls = calls + break + + # Two classes match the CallService prefix — should not resolve to either + assert resolver_calls is not None + assert not any( + "CallServiceV1.getById" in c or "CallServiceMock.getById" in c + for c in resolver_calls + ), f"Should not resolve ambiguous prefix match, got: {resolver_calls}" + + def test_no_false_positives_without_di(self, tmp_path): + """Methods without constructor deps should not spuriously resolve.""" + src = tmp_path / "src" + src.mkdir() + (src / "plain.ts").write_text("""\ +export class PlainService { + doWork() { + return this.unknownService.process(); + } +} +""") + (src / "other.ts").write_text("""\ +export class UnknownService { + process() { + return 42; + } +} +""") + data = analyze_and_resolve(tmp_path, [ + "src/plain.ts", + "src/other.ts", + ]) + + call_graph = data["callGraph"] + plain_calls = None + for fid, calls in call_graph.items(): + if "PlainService.doWork" in fid: + plain_calls = calls + break + + # Without constructor deps, unknownService.process() should NOT resolve + assert plain_calls is not None + assert not any( + "UnknownService.process" in c for c in plain_calls + ), f"Should not resolve without DI metadata, got: {plain_calls}" diff --git a/libs/openant-core/tests/test_entry_point_detector.py b/libs/openant-core/tests/test_entry_point_detector.py new file mode 100644 index 0000000..8250485 --- /dev/null +++ b/libs/openant-core/tests/test_entry_point_detector.py @@ -0,0 +1,56 @@ +"""Tests for EntryPointDetector — specifically that Express unit types +produced by the JS analyzer are recognised as entry points and therefore +survive the reachability filter. +""" +import pytest + +from utilities.agentic_enhancer.entry_point_detector import ( + ENTRY_POINT_TYPES, + EntryPointDetector, +) + + +def _make_detector(unit_type: str) -> EntryPointDetector: + functions = { + "server.js:fn": { + "name": "fn", + "unit_type": unit_type, + "code": "async (req, res, next) => { next(); }", + } + } + return EntryPointDetector(functions, call_graph={}) + + +def test_route_handler_is_entry_point(): + detector = _make_detector("route_handler") + entry_points = detector.detect_entry_points() + assert "server.js:fn" in entry_points + + +def test_route_middleware_is_entry_point(): + """route_middleware units must be detected as entry points so they are not + silently dropped by the reachability filter. + + Regression for the gap where `route_middleware` was missing from + ENTRY_POINT_TYPES: Express anonymous middleware bodies (which receive req + directly and can be doing anything dangerous) were filtered out before the + LLM ever saw them. + """ + assert "route_middleware" in ENTRY_POINT_TYPES, ( + "route_middleware must be in ENTRY_POINT_TYPES so the reachability " + "filter treats anonymous Express middleware as entry points" + ) + + detector = _make_detector("route_middleware") + entry_points = detector.detect_entry_points() + assert "server.js:fn" in entry_points, ( + "route_middleware unit was filtered out — it must survive as an entry point" + ) + + +def test_unknown_unit_type_is_not_entry_point(): + """A unit with an unrecognised unit_type is not an entry point unless it + matches a decorator or user-input pattern.""" + detector = _make_detector("utility") + entry_points = detector.detect_entry_points() + assert "server.js:fn" not in entry_points diff --git a/libs/openant-core/tests/test_go_cli.py b/libs/openant-core/tests/test_go_cli.py index 42ad294..86be273 100644 --- a/libs/openant-core/tests/test_go_cli.py +++ b/libs/openant-core/tests/test_go_cli.py @@ -79,6 +79,14 @@ def test_scan_help(self): output = result.stdout + result.stderr assert "pipeline" in output.lower() + def test_scan_help_advertises_llm_reachability(self): + """The opt-in --llm-reachability flag (issue #17) should be discoverable + from `openant scan --help`.""" + result = run_cli("scan", "--help") + assert result.returncode == 0 + output = result.stdout + result.stderr + assert "llm-reachability" in output.lower() + class TestParse: def test_parse_python_repo(self, sample_python_repo, tmp_path): @@ -166,3 +174,205 @@ def test_scan_requires_api_key(self, sample_python_repo): output = result.stderr + result.stdout assert result.returncode != 0 assert "api key" in output.lower() + + +class TestInit: + """Integration tests for ``openant init`` covering item 13 of #16: + auto-detect language and tolerate non-git directories. + """ + + @pytest.fixture + def isolated_home(self, tmp_path): + """Override home so init writes into a tmp ~/.openant/.""" + home = str(tmp_path / "fakehome") + os.makedirs(home) + # USERPROFILE for Windows, HOME for Unix. + return {"USERPROFILE": home, "HOME": home} + + def _read_project_json(self, home_dir, project_name): + project_json = ( + Path(home_dir) + / ".openant" + / "projects" + / project_name + / "project.json" + ) + assert project_json.exists(), ( + f"project.json not found at {project_json}" + ) + return json.loads(project_json.read_text()) + + @staticmethod + def _make_repo(tmp_path, name, files): + repo = tmp_path / name + repo.mkdir() + for rel, content in files.items(): + target = repo / rel + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content) + return repo + + def test_auto_detect_python_from_fixture( + self, sample_python_repo, isolated_home + ): + """Init with -l auto on a Python fixture detects ``python``.""" + result = run_cli( + "init", sample_python_repo, + "--name", "test/python-repo", + "-l", "auto", + env_override=isolated_home, + ) + assert result.returncode == 0, f"init failed:\n{result.stderr}" + assert "Detected language: python" in result.stderr + + project = self._read_project_json( + isolated_home["HOME"], "test/python-repo", + ) + assert project["language"] == "python" + + def test_auto_detect_javascript_from_fixture( + self, sample_js_repo, isolated_home + ): + """Init with -l auto on a JS fixture detects ``javascript``.""" + result = run_cli( + "init", sample_js_repo, + "--name", "test/js-repo", + "-l", "auto", + env_override=isolated_home, + ) + assert result.returncode == 0, f"init failed:\n{result.stderr}" + assert "Detected language: javascript" in result.stderr + + project = self._read_project_json( + isolated_home["HOME"], "test/js-repo", + ) + assert project["language"] == "javascript" + + def test_auto_detect_typescript_synthetic(self, tmp_path, isolated_home): + """A TS-only tree (no .git) is detected as ``javascript``.""" + repo = self._make_repo( + tmp_path, "ts_repo", + { + "src/app.ts": "export const x = 1;\n", + "src/comp.tsx": "export default () => null;\n", + "src/util.ts": "export const y = 2;\n", + }, + ) + result = run_cli( + "init", str(repo), + "--name", "test/ts-synth", + "-l", "auto", + env_override=isolated_home, + ) + assert result.returncode == 0, f"init failed:\n{result.stderr}" + assert "Detected language: javascript" in result.stderr + + project = self._read_project_json( + isolated_home["HOME"], "test/ts-synth", + ) + assert project["language"] == "javascript" + + def test_auto_detect_go_synthetic(self, tmp_path, isolated_home): + """A Go-only tree (no .git) is detected as ``go``.""" + repo = self._make_repo( + tmp_path, "go_repo", + { + "main.go": "package main\nfunc main() {}\n", + "internal/svc.go": "package internal\n", + "cmd/cli.go": "package cmd\n", + }, + ) + result = run_cli( + "init", str(repo), + "--name", "test/go-synth", + "-l", "auto", + env_override=isolated_home, + ) + assert result.returncode == 0, f"init failed:\n{result.stderr}" + assert "Detected language: go" in result.stderr + + project = self._read_project_json( + isolated_home["HOME"], "test/go-synth", + ) + assert project["language"] == "go" + + def test_explicit_language_overrides_auto_detect( + self, sample_python_repo, isolated_home + ): + """An explicit ``-l`` flag wins over auto-detection.""" + result = run_cli( + "init", sample_python_repo, + "--name", "test/explicit-lang", + "-l", "go", + env_override=isolated_home, + ) + assert result.returncode == 0, f"init failed:\n{result.stderr}" + # Auto-detect path must not run when -l is supplied. + assert "Auto-detecting" not in result.stderr + + project = self._read_project_json( + isolated_home["HOME"], "test/explicit-lang", + ) + assert project["language"] == "go" + + def test_non_git_directory_uses_nogit_sha(self, tmp_path, isolated_home): + """Init on a plain (non-.git) dir succeeds with ``nogit`` placeholder.""" + repo = self._make_repo( + tmp_path, "plain_repo", + {"main.py": "print('hello')\n"}, + ) + # Sanity: not a git repo. + assert not (repo / ".git").exists() + + result = run_cli( + "init", str(repo), + "--name", "test/no-git", + "-l", "auto", + env_override=isolated_home, + ) + assert result.returncode == 0, f"init failed:\n{result.stderr}" + + project = self._read_project_json( + isolated_home["HOME"], "test/no-git", + ) + assert project["language"] == "python" + assert project["commit_sha"] == "nogit" + assert project["commit_sha_short"] == "nogit" + + def test_non_git_directory_warns_on_commit_flag( + self, tmp_path, isolated_home + ): + """``--commit`` on a non-git directory warns and falls back to ``nogit``.""" + repo = self._make_repo( + tmp_path, "plain_repo", + {"main.py": "print('hello')\n"}, + ) + result = run_cli( + "init", str(repo), + "--name", "test/no-git-commit", + "--commit", "abc123", + "-l", "auto", + env_override=isolated_home, + ) + assert result.returncode == 0, f"init failed:\n{result.stderr}" + assert "ignored" in result.stderr.lower() + + project = self._read_project_json( + isolated_home["HOME"], "test/no-git-commit", + ) + assert project["commit_sha"] == "nogit" + + def test_empty_dir_fails_with_clear_error(self, tmp_path, isolated_home): + """Init on a directory with no source files fails cleanly.""" + empty = tmp_path / "empty_repo" + empty.mkdir() + + result = run_cli( + "init", str(empty), + "--name", "test/empty", + "-l", "auto", + env_override=isolated_home, + ) + assert result.returncode != 0 + combined = (result.stderr + result.stdout).lower() + assert "no supported source files" in combined diff --git a/libs/openant-core/tests/test_js_parser_bootstrap.py b/libs/openant-core/tests/test_js_parser_bootstrap.py new file mode 100644 index 0000000..a6c5b99 --- /dev/null +++ b/libs/openant-core/tests/test_js_parser_bootstrap.py @@ -0,0 +1,181 @@ +"""Tests for the JS parser's lazy npm-install bootstrap. + +Covers `_ensure_js_parser_dependencies` in core.parser_adapter: behavior when +node_modules is present, missing, partially installed, npm is unavailable, or +`npm install` fails. These tests monkeypatch subprocess and shutil.which so +they don't need Node. +""" +from pathlib import Path + +import pytest + +from core import parser_adapter + + +@pytest.fixture +def fake_parser_dir(tmp_path, monkeypatch): + """Point _JS_PARSER_DIR at a tmp dir (with package.json) so tests don't + touch the real one.""" + monkeypatch.setattr(parser_adapter, "_JS_PARSER_DIR", tmp_path) + # All happy-path tests assume package.json exists. Tests that need to + # exercise the missing-package.json branch can delete it. + (tmp_path / "package.json").write_text('{"name": "fake"}') + return tmp_path + + +def _mark_installed(parser_dir: Path) -> None: + """Create the success sentinel npm writes after a complete install.""" + nm = parser_dir / "node_modules" + nm.mkdir(exist_ok=True) + (nm / ".package-lock.json").write_text("{}") + + +def test_skips_install_when_deps_already_installed(fake_parser_dir, monkeypatch): + _mark_installed(fake_parser_dir) + + calls = [] + monkeypatch.setattr(parser_adapter.subprocess, "run", lambda *a, **kw: calls.append((a, kw))) + monkeypatch.setattr(parser_adapter.shutil, "which", lambda name: "/usr/bin/npm") + + parser_adapter._ensure_js_parser_dependencies() + + assert calls == [] + + +def test_retries_install_when_node_modules_partially_installed(fake_parser_dir, monkeypatch): + """A killed prior install leaves node_modules/ but no .package-lock.json + sentinel. The bootstrap must retry rather than skip.""" + (fake_parser_dir / "node_modules").mkdir() # no .package-lock.json -> partial + + calls = [] + + class _Ok: + returncode = 0 + + def _fake_run(cmd, **kwargs): + calls.append((cmd, kwargs)) + # Simulate npm completing the install by writing the sentinel. + _mark_installed(fake_parser_dir) + return _Ok() + + monkeypatch.setattr(parser_adapter.subprocess, "run", _fake_run) + monkeypatch.setattr(parser_adapter.shutil, "which", lambda name: "/usr/bin/npm") + + parser_adapter._ensure_js_parser_dependencies() + + assert len(calls) == 1, "Partial node_modules should trigger a re-install" + + +def test_runs_npm_install_when_node_modules_missing(fake_parser_dir, monkeypatch): + calls = [] + + class _Ok: + returncode = 0 + + def _fake_run(cmd, **kwargs): + calls.append((cmd, kwargs)) + return _Ok() + + monkeypatch.setattr(parser_adapter.subprocess, "run", _fake_run) + monkeypatch.setattr(parser_adapter.shutil, "which", lambda name: "/usr/bin/npm") + + parser_adapter._ensure_js_parser_dependencies() + + assert len(calls) == 1 + cmd, kwargs = calls[0] + assert cmd == ["/usr/bin/npm", "install"] + assert kwargs["cwd"] == str(fake_parser_dir) + + +def test_raises_when_npm_not_on_path(fake_parser_dir, monkeypatch): + monkeypatch.setattr(parser_adapter.shutil, "which", lambda name: None) + + with pytest.raises(RuntimeError, match="npm"): + parser_adapter._ensure_js_parser_dependencies() + + +def test_raises_when_package_json_missing(fake_parser_dir, monkeypatch): + """If the JS parser dir has no package.json, surface a clear error rather + than silently letting npm create an empty install.""" + (fake_parser_dir / "package.json").unlink() + + monkeypatch.setattr(parser_adapter.shutil, "which", lambda name: "/usr/bin/npm") + + with pytest.raises(RuntimeError, match="package.json not found"): + parser_adapter._ensure_js_parser_dependencies() + + +def test_raises_when_npm_install_fails(fake_parser_dir, monkeypatch): + class _Fail: + returncode = 1 + + monkeypatch.setattr(parser_adapter.subprocess, "run", lambda *a, **kw: _Fail()) + monkeypatch.setattr(parser_adapter.shutil, "which", lambda name: "/usr/bin/npm") + + with pytest.raises(RuntimeError, match="npm install.*exit code 1"): + parser_adapter._ensure_js_parser_dependencies() + + +def test_install_failure_message_includes_repro_command(fake_parser_dir, monkeypatch): + """The error message must tell the user how to reproduce the install + locally so they can read npm's diagnostics.""" + class _Fail: + returncode = 1 + + monkeypatch.setattr(parser_adapter.subprocess, "run", lambda *a, **kw: _Fail()) + monkeypatch.setattr(parser_adapter.shutil, "which", lambda name: "/usr/bin/npm") + + with pytest.raises(RuntimeError) as exc_info: + parser_adapter._ensure_js_parser_dependencies() + + msg = str(exc_info.value) + assert "npm install" in msg + assert str(fake_parser_dir) in msg + + +def test_parse_javascript_surfaces_bootstrap_error(fake_parser_dir, monkeypatch): + """When bootstrap fails, _parse_javascript must not run the Node subprocess.""" + monkeypatch.setattr(parser_adapter.shutil, "which", lambda name: None) + + ran_node = [] + monkeypatch.setattr( + parser_adapter.subprocess, + "run", + lambda *a, **kw: ran_node.append((a, kw)), + ) + + with pytest.raises(RuntimeError, match="npm"): + parser_adapter._parse_javascript( + repo_path="/tmp/fake-repo", + output_dir="/tmp/fake-out", + processing_level="all", + ) + + assert ran_node == [], "Node subprocess should not run when bootstrap fails" + + +def test_concurrent_bootstrap_serialized_by_lock(fake_parser_dir, monkeypatch): + """The lockfile must serialize installs: the second caller, blocked behind + the first, must observe the sentinel on entry and skip its own install.""" + install_count = 0 + + class _Ok: + returncode = 0 + + def _fake_run(cmd, **kwargs): + nonlocal install_count + install_count += 1 + _mark_installed(fake_parser_dir) + return _Ok() + + monkeypatch.setattr(parser_adapter.subprocess, "run", _fake_run) + monkeypatch.setattr(parser_adapter.shutil, "which", lambda name: "/usr/bin/npm") + + # Two sequential calls in the same process: first installs and writes the + # sentinel, second sees the sentinel and is a no-op. (True multi-process + # concurrency is exercised by the OS lock; we just verify the + # re-check-under-lock + sentinel logic.) + parser_adapter._ensure_js_parser_dependencies() + parser_adapter._ensure_js_parser_dependencies() + + assert install_count == 1 diff --git a/libs/openant-core/tests/test_llm_reachability.py b/libs/openant-core/tests/test_llm_reachability.py new file mode 100644 index 0000000..bbad813 --- /dev/null +++ b/libs/openant-core/tests/test_llm_reachability.py @@ -0,0 +1,490 @@ +"""Tests for the LLM reachability review stage (issue #17). + +The stage is opt-in and advisory: signals may PROMOTE a unit's +reachability but never demote one that the structural analysis kept. +These tests pin that behavior down with a fully mocked LLM client so they +run without network access or an API key. +""" + +from __future__ import annotations + +import json +from typing import List + +import pytest + +from core.llm_reachability import ( + ReachabilitySignal, + analyze_reachability, + apply_signals, + build_prompt, + parse_response, + signals_to_json, +) + + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +class FakeClient: + """Minimal stand-in for AnthropicClient. + + Records calls and replays a fixed sequence of canned responses. + """ + + def __init__(self, responses: List[str]): + self._responses = list(responses) + self.calls: List[dict] = [] + + def analyze_sync(self, prompt: str, max_tokens: int = 4096, model: str = ""): + self.calls.append( + {"prompt": prompt, "max_tokens": max_tokens, "model": model} + ) + if not self._responses: + return '{"signals": []}' + return self._responses.pop(0) + + +def _make_unit(unit_id: str, code: str = "pass", **kw) -> dict: + unit = { + "id": unit_id, + "unit_type": kw.pop("unit_type", "function"), + "code": {"primary_code": code}, + } + unit.update(kw) + return unit + + +# --------------------------------------------------------------------------- +# parse_response +# --------------------------------------------------------------------------- + + +class TestParseResponse: + def test_parses_well_formed_signal(self): + text = json.dumps( + { + "signals": [ + { + "unit_id": "app.py:handler", + "kind": "entry_point", + "confidence": "high", + "reason": "Express handler", + } + ] + } + ) + sigs = parse_response(text, valid_unit_ids={"app.py:handler"}) + assert len(sigs) == 1 + assert sigs[0].unit_id == "app.py:handler" + assert sigs[0].kind == "entry_point" + assert sigs[0].confidence == "high" + assert "Express" in sigs[0].reason + + def test_strips_markdown_fences(self): + text = "```json\n" + json.dumps( + {"signals": [ + {"unit_id": "x.py:f", "kind": "external_input", + "confidence": "medium", "reason": "reads argv"}]} + ) + "\n```" + sigs = parse_response(text, valid_unit_ids={"x.py:f"}) + assert len(sigs) == 1 + assert sigs[0].kind == "external_input" + + def test_falls_back_to_first_object(self): + text = "Sure! Here you go:\n" + json.dumps( + {"signals": [ + {"unit_id": "a.py:g", "kind": "cross_process", + "confidence": "low", "reason": "queue"}]} + ) + "\nEnd." + sigs = parse_response(text, valid_unit_ids={"a.py:g"}) + assert len(sigs) == 1 + + def test_malformed_json_returns_empty(self): + errors: List[str] = [] + sigs = parse_response( + "not json at all", + valid_unit_ids={"x"}, + on_error=errors.append, + ) + assert sigs == [] + assert any("malformed" in e for e in errors) + + def test_invalid_kind_skipped(self): + text = json.dumps( + {"signals": [ + {"unit_id": "x.py:f", "kind": "garbage", + "confidence": "high", "reason": "n/a"}]} + ) + errors: List[str] = [] + sigs = parse_response( + text, valid_unit_ids={"x.py:f"}, on_error=errors.append + ) + assert sigs == [] + assert any("invalid kind" in e for e in errors) + + def test_unknown_unit_id_skipped(self): + text = json.dumps( + {"signals": [ + {"unit_id": "ghost.py:f", "kind": "entry_point", + "confidence": "high", "reason": "n/a"}]} + ) + errors: List[str] = [] + sigs = parse_response( + text, valid_unit_ids={"real.py:f"}, on_error=errors.append + ) + assert sigs == [] + + def test_signals_not_a_list_returns_empty(self): + text = json.dumps({"signals": "nope"}) + errors: List[str] = [] + sigs = parse_response(text, on_error=errors.append) + assert sigs == [] + + +# --------------------------------------------------------------------------- +# build_prompt / app_context threading +# --------------------------------------------------------------------------- + + +class TestBuildPrompt: + def test_includes_unit_ids_and_code(self): + units = [_make_unit("app.py:handler", code="def handler(): ...")] + prompt = build_prompt(units) + assert "app.py:handler" in prompt + assert "def handler()" in prompt + + def test_no_app_context_marker(self): + prompt = build_prompt([_make_unit("a:f")]) + assert "(none provided)" in prompt + + def test_includes_app_context_when_provided(self): + ctx = {"application_type": "web_app", "framework": "Express"} + prompt = build_prompt([_make_unit("a:f")], app_context=ctx) + assert "web_app" in prompt + assert "Express" in prompt + + def test_truncates_overly_long_code(self): + big = "x = 1\n" * 5000 + prompt = build_prompt([_make_unit("a:f", code=big)]) + assert "[truncated]" in prompt + + def test_max_code_bytes_override_keeps_more_context(self): + """Larger max_code_bytes should preserve content past the default cutoff.""" + # 3000 bytes of unique markers — past default 1500, within 4096 + big = ("# unique-marker\n" * 3) + ("x = 1\n" * 600) + "FINAL_MARKER = True\n" + # default 1500: FINAL_MARKER is past the cutoff and should be missing + default_prompt = build_prompt([_make_unit("a:f", code=big)]) + assert "FINAL_MARKER" not in default_prompt + assert "[truncated]" in default_prompt + # 4096: FINAL_MARKER fits and should appear + wide_prompt = build_prompt( + [_make_unit("a:f", code=big)], max_code_bytes=4096 + ) + assert "FINAL_MARKER" in wide_prompt + + +# --------------------------------------------------------------------------- +# analyze_reachability — full call with a mocked client +# --------------------------------------------------------------------------- + + +class TestAnalyzeReachability: + def test_parses_signals_from_mocked_llm(self): + dataset = { + "units": [ + _make_unit("app.py:handler"), + _make_unit("util.py:helper"), + ] + } + canned = json.dumps( + { + "signals": [ + { + "unit_id": "app.py:handler", + "kind": "entry_point", + "confidence": "high", + "reason": "Express handler", + }, + { + "unit_id": "util.py:helper", + "kind": "external_input", + "confidence": "medium", + "reason": "reads file", + }, + ] + } + ) + client = FakeClient([canned]) + signals = analyze_reachability(dataset, client=client) + assert len(signals) == 2 + assert {s.kind for s in signals} == {"entry_point", "external_input"} + assert len(client.calls) == 1 + + def test_app_context_threaded_into_prompt(self): + dataset = {"units": [_make_unit("a:f")]} + client = FakeClient(['{"signals": []}']) + ctx = {"application_type": "web_app", "framework": "Flask"} + analyze_reachability(dataset, app_context=ctx, client=client) + assert "Flask" in client.calls[0]["prompt"] + assert "web_app" in client.calls[0]["prompt"] + + def test_malformed_response_handled_gracefully(self): + dataset = {"units": [_make_unit("a:f")]} + errors: List[str] = [] + client = FakeClient(["this is not JSON"]) + sigs = analyze_reachability( + dataset, client=client, on_error=errors.append + ) + assert sigs == [] + assert errors # at least one error logged + + def test_empty_dataset_returns_empty(self): + client = FakeClient([]) + sigs = analyze_reachability({"units": []}, client=client) + assert sigs == [] + assert client.calls == [] # no LLM calls when nothing to review + + def test_batch_size_chunks_units(self): + dataset = {"units": [_make_unit(f"a:{i}") for i in range(7)]} + client = FakeClient(['{"signals": []}'] * 5) + analyze_reachability(dataset, client=client, batch_size=3) + # 7 units / 3 per batch = 3 calls + assert len(client.calls) == 3 + + def test_non_positive_batch_size_uses_single_batch(self): + """``batch_size <= 0`` historically tripped a NameError. Guard the + contract: non-positive size collapses to a single batch covering all + units (and never raises).""" + dataset = {"units": [_make_unit(f"a:{i}") for i in range(4)]} + client = FakeClient(['{"signals": []}']) + analyze_reachability(dataset, client=client, batch_size=0) + assert len(client.calls) == 1 + + def test_client_exception_does_not_crash(self): + class Boom: + def analyze_sync(self, *a, **kw): + raise RuntimeError("api boom") + + errors: List[str] = [] + sigs = analyze_reachability( + {"units": [_make_unit("a:f")]}, + client=Boom(), + on_error=errors.append, + ) + assert sigs == [] + assert any("api boom" in e for e in errors) + + +# --------------------------------------------------------------------------- +# apply_signals — promote-only semantics +# --------------------------------------------------------------------------- + + +class TestApplySignals: + def test_high_confidence_entry_point_promotes(self): + dataset = {"units": [_make_unit("a:f", is_entry_point=False)]} + sigs = [ + ReachabilitySignal("a:f", "entry_point", "high", "framework hook") + ] + summary = apply_signals(dataset, sigs) + assert dataset["units"][0]["is_entry_point"] is True + assert summary["entry_points_promoted"] == 1 + assert summary["signals_applied"] == 1 + assert summary["units_touched"] == 1 + + def test_medium_confidence_does_not_promote(self): + dataset = {"units": [_make_unit("a:f", is_entry_point=False)]} + sigs = [ + ReachabilitySignal("a:f", "entry_point", "medium", "maybe") + ] + summary = apply_signals(dataset, sigs) + assert dataset["units"][0]["is_entry_point"] is False + assert summary["entry_points_promoted"] == 0 + # but the signal is still attached for the reviewer + assert summary["signals_applied"] == 1 + + def test_external_input_does_not_set_entry_point(self): + dataset = {"units": [_make_unit("a:f", is_entry_point=False)]} + sigs = [ + ReachabilitySignal("a:f", "external_input", "high", "argv") + ] + apply_signals(dataset, sigs) + # external_input never sets is_entry_point regardless of confidence + assert dataset["units"][0]["is_entry_point"] is False + + def test_does_not_demote_existing_entry_point(self): + """Crucial promote-only invariant: a unit the structural pass + already marked as an entry point must never be unmarked, even if + the LLM emits no signal (or a low-confidence one) for it.""" + dataset = {"units": [_make_unit("a:f", is_entry_point=True)]} + # Empty signal list — apply_signals must not flip the flag. + apply_signals(dataset, []) + assert dataset["units"][0]["is_entry_point"] is True + + # Even a stray "low" entry_point signal must not flip it back. + sigs = [ReachabilitySignal("a:f", "entry_point", "low", "weak")] + apply_signals(dataset, sigs) + assert dataset["units"][0]["is_entry_point"] is True + + def test_signal_attached_to_unit(self): + dataset = {"units": [_make_unit("a:f")]} + sigs = [ + ReachabilitySignal("a:f", "external_input", "medium", "reads stdin") + ] + apply_signals(dataset, sigs) + unit = dataset["units"][0] + assert "llm_reachability_signals" in unit + assert len(unit["llm_reachability_signals"]) == 1 + attached = unit["llm_reachability_signals"][0] + assert attached["kind"] == "external_input" + assert attached["reason"] == "reads stdin" + + def test_multiple_signals_accumulate_on_same_unit(self): + dataset = {"units": [_make_unit("a:f")]} + sigs = [ + ReachabilitySignal("a:f", "external_input", "medium", "argv"), + ReachabilitySignal("a:f", "cross_process", "low", "queue"), + ] + apply_signals(dataset, sigs) + attached = dataset["units"][0]["llm_reachability_signals"] + assert len(attached) == 2 + + def test_unknown_unit_id_skipped(self): + dataset = {"units": [_make_unit("a:f")]} + sigs = [ReachabilitySignal("ghost:x", "entry_point", "high", "n/a")] + summary = apply_signals(dataset, sigs) + assert summary["signals_applied"] == 0 + assert summary["entry_points_promoted"] == 0 + + +class TestSerialization: + def test_signals_to_json_roundtrip(self): + sigs = [ + ReachabilitySignal("a:f", "entry_point", "high", "r1"), + ReachabilitySignal("b:g", "external_input", "low", "r2"), + ] + out = signals_to_json(sigs) + assert isinstance(out, list) + assert all(isinstance(item, dict) for item in out) + # Round-trips through JSON cleanly. + json.loads(json.dumps(out)) + + +# --------------------------------------------------------------------------- +# CLI flag plumbing — mock scan_repository to confirm wiring without API +# --------------------------------------------------------------------------- + + +class TestCliPlumbing: + """Confirms that the --llm-reachability flag exists in scan --help and + that, by default (no flag), the LLM reachability path is not invoked. + + These tests exercise the Python CLI directly (no Go binary required), so + they always run in the basic pytest suite. + """ + + def test_flag_appears_in_scan_help(self, capsys): + from openant.cli import main + + with pytest.raises(SystemExit): + import sys + old = sys.argv + try: + sys.argv = ["openant", "scan", "--help"] + main() + finally: + sys.argv = old + out = capsys.readouterr().out + capsys.readouterr().err + assert "--llm-reachability" in out + + def test_default_does_not_invoke_llm_reachability(self, monkeypatch, tmp_path): + """When --llm-reachability is NOT passed, ``analyze_reachability`` in + the scanner module must not be called. + + We achieve this by monkey-patching ``scan_repository`` to a stub + that records its kwargs, then driving ``cmd_scan`` through it. + """ + captured = {} + + from openant import cli as cli_mod + + def fake_scan(**kwargs): + captured.update(kwargs) + from core.schemas import ScanResult + r = ScanResult(output_dir=str(tmp_path)) + return r + + monkeypatch.setattr( + "core.scanner.scan_repository", fake_scan, raising=True + ) + + # Drive cmd_scan via argparse + import argparse + ns = argparse.Namespace( + repo=str(tmp_path), + output=str(tmp_path / "out"), + language="auto", + level="reachable", + verify=False, + no_context=True, + no_enhance=True, + enhance_mode="agentic", + no_report=True, + dynamic_test=False, + no_skip_tests=False, + limit=None, + model="opus", + workers=1, + repo_name=None, + repo_url=None, + commit_sha=None, + backoff=30, + diff_manifest=None, + llm_reachability=False, + ) + rc = cli_mod.cmd_scan(ns) + # rc 0 or 1 acceptable; we only care about plumbing. + assert rc in (0, 1) + assert captured.get("llm_reachability") is False + + def test_flag_passes_through_when_set(self, monkeypatch, tmp_path): + captured = {} + from openant import cli as cli_mod + + def fake_scan(**kwargs): + captured.update(kwargs) + from core.schemas import ScanResult + return ScanResult(output_dir=str(tmp_path)) + + monkeypatch.setattr( + "core.scanner.scan_repository", fake_scan, raising=True + ) + + import argparse + ns = argparse.Namespace( + repo=str(tmp_path), + output=str(tmp_path / "out"), + language="auto", + level="reachable", + verify=False, + no_context=True, + no_enhance=True, + enhance_mode="agentic", + no_report=True, + dynamic_test=False, + no_skip_tests=False, + limit=None, + model="opus", + workers=1, + repo_name=None, + repo_url=None, + commit_sha=None, + backoff=30, + diff_manifest=None, + llm_reachability=True, + ) + cli_mod.cmd_scan(ns) + assert captured.get("llm_reachability") is True diff --git a/libs/openant-core/utilities/agentic_enhancer/entry_point_detector.py b/libs/openant-core/utilities/agentic_enhancer/entry_point_detector.py index 22aab91..16df5b5 100644 --- a/libs/openant-core/utilities/agentic_enhancer/entry_point_detector.py +++ b/libs/openant-core/utilities/agentic_enhancer/entry_point_detector.py @@ -25,6 +25,7 @@ # Entry point patterns by unit_type (from function extractor classification) ENTRY_POINT_TYPES = { 'route_handler', # Flask/FastAPI/Express routes + 'route_middleware', # Express anonymous middleware callbacks (req, res, next) 'view_function', # Django views 'websocket_handler', # WebSocket endpoints 'cli_handler', # CLI commands