diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7582c4d..84a1121 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,12 +29,12 @@ jobs: pip install ruff - name: Lint with ruff - run: ruff check src/ + run: ruff check src/ tests/ - name: Format check with ruff - run: ruff format --check src/ + run: ruff format --check src/ tests/ - test: + test-unit: runs-on: ubuntu-latest strategy: matrix: @@ -47,8 +47,8 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Install package - run: pip install -e . + - name: Install package + test deps + run: pip install -e '.[mcp]' pytest pytest-cov - name: Verify CLI entry point run: memory --version @@ -61,8 +61,37 @@ jobs: python -c "from ai_memory_protocol.formatter import format_brief" python -c "from ai_memory_protocol.rst import generate_rst_directive" python -c "from ai_memory_protocol.config import TYPE_FILES" + python -c "from ai_memory_protocol.mcp_server import create_mcp_server; create_mcp_server()" + + - name: Run unit tests + run: pytest tests/ -v -m "not integration" --cov=ai_memory_protocol --cov-report=term-missing --cov-report=html + + - name: Upload coverage + if: matrix.python-version == '3.12' + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: htmlcov/ + + test-integration: + runs-on: ubuntu-latest + needs: [test-unit] + steps: + - uses: actions/checkout@v4 - - name: Test init + add + rebuild workflow + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install package + run: pip install -e '.[mcp]' pytest pytest-timeout + + - name: Run integration tests + run: pytest tests/ -v -m integration --timeout=120 + continue-on-error: true + + - name: Run workflow smoke test run: | memory init /tmp/test-memory --name "CI Test" --install memory --dir /tmp/test-memory add fact "CI test fact" \ @@ -74,9 +103,11 @@ jobs: memory --dir /tmp/test-memory get FACT_ci_test_fact memory --dir /tmp/test-memory tags memory --dir /tmp/test-memory list + memory --dir /tmp/test-memory doctor test-mcp: runs-on: ubuntu-latest + needs: [test-unit] steps: - uses: actions/checkout@v4 @@ -86,9 +117,12 @@ jobs: python-version: "3.12" - name: Install package with MCP extras - run: pip install -e '.[mcp]' + run: pip install -e '.[mcp]' pytest - name: Verify MCP server imports run: | python -c "from ai_memory_protocol.mcp_server import create_mcp_server; create_mcp_server()" python -c "from ai_memory_protocol.mcp_server import TOOLS; print(f'{len(TOOLS)} MCP tools registered')" + + - name: Run MCP tests + run: pytest tests/test_mcp_server.py -v diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6c1788b --- /dev/null +++ b/Makefile @@ -0,0 +1,42 @@ +.PHONY: install install-dev install-mcp inject-mcp test test-cov test-unit test-integration lint format doctor uninstall help + +install: ## Install via pipx (CLI only) + pipx install -e . + +install-mcp: ## Install via pipx with MCP support + pipx install -e '.[mcp]' + +install-dev: ## Install for development (editable + dev deps) + pip install -e '.[mcp]' + pip install ruff pytest pytest-cov + +inject-mcp: ## Add MCP to existing pipx install + pipx inject ai-memory-protocol mcp + +test: ## Run all tests + pytest tests/ -v + +test-unit: ## Run unit tests only (no Sphinx needed) + pytest tests/ -v -m "not integration" + +test-integration: ## Run integration tests (requires Sphinx) + pytest tests/ -v -m integration + +test-cov: ## Run tests with coverage + pytest tests/ -v --cov=ai_memory_protocol --cov-report=term-missing --cov-report=html + +lint: ## Run linters + ruff check src/ tests/ + ruff format --check src/ tests/ + +format: ## Format code + ruff format src/ tests/ + +doctor: ## Verify installation + memory doctor + +uninstall: ## Uninstall from pipx + pipx uninstall ai-memory-protocol + +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*## ' Makefile | sort | awk 'BEGIN {FS = ":.*## "}; {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}' diff --git a/pyproject.toml b/pyproject.toml index 7ac0feb..1a1a1ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ai-memory-protocol" -version = "0.2.0" +version = "0.3.0" description = "AI Memory Protocol — versioned, graph-based memory for AI agents using Sphinx-Needs" readme = "README.md" license = { text = "Apache-2.0" } @@ -44,6 +44,7 @@ memory-mcp-stdio = "ai_memory_protocol.mcp_server:main_stdio" dev = [ "ruff>=0.8", "pytest>=8.0", + "pytest-cov>=6.0", "pre-commit>=4.0", ] @@ -71,3 +72,7 @@ known-first-party = ["ai_memory_protocol"] [tool.pytest.ini_options] testpaths = ["tests"] +markers = [ + "integration: marks tests that require Sphinx build (deselect with '-m \"not integration\"')", +] +addopts = "-ra --strict-markers" diff --git a/src/ai_memory_protocol/__init__.py b/src/ai_memory_protocol/__init__.py index 27dad86..714084e 100644 --- a/src/ai_memory_protocol/__init__.py +++ b/src/ai_memory_protocol/__init__.py @@ -1,3 +1,3 @@ """AI Memory Protocol — versioned, graph-based memory for AI agents.""" -__version__ = "0.2.0" +__version__ = "0.3.0" diff --git a/src/ai_memory_protocol/capture.py b/src/ai_memory_protocol/capture.py new file mode 100644 index 0000000..82c7ce4 --- /dev/null +++ b/src/ai_memory_protocol/capture.py @@ -0,0 +1,470 @@ +"""Capture knowledge from external sources — git, CI, discussions. + +Primary use-case: extract memories from git history so the agent does +not lose context from past development sessions. + +Usage: + from ai_memory_protocol.capture import capture_from_git + candidates = capture_from_git(workspace, repo_path, since="2 weeks ago") +""" + +from __future__ import annotations + +import re +import subprocess +from dataclasses import dataclass, field +from difflib import SequenceMatcher +from pathlib import Path +from typing import Any + +from .engine import load_needs + +# --------------------------------------------------------------------------- +# Candidate (not yet a memory — needs review before adding) +# --------------------------------------------------------------------------- + + +@dataclass +class MemoryCandidate: + """A candidate memory extracted from a source (git, CI, discussion).""" + + type: str + title: str + body: str + tags: list[str] = field(default_factory=list) + source: str = "" + confidence: str = "medium" + scope: str = "global" + # For dedup + _source_hashes: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + d = { + "type": self.type, + "title": self.title, + "body": self.body, + "tags": self.tags, + "source": self.source, + "confidence": self.confidence, + "scope": self.scope, + } + return {k: v for k, v in d.items() if v} + + +# --------------------------------------------------------------------------- +# Git commit parsing +# --------------------------------------------------------------------------- + + +_GIT_RECORD_SEP = "\x1e" # ASCII Record Separator between commits +_GIT_FIELD_SEP = "\x1f" # ASCII Unit Separator between fields +_GIT_LOG_FORMAT = ( + f"%H{_GIT_FIELD_SEP}%s{_GIT_FIELD_SEP}%b{_GIT_FIELD_SEP}%an{_GIT_FIELD_SEP}%ad{_GIT_RECORD_SEP}" +) + + +@dataclass +class _GitCommit: + """Parsed git commit.""" + + hash: str + subject: str + body: str + author: str + date: str + files: list[str] = field(default_factory=list) + + +def _parse_git_log(repo_path: Path, since: str, until: str) -> list[_GitCommit]: + """Run git log and parse the output.""" + # Base git log command + cmd: list[str] = [ + "git", + "log", + f"--format={_GIT_LOG_FORMAT}", + "--date=iso-strict", + ] + + # Heuristic: if arguments contain spaces (e.g. "2 weeks ago"), treat them as + # date expressions and use --since/--until. Otherwise treat them as refs and + # use git's revision range syntax. + has_since = bool(since) + has_until = bool(until) + since_is_date = has_since and (" " in since) + until_is_date = has_until and (" " in until) + + if has_since and has_until: + if since_is_date or until_is_date: + # Date-based range + cmd.append(f"--since={since}") + cmd.append(f"--until={until}") + else: + # Ref-based range: use {since}..{until} + cmd.append(f"{since}..{until}") + elif has_since: + if since_is_date: + cmd.append(f"--since={since}") + else: + cmd.append(since) + elif has_until: + if until_is_date: + cmd.append(f"--until={until}") + else: + cmd.append(until) + + try: + result = subprocess.run( + cmd, + cwd=str(repo_path), + capture_output=True, + text=True, + ) + if result.returncode != 0: + return [] + except OSError: + return [] + + commits: list[_GitCommit] = [] + # Split on record separator — safe even when %b contains newlines + for record in result.stdout.split(_GIT_RECORD_SEP): + record = record.strip() + if not record: + continue + parts = record.split(_GIT_FIELD_SEP) + if len(parts) < 5: + continue + commit = _GitCommit( + hash=parts[0].strip(), + subject=parts[1].strip(), + body=parts[2].strip(), + author=parts[3].strip(), + date=parts[4].strip(), + ) + commits.append(commit) + + # Get changed files per commit + for c in commits: + try: + files_result = subprocess.run( + ["git", "diff-tree", "--no-commit-id", "--name-only", "-r", c.hash], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + c.files = [f.strip() for f in files_result.stdout.strip().split("\n") if f.strip()] + except OSError: + # Best-effort: if git diff-tree fails for this commit (e.g. git not + # available or repository in an unexpected state), leave the files + # list unchanged for this commit and continue processing others. + pass + + return commits + + +# --------------------------------------------------------------------------- +# Commit classification +# --------------------------------------------------------------------------- + +# Patterns for classifying commits by conventional commit prefix +_CLASSIFY_PATTERNS: list[tuple[str, str, str]] = [ + # (regex_pattern, memory_type, confidence) + (r"^fix[\(:]|^bugfix[\(:]|^hotfix[\(:]", "mem", "high"), + (r"^feat[\(:]|^add[\(:]|^feature[\(:]", "fact", "medium"), + (r"^refactor[\(:]|^perf[\(:]|^optimize[\(:]", "dec", "medium"), + (r"^docs[\(:]|^doc[\(:]", "fact", "low"), + (r"^test[\(:]|^tests[\(:]", "mem", "low"), + (r"^ci[\(:]|^build[\(:]|^chore[\(:]", "mem", "low"), + (r"BREAKING[ _]CHANGE", "risk", "high"), + (r"^revert[\(:]", "mem", "medium"), + (r"^style[\(:]", "pref", "low"), +] + + +def _classify_commit(commit: _GitCommit) -> tuple[str, str]: + """Classify a commit into a memory type + confidence. + + Returns (type, confidence). Falls back to "mem"/"low" for unclassifiable commits. + """ + subject = commit.subject + body_text = f"{subject} {commit.body}" + + # Check BREAKING CHANGE first (can appear anywhere) + if re.search(r"BREAKING[ _]CHANGE", body_text, re.IGNORECASE): + return "risk", "high" + + for pattern, mem_type, confidence in _CLASSIFY_PATTERNS: + if re.search(pattern, subject, re.IGNORECASE): + return mem_type, confidence + + return "mem", "low" + + +def _extract_scope(subject: str) -> str: + """Extract scope from conventional commit format. e.g. 'fix(gateway): ...' → 'gateway'.""" + match = re.match(r"^\w+\(([^)]+)\)", subject) + return match.group(1) if match else "" + + +def _infer_tags(commit: _GitCommit, repo_name: str) -> list[str]: + """Infer tags from commit metadata.""" + tags: list[str] = [f"repo:{repo_name}"] + + # Extract scope as topic tag + scope = _extract_scope(commit.subject) + if scope: + tags.append(f"topic:{scope}") + + # Infer topic from file paths + path_topics: set[str] = set() + for f in commit.files: + parts = Path(f).parts + if len(parts) >= 2: + # Use first meaningful directory as topic + for part in parts: + if part not in ("src", "lib", "test", "tests", "include", ".", ".."): + path_topics.add(part.replace("_", "-")) + break + + for topic in sorted(path_topics)[:3]: # Limit to 3 path-based topics + tag = f"topic:{topic}" + if tag not in tags: + tags.append(tag) + + return tags + + +# --------------------------------------------------------------------------- +# Grouping related commits +# --------------------------------------------------------------------------- + + +def _file_overlap(files1: list[str], files2: list[str]) -> float: + """Compute Jaccard similarity of file sets.""" + s1, s2 = set(files1), set(files2) + union = s1 | s2 + if not union: + return 0.0 + return len(s1 & s2) / len(union) + + +def _group_commits( + commits: list[_GitCommit], file_overlap_threshold: float = 0.3 +) -> list[list[_GitCommit]]: + """Group commits by file overlap (simple greedy clustering).""" + if not commits: + return [] + + groups: list[list[_GitCommit]] = [[commits[0]]] + + for commit in commits[1:]: + best_group = -1 + best_overlap = 0.0 + for i, group in enumerate(groups): + # Compare against all commits in the group + group_files = [f for c in group for f in c.files] + overlap = _file_overlap(commit.files, group_files) + if overlap > best_overlap: + best_overlap = overlap + best_group = i + + if best_overlap >= file_overlap_threshold and best_group >= 0: + groups[best_group].append(commit) + else: + groups.append([commit]) + + return groups + + +# --------------------------------------------------------------------------- +# Deduplication against existing memories +# --------------------------------------------------------------------------- + + +def _is_duplicate( + candidate: MemoryCandidate, + existing_needs: dict[str, Any], + title_threshold: float = 0.7, +) -> bool: + """Check if a candidate is a near-duplicate of an existing memory.""" + for need in existing_needs.values(): + if need.get("status") == "deprecated": + continue + existing_title = need.get("title", "").lower() + candidate_title = candidate.title.lower() + sim = SequenceMatcher(None, candidate_title, existing_title).ratio() + if sim >= title_threshold: + return True + + # Also check by source (exact commit hash match) + existing_source = need.get("source", "") + if candidate.source and candidate.source in existing_source: + return True + + return False + + +# --------------------------------------------------------------------------- +# Public interface: capture from git +# --------------------------------------------------------------------------- + + +def capture_from_git( + workspace: Path, + repo_path: Path, + since: str = "HEAD~20", + until: str = "HEAD", + repo_name: str | None = None, + deduplicate: bool = True, + min_confidence: str = "low", +) -> list[MemoryCandidate]: + """Analyze git log and generate memory candidates. + + Parameters + ---------- + workspace + Path to the memory workspace (for dedup against existing). + repo_path + Path to the git repository to analyze. + since + Start of the range (commit or date like ``"2 weeks ago"``). + until + End of the range (default: ``"HEAD"``). + repo_name + Repository name for ``repo:`` tags. Auto-detected from path if omitted. + deduplicate + If True, filter out candidates that match existing memories. + min_confidence + Minimum confidence to include. "low" includes all. + + Returns + ------- + list[MemoryCandidate] + Candidate memories ready for review and optional insertion. + """ + if repo_name is None: + repo_name = repo_path.name + + commits = _parse_git_log(repo_path, since, until) + if not commits: + return [] + + # Load existing memories for dedup + existing: dict[str, Any] = {} + if deduplicate: + try: + existing = load_needs(workspace) + except (SystemExit, Exception): + existing = {} + + # Classify and group + conf_rank = {"high": 2, "medium": 1, "low": 0} + min_conf_rank = conf_rank.get(min_confidence, 0) + + candidates: list[MemoryCandidate] = [] + + # Group related commits + groups = _group_commits(commits) + + for group in groups: + if len(group) == 1: + # Single commit → single candidate + commit = group[0] + mem_type, confidence = _classify_commit(commit) + + if conf_rank.get(confidence, 0) < min_conf_rank: + continue + + # Clean title: remove conventional commit prefix + title = re.sub(r"^\w+(\([^)]*\))?:\s*", "", commit.subject) + if not title: + title = commit.subject + + body_parts = [commit.body] if commit.body else [] + if commit.files: + body_parts.append(f"Files: {', '.join(commit.files[:10])}") + + candidate = MemoryCandidate( + type=mem_type, + title=title[:120], + body="\n".join(body_parts), + tags=_infer_tags(commit, repo_name), + source=f"commit:{commit.hash[:8]}", + confidence=confidence, + scope=f"repo:{repo_name}", + _source_hashes=[commit.hash], + ) + candidates.append(candidate) + else: + # Multiple related commits → summarize + primary = group[0] # Use first (most recent) commit + mem_type, confidence = _classify_commit(primary) + + # Upgrade confidence for grouped commits + if len(group) >= 3 and confidence == "low": + confidence = "medium" + + if conf_rank.get(confidence, 0) < min_conf_rank: + continue + + title = re.sub(r"^\w+(\([^)]*\))?:\s*", "", primary.subject) + if not title: + title = primary.subject + + body_parts = [f"Group of {len(group)} related commits:"] + for c in group[:5]: + body_parts.append(f" - {c.subject} ({c.hash[:8]})") + if len(group) > 5: + body_parts.append(f" ... and {len(group) - 5} more") + + all_files: set[str] = set() + all_tags: set[str] = set() + for c in group: + all_files.update(c.files) + for tag in _infer_tags(c, repo_name): + all_tags.add(tag) + + if all_files: + body_parts.append(f"Files: {', '.join(sorted(all_files)[:10])}") + + candidate = MemoryCandidate( + type=mem_type, + title=title[:120], + body="\n".join(body_parts), + tags=sorted(all_tags), + source=f"commit:{primary.hash[:8]}+{len(group) - 1}", + confidence=confidence, + scope=f"repo:{repo_name}", + _source_hashes=[c.hash for c in group], + ) + candidates.append(candidate) + + # Dedup against existing + if deduplicate and existing: + candidates = [c for c in candidates if not _is_duplicate(c, existing)] + + return candidates + + +def format_candidates(candidates: list[MemoryCandidate], fmt: str = "human") -> str: + """Format capture candidates for display.""" + if not candidates: + return "No new memory candidates found." + + if fmt == "json": + import json + + return json.dumps([c.to_dict() for c in candidates], indent=2, ensure_ascii=False) + + lines = [f"## {len(candidates)} memory candidate(s)\n"] + for i, c in enumerate(candidates, 1): + lines.append(f" {i}. [{c.type}] {c.title}") + lines.append(f" Tags: {', '.join(c.tags)}") + lines.append(f" Confidence: {c.confidence} | Source: {c.source}") + if c.body: + # Show first 2 lines of body + body_lines = c.body.split("\n")[:2] + for bl in body_lines: + lines.append(f" {bl}") + lines.append("") + + return "\n".join(lines) diff --git a/src/ai_memory_protocol/cli.py b/src/ai_memory_protocol/cli.py index 3b06471..5e459e7 100644 --- a/src/ai_memory_protocol/cli.py +++ b/src/ai_memory_protocol/cli.py @@ -27,6 +27,7 @@ from pathlib import Path from . import __version__ +from .capture import capture_from_git, format_candidates from .config import TYPE_FILES from .engine import ( expand_graph, @@ -37,7 +38,9 @@ tag_match, text_match, ) +from .executor import actions_from_json, execute_plan from .formatter import format_brief, format_compact, format_context_pack, format_full +from .planner import format_plan, run_plan from .rst import ( add_tags_in_rst, append_to_rst, @@ -49,6 +52,106 @@ ) from .scaffold import init_workspace +# --------------------------------------------------------------------------- +# Doctor checks +# --------------------------------------------------------------------------- + + +def _check_cli() -> tuple[bool, str]: + """Verify CLI entry point works.""" + from . import __version__ + + return True, f"v{__version__}" + + +def _check_workspace(workspace_dir: str | None) -> tuple[bool, str]: + """Verify workspace exists and is valid.""" + try: + ws = find_workspace(workspace_dir) + return True, str(ws) + except SystemExit as e: + return False, f"{e} — Run: memory init " + + +def _check_sphinx_build(workspace_dir: str | None) -> tuple[bool, str]: + """Verify sphinx-build is discoverable.""" + from .engine import find_sphinx_build + + try: + ws = find_workspace(workspace_dir) + except SystemExit: + return False, "Workspace not found (skipped)" + try: + sb = find_sphinx_build(ws) + return True, sb + except FileNotFoundError as e: + return False, f"Not found — {e}" + + +def _check_needs_json(workspace_dir: str | None) -> tuple[bool, str]: + """Verify needs.json is loadable.""" + from .engine import find_needs_json + + try: + ws = find_workspace(workspace_dir) + except SystemExit: + return False, "Workspace not found (skipped)" + path = find_needs_json(ws) + if not path.exists(): + return False, f"Not found at {path} — Run: memory rebuild" + try: + needs = load_needs(ws) + return True, f"{len(needs)} memories loaded" + except (SystemExit, Exception) as e: + return False, f"Failed to load: {e}" + + +def _check_mcp_importable() -> tuple[bool, str]: + """Verify MCP SDK is installed.""" + try: + import mcp # noqa: F401 + + return True, f"v{getattr(mcp, '__version__', '?')}" + except ImportError: + return False, "Not installed — Run: pipx inject ai-memory-protocol mcp" + + +def _check_mcp_server() -> tuple[bool, str]: + """Verify MCP server can be created.""" + try: + from .mcp_server import create_mcp_server + + create_mcp_server() + return True, "Server created successfully" + except ImportError as e: + return False, f"MCP SDK missing: {e}" + except Exception as e: + return False, f"Failed: {e}" + + +def _check_rst_files(workspace_dir: str | None) -> tuple[bool, str]: + """Verify RST files exist and are parseable.""" + try: + ws = find_workspace(workspace_dir) + except SystemExit: + return False, "Workspace not found (skipped)" + memory_dir = ws / "memory" + if not memory_dir.exists(): + return False, f"No memory/ directory in {ws}" + rst_files = list(memory_dir.glob("*.rst")) + if not rst_files: + return False, "No RST files found in memory/" + errors = [] + for f in rst_files: + try: + f.read_text() + except Exception as e: + errors.append(f"{f.name}: {e}") + if errors: + return False, f"{len(errors)} unreadable files: {'; '.join(errors)}" + return True, f"{len(rst_files)} RST files OK" + + # --------------------------------------------------------------------------- # Subcommands # --------------------------------------------------------------------------- @@ -65,6 +168,38 @@ def cmd_init(args: argparse.Namespace) -> None: ) +def cmd_doctor(args: argparse.Namespace) -> None: + """Run installation health checks.""" + ws_dir = getattr(args, "dir", None) + checks = [ + ("CLI entry point", _check_cli), + ("Workspace exists", lambda: _check_workspace(ws_dir)), + ("Sphinx-build available", lambda: _check_sphinx_build(ws_dir)), + ("needs.json loadable", lambda: _check_needs_json(ws_dir)), + ("MCP SDK installed", _check_mcp_importable), + ("MCP server creatable", _check_mcp_server), + ("RST files parseable", lambda: _check_rst_files(ws_dir)), + ] + all_ok = True + print("AI Memory Protocol — Health Check\n") + for name, check_fn in checks: + try: + ok, detail = check_fn() + status = "\u2713" if ok else "\u2717" + print(f" {status} {name}: {detail}") + if not ok: + all_ok = False + except Exception as e: + print(f" \u2717 {name}: CRASH — {e}") + all_ok = False + print() + if all_ok: + print("All checks passed.") + else: + print("Some checks failed. See details above.") + sys.exit(1) + + def cmd_add(args: argparse.Namespace) -> None: """Add a new memory entry.""" workspace = find_workspace(args.dir) @@ -340,6 +475,95 @@ def cmd_rebuild(args: argparse.Namespace) -> None: sys.exit(1) +def cmd_plan(args: argparse.Namespace) -> None: + """Analyze memory graph and generate a maintenance plan.""" + workspace = find_workspace(args.dir) + checks = [c.strip() for c in args.checks.split(",")] if args.checks else None + actions = run_plan(workspace, checks=checks) + fmt = args.format + print(format_plan(actions, fmt=fmt)) + + +def cmd_apply(args: argparse.Namespace) -> None: + """Execute a list of planned actions from a JSON file.""" + workspace = find_workspace(args.dir) + + if args.file: + import json as json_mod + + data = json_mod.loads(Path(args.file).read_text()) + actions = actions_from_json(data) + elif args.plan: + # Run plan first, then apply + checks = [c.strip() for c in args.plan.split(",")] if args.plan != "all" else None + actions_list = run_plan(workspace, checks=checks) + if not actions_list: + print("No issues found — nothing to apply.") + return + print(format_plan(actions_list, fmt="human")) + if not args.yes: + answer = input(f"\nApply {len(actions_list)} action(s)? [y/N] ") + if answer.lower() not in ("y", "yes"): + print("Aborted.") + return + actions = actions_list + else: + print("Provide --file or --plan [checks] to generate and apply.") + sys.exit(1) + + result = execute_plan( + workspace, + actions, + auto_commit=args.auto_commit, + rebuild=not args.no_rebuild, + ) + print(result.summary()) + if not result.success: + sys.exit(1) + + +def cmd_capture(args: argparse.Namespace) -> None: + """Capture memory candidates from external sources.""" + workspace = find_workspace(args.dir) + + if args.source == "git": + repo_path = Path(args.repo).resolve() if args.repo else Path.cwd() + candidates = capture_from_git( + workspace=workspace, + repo_path=repo_path, + since=args.since, + until=args.until, + repo_name=args.repo_name, + min_confidence=args.min_confidence, + ) + print(format_candidates(candidates, fmt=args.format)) + + if args.auto_add and candidates: + from .rst import append_to_rst, generate_rst_directive + + count = 0 + for c in candidates: + directive = generate_rst_directive( + mem_type=c.type, + title=c.title, + tags=c.tags, + source=c.source, + confidence=c.confidence, + scope=c.scope, + body=c.body, + ) + append_to_rst(workspace, c.type, directive) + count += 1 + print(f"\nAdded {count} memories to workspace.") + if not args.no_rebuild: + success, message = run_rebuild(workspace) + print(message) + else: + print(f"Unknown capture source: {args.source}") + print("Supported sources: git") + sys.exit(1) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -548,6 +772,106 @@ def build_parser() -> argparse.ArgumentParser: p_rebuild = sub.add_parser("rebuild", help="Rebuild needs.json from RST sources") p_rebuild.set_defaults(func=cmd_rebuild) + # --- doctor --- + p_doctor = sub.add_parser("doctor", help="Run installation health checks") + p_doctor.set_defaults(func=cmd_doctor) + + # --- plan --- + p_plan = sub.add_parser("plan", help="Analyze memory graph and generate maintenance plan") + p_plan.add_argument( + "--checks", + help=( + "Comma-separated checks to run. " + "Options: duplicates, missing_tags, stale, conflicts, tag_normalize, split_files. " + "Default: all." + ), + ) + p_plan.add_argument( + "--format", + "-f", + choices=["human", "json"], + default="human", + help="Output format (default: human)", + ) + p_plan.set_defaults(func=cmd_plan) + + # --- apply --- + p_apply = sub.add_parser("apply", help="Execute planned maintenance actions") + p_apply.add_argument("--file", help="JSON file containing actions to apply") + p_apply.add_argument( + "--plan", + nargs="?", + const="all", + help="Run plan first, then apply. Optionally specify checks (comma-separated).", + ) + p_apply.add_argument( + "--auto-commit", + action="store_true", + help="Commit changes to git after successful apply", + ) + p_apply.add_argument( + "--no-rebuild", + action="store_true", + help="Skip Sphinx rebuild after applying", + ) + p_apply.add_argument( + "-y", + "--yes", + action="store_true", + help="Skip confirmation prompt when using --plan", + ) + p_apply.set_defaults(func=cmd_apply) + + # --- capture --- + p_capture = sub.add_parser("capture", help="Capture memories from external sources") + p_capture.add_argument( + "source", + choices=["git"], + help="Capture source type", + ) + p_capture.add_argument( + "--repo", + help="Path to git repository (default: current directory)", + ) + p_capture.add_argument( + "--repo-name", + help="Repository name for repo: tags (auto-detected from path if omitted)", + ) + p_capture.add_argument( + "--since", + default="HEAD~20", + help="Start of git range (commit ref or date like '2 weeks ago'). Default: HEAD~20", + ) + p_capture.add_argument( + "--until", + default="HEAD", + help="End of git range. Default: HEAD", + ) + p_capture.add_argument( + "--min-confidence", + choices=["low", "medium", "high"], + default="low", + help="Minimum confidence to include (default: low)", + ) + p_capture.add_argument( + "--format", + "-f", + choices=["human", "json"], + default="human", + help="Output format (default: human)", + ) + p_capture.add_argument( + "--auto-add", + action="store_true", + help="Automatically add candidates to workspace (skip review)", + ) + p_capture.add_argument( + "--no-rebuild", + action="store_true", + help="Skip rebuild after auto-add", + ) + p_capture.set_defaults(func=cmd_capture) + return parser diff --git a/src/ai_memory_protocol/executor.py b/src/ai_memory_protocol/executor.py new file mode 100644 index 0000000..a71b5dc --- /dev/null +++ b/src/ai_memory_protocol/executor.py @@ -0,0 +1,448 @@ +"""Execute planned maintenance actions against the memory workspace. + +The executor takes a list of ``Action`` objects (from ``planner.py``) +and applies them sequentially using existing ``rst.py`` functions. +Includes git-based rollback on build failure. + +Usage: + from ai_memory_protocol.executor import execute_plan + result = execute_plan(workspace, actions, auto_commit=False) +""" + +from __future__ import annotations + +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from .engine import run_rebuild +from .planner import Action +from .rst import ( + add_tags_in_rst, + append_to_rst, + deprecate_in_rst, + generate_rst_directive, + remove_tags_in_rst, + update_field_in_rst, +) + +# --------------------------------------------------------------------------- +# Result type +# --------------------------------------------------------------------------- + + +@dataclass +class ExecutionResult: + """Result of executing a plan.""" + + success: bool + applied: list[dict[str, Any]] = field(default_factory=list) + failed: list[dict[str, Any]] = field(default_factory=list) + skipped: list[dict[str, Any]] = field(default_factory=list) + build_output: str = "" + message: str = "" + + def to_dict(self) -> dict[str, Any]: + return { + "success": self.success, + "applied_count": len(self.applied), + "failed_count": len(self.failed), + "skipped_count": len(self.skipped), + "applied": self.applied, + "failed": self.failed, + "skipped": self.skipped, + "build_output": self.build_output, + "message": self.message, + } + + def summary(self) -> str: + parts = [self.message] if self.message else [] + parts.append( + f"Applied: {len(self.applied)}, " + f"Failed: {len(self.failed)}, " + f"Skipped: {len(self.skipped)}" + ) + if self.build_output: + parts.append(f"Build: {self.build_output[:200]}") + return "\n".join(parts) + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +def validate_actions(actions: list[Action]) -> tuple[list[Action], list[dict[str, Any]]]: + """Validate actions before execution. + + Returns (valid_actions, skipped_with_reasons). + Checks: + - Circular supersedes (A supersedes B, B supersedes A) + - Missing required fields per action kind + """ + valid: list[Action] = [] + skipped: list[dict[str, Any]] = [] + + # Build supersede graph for cycle detection + supersede_map: dict[str, str] = {} + for a in actions: + if a.kind == "SUPERSEDE" and a.old_id and a.by_id: + supersede_map[a.old_id] = a.by_id + + for a in actions: + # Check required fields + if a.kind == "RETAG" and not a.id: + skipped.append({"action": a.to_dict(), "reason": "RETAG requires 'id'"}) + continue + if a.kind == "SUPERSEDE" and not a.old_id: + skipped.append({"action": a.to_dict(), "reason": "SUPERSEDE requires 'old_id'"}) + continue + if a.kind == "DEPRECATE" and not a.id: + skipped.append({"action": a.to_dict(), "reason": "DEPRECATE requires 'id'"}) + continue + if a.kind == "UPDATE" and not a.id: + skipped.append({"action": a.to_dict(), "reason": "UPDATE requires 'id'"}) + continue + if a.kind == "SPLIT_FILE" and not a.rst_path: + skipped.append({"action": a.to_dict(), "reason": "SPLIT_FILE requires 'rst_path'"}) + continue + + # Check supersede cycles + if a.kind == "SUPERSEDE" and a.old_id: + visited: set[str] = set() + current = a.old_id + cycle = False + while current in supersede_map: + if current in visited: + cycle = True + break + visited.add(current) + current = supersede_map[current] + if cycle: + skipped.append( + { + "action": a.to_dict(), + "reason": f"Circular supersede chain involving {a.old_id}", + } + ) + continue + + valid.append(a) + + return valid, skipped + + +# --------------------------------------------------------------------------- +# Individual action executors +# --------------------------------------------------------------------------- + + +def _execute_retag(workspace: Path, action: Action) -> tuple[bool, str]: + """Execute a RETAG action — add/remove tags on a memory.""" + messages: list[str] = [] + ok = True + + if action.remove_tags: + success, msg = remove_tags_in_rst(workspace, action.id, action.remove_tags) + messages.append(msg) + ok = ok and success + + if action.add_tags: + success, msg = add_tags_in_rst(workspace, action.id, action.add_tags) + messages.append(msg) + ok = ok and success + + return ok, "; ".join(messages) + + +def _execute_supersede(workspace: Path, action: Action) -> tuple[bool, str]: + """Execute a SUPERSEDE action — deprecate old, optionally create new.""" + messages: list[str] = [] + + # Deprecate the old memory + ok, msg = deprecate_in_rst(workspace, action.old_id, action.by_id) + messages.append(msg) + + if not ok: + return False, f"Failed to deprecate {action.old_id}: {msg}" + + # If new memory details provided, create it + if action.new_type and action.new_title: + directive = generate_rst_directive( + mem_type=action.new_type, + title=action.new_title, + tags=action.new_tags or [], + body=action.new_body or "", + supersedes=[action.old_id], + ) + target = append_to_rst(workspace, action.new_type, directive) + messages.append(f"Created replacement in {target.name}") + + return True, "; ".join(messages) + + +def _execute_deprecate(workspace: Path, action: Action) -> tuple[bool, str]: + """Execute a DEPRECATE action.""" + return deprecate_in_rst(workspace, action.id, action.by_id or None) + + +def _execute_update(workspace: Path, action: Action) -> tuple[bool, str]: + """Execute an UPDATE action — change metadata fields.""" + if not action.field_changes: + return True, f"No field changes for {action.id}" + + messages: list[str] = [] + all_ok = True + + for field_name, value in action.field_changes.items(): + ok, msg = update_field_in_rst(workspace, action.id, field_name, value) + messages.append(msg) + all_ok = all_ok and ok + + return all_ok, "; ".join(messages) + + +def _execute_prune(workspace: Path, action: Action) -> tuple[bool, str]: + """Execute a PRUNE action — deprecate without replacement.""" + return deprecate_in_rst(workspace, action.id) + + +def _execute_split_file(workspace: Path, action: Action) -> tuple[bool, str]: # noqa: ARG001 + """Execute a SPLIT_FILE action. + + This is informational — actual splitting happens automatically + via rst.py append_to_rst when MAX_ENTRIES_PER_FILE is exceeded. + """ + return True, ( + f"File splitting noted for {action.rst_path} — handled automatically on next append." + ) + + +# Dispatcher +_EXECUTORS = { + "RETAG": _execute_retag, + "SUPERSEDE": _execute_supersede, + "DEPRECATE": _execute_deprecate, + "UPDATE": _execute_update, + "PRUNE": _execute_prune, + "SPLIT_FILE": _execute_split_file, +} + + +# --------------------------------------------------------------------------- +# Git operations for rollback +# --------------------------------------------------------------------------- + + +def _git_stash_push(workspace: Path) -> bool: + """Stash uncommitted changes for rollback. Returns True if stash was created.""" + try: + result = subprocess.run( + ["git", "stash", "push", "-m", "memory_apply pre-backup"], + cwd=str(workspace), + capture_output=True, + text=True, + ) + # Only treat as successful if git exited cleanly. + if result.returncode != 0: + return False + # "No local changes to save" means nothing was stashed. + # This may appear in stdout or stderr. + output = (result.stdout or "") + (result.stderr or "") + return "No local changes to save" not in output + except OSError: + return False + + +def _git_stash_pop(workspace: Path) -> bool: + """Pop stashed changes to rollback.""" + try: + result = subprocess.run( + ["git", "stash", "pop"], + cwd=str(workspace), + capture_output=True, + text=True, + ) + return result.returncode == 0 + except OSError: + return False + + +def _git_stash_drop(workspace: Path) -> bool: + """Drop the stash (cleanup after successful apply).""" + try: + result = subprocess.run( + ["git", "stash", "drop"], + cwd=str(workspace), + capture_output=True, + text=True, + ) + return result.returncode == 0 + except OSError: + return False + + +def _git_commit(workspace: Path, message: str) -> bool: + """Stage and commit memory changes.""" + try: + subprocess.run( + ["git", "add", "memory/", "*.rst"], + cwd=str(workspace), + capture_output=True, + text=True, + ) + result = subprocess.run( + ["git", "commit", "-m", message], + cwd=str(workspace), + capture_output=True, + text=True, + ) + return result.returncode == 0 + except OSError: + return False + + +# --------------------------------------------------------------------------- +# Main execution entry point +# --------------------------------------------------------------------------- + + +def execute_plan( + workspace: Path, + actions: list[Action], + auto_commit: bool = False, + rebuild: bool = True, +) -> ExecutionResult: + """Execute a list of planned actions. + + Parameters + ---------- + workspace + Path to the memory workspace. + actions + Actions to execute (from ``run_plan`` or deserialized from JSON). + auto_commit + If True, commit changes to git after successful execution. + rebuild + If True, run Sphinx rebuild after applying actions. + + Returns + ------- + ExecutionResult + Summary of applied/failed/skipped actions + build output. + """ + # Validate + valid_actions, skipped = validate_actions(actions) + + if not valid_actions: + return ExecutionResult( + success=True, + skipped=skipped, + message="No valid actions to execute.", + ) + + # Stash for rollback + stashed = _git_stash_push(workspace) + + # Execute sequentially + applied: list[dict[str, Any]] = [] + failed: list[dict[str, Any]] = [] + + for action in valid_actions: + executor = _EXECUTORS.get(action.kind) + if not executor: + failed.append( + { + "action": action.to_dict(), + "error": f"Unknown action kind: {action.kind}", + } + ) + continue + + try: + ok, msg = executor(workspace, action) + entry = {"action": action.to_dict(), "message": msg} + if ok: + applied.append(entry) + else: + failed.append({**entry, "error": msg}) + except Exception as e: + failed.append({"action": action.to_dict(), "error": str(e)}) + + # Rebuild + build_output = "" + build_ok = True + if rebuild and applied: + build_ok, build_output = run_rebuild(workspace) + + # If build failed, always treat as unsuccessful; use git stash for rollback + # when available. + if not build_ok: + if stashed: + _git_stash_pop(workspace) + applied_result: list[dict[str, Any]] = [] + message = "Build failed after applying actions — rolled back via git stash pop." + else: + # No stash available: cannot automatically roll back workspace changes. + applied_result = applied + message = ( + "Build failed after applying actions — no git stash available for " + "rollback; workspace may be in an inconsistent state." + ) + + return ExecutionResult( + success=False, + applied=applied_result, + failed=failed, + skipped=skipped, + build_output=build_output, + message=message, + ) + + # Cleanup stash on success + if stashed: + _git_stash_drop(workspace) + + # Auto-commit + if auto_commit and applied: + kinds = set(a.get("action", {}).get("kind", "?") for a in applied) + msg = f"memory: auto-apply {', '.join(sorted(kinds))} ({len(applied)} actions)" + _git_commit(workspace, msg) + + all_succeeded = not failed + return ExecutionResult( + success=all_succeeded, + applied=applied, + failed=failed, + skipped=skipped, + build_output=build_output, + message=( + f"Plan executed: {len(applied)} applied, {len(failed)} failed, {len(skipped)} skipped." + ), + ) + + +def actions_from_json(data: list[dict[str, Any]]) -> list[Action]: + """Deserialize a list of action dicts (e.g. from JSON) into Action objects.""" + actions: list[Action] = [] + for d in data: + actions.append( + Action( + kind=d.get("kind", "UPDATE"), + reason=d.get("reason", ""), + id=d.get("id", ""), + add_tags=d.get("add_tags", []), + remove_tags=d.get("remove_tags", []), + field_changes=d.get("field_changes", {}), + old_id=d.get("old_id", ""), + new_type=d.get("new_type", ""), + new_title=d.get("new_title", ""), + new_body=d.get("new_body", ""), + new_tags=d.get("new_tags", []), + new_links=d.get("new_links", []), + by_id=d.get("by_id", ""), + rst_path=d.get("rst_path", ""), + ) + ) + return actions diff --git a/src/ai_memory_protocol/mcp_server.py b/src/ai_memory_protocol/mcp_server.py index 9901bab..6c8d07a 100644 --- a/src/ai_memory_protocol/mcp_server.py +++ b/src/ai_memory_protocol/mcp_server.py @@ -15,14 +15,24 @@ import json import logging +import sys from datetime import date from pathlib import Path from typing import Any -from mcp.server import Server -from mcp.types import TextContent, Tool +# Lazy MCP import — provide helpful error instead of ImportError crash +_MCP_AVAILABLE = False +try: + from mcp.server import Server + from mcp.types import TextContent, Tool -from .engine import ( + _MCP_AVAILABLE = True +except ImportError: + Server = None # type: ignore[assignment,misc] + TextContent = None # type: ignore[assignment,misc] + Tool = None # type: ignore[assignment,misc] + +from .engine import ( # noqa: E402 expand_graph, find_workspace, load_needs, @@ -31,8 +41,13 @@ tag_match, text_match, ) -from .formatter import format_brief, format_compact, format_context_pack, format_full -from .rst import ( +from .formatter import ( # noqa: E402 + format_brief, + format_compact, + format_context_pack, + format_full, +) +from .rst import ( # noqa: E402 add_tags_in_rst, append_to_rst, deprecate_in_rst, @@ -51,6 +66,13 @@ def create_mcp_server(name: str = "ai-memory-protocol") -> Server: """Create and configure the MCP server with all memory tools.""" + if not _MCP_AVAILABLE: + raise ImportError( + "MCP SDK not installed. Install with:\n" + " pipx install -e '.[mcp]'\n" + "Or if already installed via pipx:\n" + " pipx inject ai-memory-protocol mcp\n" + ) server = Server(name) _register_tools(server) _register_handlers(server) @@ -61,257 +83,395 @@ def create_mcp_server(name: str = "ai-memory-protocol") -> Server: # Tool definitions # --------------------------------------------------------------------------- -TOOLS: list[Tool] = [ - Tool( - name="memory_recall", - description=( - "Search memories by free text query and/or tags. " - "Returns matching memories formatted for context windows. " - "Use format='brief' FIRST to peek at titles, then memory_get to drill into specific IDs. " - "Recall at EVERY topic transition: new task, unfamiliar code, before decisions, when stuck — not just session start." + +def _build_tools() -> list: + """Build tool definitions. Returns empty list if MCP SDK not available.""" + if not _MCP_AVAILABLE: + return [] + return [ + Tool( + name="memory_recall", + description=( + "Search memories by free text query and/or tags. " + "Returns matching memories formatted for context windows. " + "Use format='brief' FIRST to peek at titles, " + "then memory_get to drill into specific IDs. " + "Recall at EVERY topic transition: new task, " + "unfamiliar code, before decisions, when stuck " + "— not just session start." + ), + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": ( + "Free-text search query (OR logic across words). " + "Optional if tag is provided." + ), + }, + "tag": { + "type": "string", + "description": ( + "Comma-separated tag filters (AND logic). " + "E.g. 'topic:gateway,repo:ros2_medkit'." + ), + }, + "type": { + "type": "string", + "description": "Filter by memory type: mem, dec, fact, pref, risk, goal, q.", + "enum": ["mem", "dec", "fact", "pref", "risk", "goal", "q"], + }, + "format": { + "type": "string", + "description": ( + "Output format: brief (minimal tokens), " + "compact (one-liner), context (grouped by " + "type, default), json." + ), + "enum": ["brief", "compact", "context", "json"], + "default": "context", + }, + "limit": { + "type": "integer", + "description": "Maximum number of results to return. 0 = unlimited.", + "default": 0, + }, + "body": { + "type": "boolean", + "description": "Include body text in output. Default false to save tokens.", + "default": False, + }, + "sort": { + "type": "string", + "description": "Sort order for results.", + "enum": ["newest", "oldest", "confidence", "updated"], + }, + "expand": { + "type": "integer", + "description": ( + "Graph expansion hops from matched memories. " + "0 = exact matches only. Default 1." + ), + "default": 1, + }, + "stale": { + "type": "boolean", + "description": "If true, show only expired or review-overdue memories.", + "default": False, + }, + }, + "required": [], + }, ), - inputSchema={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Free-text search query (OR logic across words). Optional if tag is provided.", - }, - "tag": { - "type": "string", - "description": "Comma-separated tag filters (AND logic). E.g. 'topic:gateway,repo:ros2_medkit'.", - }, - "type": { - "type": "string", - "description": "Filter by memory type: mem, dec, fact, pref, risk, goal, q.", - "enum": ["mem", "dec", "fact", "pref", "risk", "goal", "q"], - }, - "format": { - "type": "string", - "description": "Output format: brief (minimal tokens), compact (one-liner), context (grouped by type, default), json.", - "enum": ["brief", "compact", "context", "json"], - "default": "context", - }, - "limit": { - "type": "integer", - "description": "Maximum number of results to return. 0 = unlimited.", - "default": 0, - }, - "body": { - "type": "boolean", - "description": "Include body text in output. Default false to save tokens.", - "default": False, - }, - "sort": { - "type": "string", - "description": "Sort order for results.", - "enum": ["newest", "oldest", "confidence", "updated"], - }, - "expand": { - "type": "integer", - "description": "Graph expansion hops from matched memories. 0 = exact matches only. Default 1.", - "default": 1, - }, - "stale": { - "type": "boolean", - "description": "If true, show only expired or review-overdue memories.", - "default": False, - }, + Tool( + name="memory_get", + description=( + "Get full details of a specific memory by ID — the DRILL step. " + "Always shows body text. Use AFTER memory_recall(format='brief') to read " + "the 2-3 most relevant memories. Never skip the peek step." + ), + inputSchema={ + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Memory ID, e.g. DEC_rest_framework or FACT_gateway_port.", + }, + }, + "required": ["id"], }, - "required": [], - }, - ), - Tool( - name="memory_get", - description=( - "Get full details of a specific memory by ID — the DRILL step. " - "Always shows body text. Use AFTER memory_recall(format='brief') to read " - "the 2-3 most relevant memories. Never skip the peek step." ), - inputSchema={ - "type": "object", - "properties": { - "id": { - "type": "string", - "description": "Memory ID, e.g. DEC_rest_framework or FACT_gateway_port.", - }, + Tool( + name="memory_add", + description=( + "Record a new memory. You MUST use this at specific trigger points: " + "chose approach A over B → dec; fixed non-obvious bug → mem; " + "discovered undocumented pattern → fact; user stated preference → pref; " + "identified risk → risk; question unanswered → q. " + "Also batch-write at end of task: architecture learned → fact, conventions → pref. " + "Tags are mandatory. Body must be self-contained with file paths and concrete details." + ), + inputSchema={ + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "Memory type.", + "enum": ["mem", "dec", "fact", "pref", "risk", "goal", "q"], + }, + "title": { + "type": "string", + "description": "Short title for the memory.", + }, + "tags": { + "type": "string", + "description": ( + "Comma-separated tags in prefix:value format. " + "E.g. 'topic:api,repo:backend'." + ), + }, + "body": { + "type": "string", + "description": "Detailed description text.", + }, + "confidence": { + "type": "string", + "description": "Trust level.", + "enum": ["low", "medium", "high"], + "default": "medium", + }, + "source": { + "type": "string", + "description": "Provenance — URL, commit, ticket, or description of origin.", + }, + "scope": { + "type": "string", + "description": "Applicability scope. E.g. 'global', 'repo:ros2_medkit'.", + "default": "global", + }, + "relates": { + "type": "string", + "description": "Comma-separated IDs of related memories.", + }, + "supersedes": { + "type": "string", + "description": "Comma-separated IDs that this memory supersedes.", + }, + "id": { + "type": "string", + "description": "Custom memory ID. Auto-generated from type + title if omitted.", + }, + "rebuild": { + "type": "boolean", + "description": "Auto-rebuild needs.json after adding. Default true.", + "default": True, + }, + }, + "required": ["type", "title", "tags"], }, - "required": ["id"], - }, - ), - Tool( - name="memory_add", - description=( - "Record a new memory. You MUST use this at specific trigger points: " - "chose approach A over B → dec; fixed non-obvious bug → mem; " - "discovered undocumented pattern → fact; user stated preference → pref; " - "identified risk → risk; question unanswered → q. " - "Also batch-write at end of task: architecture learned → fact, conventions → pref. " - "Tags are mandatory. Body must be self-contained with file paths and concrete details." ), - inputSchema={ - "type": "object", - "properties": { - "type": { - "type": "string", - "description": "Memory type.", - "enum": ["mem", "dec", "fact", "pref", "risk", "goal", "q"], - }, - "title": { - "type": "string", - "description": "Short title for the memory.", - }, - "tags": { - "type": "string", - "description": "Comma-separated tags in prefix:value format. E.g. 'topic:api,repo:backend'.", - }, - "body": { - "type": "string", - "description": "Detailed description text.", - }, - "confidence": { - "type": "string", - "description": "Trust level.", - "enum": ["low", "medium", "high"], - "default": "medium", - }, - "source": { - "type": "string", - "description": "Provenance — URL, commit, ticket, or description of origin.", - }, - "scope": { - "type": "string", - "description": "Applicability scope. E.g. 'global', 'repo:ros2_medkit'.", - "default": "global", - }, - "relates": { - "type": "string", - "description": "Comma-separated IDs of related memories.", - }, - "supersedes": { - "type": "string", - "description": "Comma-separated IDs that this memory supersedes.", - }, - "id": { - "type": "string", - "description": "Custom memory ID. Auto-generated from type + title if omitted.", - }, - "rebuild": { - "type": "boolean", - "description": "Auto-rebuild needs.json after adding. Default true.", - "default": True, - }, + Tool( + name="memory_update", + description=( + "Update metadata on an existing memory. " + "Can change status, confidence, scope, tags, review date, etc." + ), + inputSchema={ + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Memory ID to update.", + }, + "status": { + "type": "string", + "description": "New status.", + "enum": ["draft", "active", "promoted", "deprecated", "review"], + }, + "confidence": { + "type": "string", + "description": "New confidence level.", + "enum": ["low", "medium", "high"], + }, + "scope": { + "type": "string", + "description": "New scope.", + }, + "review_after": { + "type": "string", + "description": "New review date (ISO-8601, e.g. 2026-06-01).", + }, + "source": { + "type": "string", + "description": "New source/provenance.", + }, + "add_tags": { + "type": "string", + "description": "Tags to add, comma-separated.", + }, + "remove_tags": { + "type": "string", + "description": "Tags to remove, comma-separated.", + }, + }, + "required": ["id"], }, - "required": ["type", "title", "tags"], - }, - ), - Tool( - name="memory_update", - description=( - "Update metadata on an existing memory. " - "Can change status, confidence, scope, tags, review date, etc." ), - inputSchema={ - "type": "object", - "properties": { - "id": { - "type": "string", - "description": "Memory ID to update.", - }, - "status": { - "type": "string", - "description": "New status.", - "enum": ["draft", "active", "promoted", "deprecated", "review"], - }, - "confidence": { - "type": "string", - "description": "New confidence level.", - "enum": ["low", "medium", "high"], - }, - "scope": { - "type": "string", - "description": "New scope.", - }, - "review_after": { - "type": "string", - "description": "New review date (ISO-8601, e.g. 2026-06-01).", - }, - "source": { - "type": "string", - "description": "New source/provenance.", - }, - "add_tags": { - "type": "string", - "description": "Tags to add, comma-separated.", - }, - "remove_tags": { - "type": "string", - "description": "Tags to remove, comma-separated.", - }, + Tool( + name="memory_deprecate", + description=( + "Mark a memory as deprecated. Optionally specify the superseding memory. " + "Use this instead of editing — supersede, don't silently edit." + ), + inputSchema={ + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Memory ID to deprecate.", + }, + "by": { + "type": "string", + "description": "ID of the superseding memory.", + }, + }, + "required": ["id"], }, - "required": ["id"], - }, - ), - Tool( - name="memory_deprecate", - description=( - "Mark a memory as deprecated. Optionally specify the superseding memory. " - "Use this instead of editing — supersede, don't silently edit." ), - inputSchema={ - "type": "object", - "properties": { - "id": { - "type": "string", - "description": "Memory ID to deprecate.", - }, - "by": { - "type": "string", - "description": "ID of the superseding memory.", - }, + Tool( + name="memory_tags", + description=( + "List all tags in use with counts, grouped by prefix. " + "Use before filtering to discover available tag prefixes." + ), + inputSchema={ + "type": "object", + "properties": { + "prefix": { + "type": "string", + "description": "Filter by tag prefix, e.g. 'topic' to see only topic:* tags.", + }, + }, + "required": [], }, - "required": ["id"], - }, - ), - Tool( - name="memory_tags", - description=( - "List all tags in use with counts, grouped by prefix. " - "Use before filtering to discover available tag prefixes." ), - inputSchema={ - "type": "object", - "properties": { - "prefix": { - "type": "string", - "description": "Filter by tag prefix, e.g. 'topic' to see only topic:* tags.", - }, + Tool( + name="memory_stale", + description=( + "Show expired or review-overdue memories. " + "Use periodically to keep the memory graph fresh." + ), + inputSchema={ + "type": "object", + "properties": {}, + "required": [], }, - "required": [], - }, - ), - Tool( - name="memory_stale", - description="Show expired or review-overdue memories. Use periodically to keep the memory graph fresh.", - inputSchema={ - "type": "object", - "properties": {}, - "required": [], - }, - ), - Tool( - name="memory_rebuild", - description=( - "Rebuild needs.json from RST sources by running Sphinx build. " - "Required after adding or modifying memories to make changes searchable." ), - inputSchema={ - "type": "object", - "properties": {}, - "required": [], - }, - ), -] + Tool( + name="memory_rebuild", + description=( + "Rebuild needs.json from RST sources by running Sphinx build. " + "Required after adding or modifying memories to make changes searchable." + ), + inputSchema={ + "type": "object", + "properties": {}, + "required": [], + }, + ), + Tool( + name="memory_plan", + description=( + "Analyze memory graph and return planned maintenance actions (no modifications). " + "Checks for duplicates, missing tags, stale entries, conflicts, tag normalization, " + "and oversized files. Returns a list of proposed actions." + ), + inputSchema={ + "type": "object", + "properties": { + "checks": { + "type": "array", + "items": { + "type": "string", + "enum": [ + "duplicates", + "missing_tags", + "stale", + "conflicts", + "tag_normalize", + "split_files", + ], + }, + "description": "Which checks to run. Default: all.", + }, + "format": { + "type": "string", + "enum": ["human", "json"], + "default": "human", + "description": "Output format. 'json' for machine-readable actions.", + }, + }, + "required": [], + }, + ), + Tool( + name="memory_apply", + description=( + "Execute a list of planned memory actions, rebuild, and validate. " + "Includes git-based rollback on build failure. Pass actions from memory_plan output." + ), + inputSchema={ + "type": "object", + "properties": { + "actions": { + "type": "array", + "items": {"type": "object"}, + "description": "Actions from memory_plan output (JSON format).", + }, + "auto_commit": { + "type": "boolean", + "default": False, + "description": "Commit changes to git after successful apply.", + }, + }, + "required": ["actions"], + }, + ), + Tool( + name="memory_capture_git", + description=( + "Analyze git log and generate memory candidates from commit history. " + "Classifies commits by conventional commit prefix " + "and deduplicates against existing memories." + ), + inputSchema={ + "type": "object", + "properties": { + "repo_path": { + "type": "string", + "description": "Path to git repository. Default: current directory.", + }, + "since": { + "type": "string", + "default": "HEAD~20", + "description": "Start of range (commit ref or date like '2 weeks ago').", + }, + "until": { + "type": "string", + "default": "HEAD", + "description": "End of range.", + }, + "repo_name": { + "type": "string", + "description": "Repository name for repo: tags. Auto-detected if omitted.", + }, + "min_confidence": { + "type": "string", + "enum": ["low", "medium", "high"], + "default": "low", + "description": "Minimum confidence to include.", + }, + "format": { + "type": "string", + "enum": ["human", "json"], + "default": "human", + "description": "Output format.", + }, + "auto_add": { + "type": "boolean", + "default": False, + "description": "Automatically add candidates to workspace.", + }, + }, + "required": [], + }, + ), + ] + + +# Module-level TOOLS — populated lazily so the module can be imported without MCP SDK +TOOLS: list = _build_tools() def _register_tools(server: Server) -> None: @@ -428,6 +588,12 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: return _handle_stale(arguments) elif name == "memory_rebuild": return _handle_rebuild(arguments) + elif name == "memory_plan": + return _handle_plan(arguments) + elif name == "memory_apply": + return _handle_apply(arguments) + elif name == "memory_capture_git": + return _handle_capture_git(arguments) else: return _text_response(f"Unknown tool: {name}") except SystemExit as e: @@ -663,6 +829,68 @@ def _handle_rebuild(args: dict[str, Any]) -> list[TextContent]: return _text_response(result) +def _handle_plan(args: dict[str, Any]) -> list[TextContent]: + from .planner import format_plan, run_plan + + workspace = _get_workspace() + checks = args.get("checks") # list[str] or None + fmt = args.get("format", "human") + actions = run_plan(workspace, checks=checks) + output = format_plan(actions, fmt=fmt) + return _text_response(output) + + +def _handle_apply(args: dict[str, Any]) -> list[TextContent]: + from .executor import actions_from_json, execute_plan + + workspace = _get_workspace() + raw_actions = args.get("actions", []) + auto_commit = args.get("auto_commit", False) + actions = actions_from_json(raw_actions) + result = execute_plan(workspace, actions, auto_commit=auto_commit) + return _text_response(result.summary()) + + +def _handle_capture_git(args: dict[str, Any]) -> list[TextContent]: + from .capture import capture_from_git, format_candidates + from .rst import append_to_rst, generate_rst_directive + + workspace = _get_workspace() + repo_path = Path(args.get("repo_path", ".")).resolve() + candidates = capture_from_git( + workspace=workspace, + repo_path=repo_path, + since=args.get("since", "HEAD~20"), + until=args.get("until", "HEAD"), + repo_name=args.get("repo_name"), + min_confidence=args.get("min_confidence", "low"), + ) + + output_lines: list[str] = [] + fmt = args.get("format", "human") + output_lines.append(format_candidates(candidates, fmt=fmt)) + + if args.get("auto_add", False) and candidates: + count = 0 + for c in candidates: + directive = generate_rst_directive( + mem_type=c.type, + title=c.title, + tags=c.tags, + source=c.source, + confidence=c.confidence, + scope=c.scope, + body=c.body, + ) + append_to_rst(workspace, c.type, directive) + count += 1 + output_lines.append(f"\nAdded {count} memories to workspace.") + success, msg = run_rebuild(workspace) + output_lines.append(msg) + + return _text_response("\n".join(output_lines)) + + # --------------------------------------------------------------------------- # Entry points # --------------------------------------------------------------------------- @@ -670,6 +898,17 @@ def _handle_rebuild(args: dict[str, Any]) -> list[TextContent]: def main_stdio() -> None: """Run the MCP server over stdio transport.""" + if not _MCP_AVAILABLE: + print( + "ERROR: MCP SDK not installed.\n" + "Install with:\n" + " pipx install -e '.[mcp]'\n" + "Or if already installed via pipx:\n" + " pipx inject ai-memory-protocol mcp", + file=sys.stderr, + ) + sys.exit(1) + import asyncio from mcp.server.stdio import stdio_server diff --git a/src/ai_memory_protocol/planner.py b/src/ai_memory_protocol/planner.py new file mode 100644 index 0000000..3a36e76 --- /dev/null +++ b/src/ai_memory_protocol/planner.py @@ -0,0 +1,431 @@ +"""Plan analysis — detect problems and generate maintenance actions. + +The planner is **read-only**: it loads needs.json + RST files, runs +detection algorithms, and returns a list of ``Action`` dicts describing +what *should* be done. The executor (``executor.py``) is responsible +for actually applying actions. + +Usage: + from ai_memory_protocol.planner import run_plan + actions = run_plan(workspace, checks=["duplicates", "missing_tags"]) +""" + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import asdict, dataclass, field +from datetime import date +from difflib import SequenceMatcher +from pathlib import Path +from typing import Any, Literal + +from .config import TYPE_FILES +from .engine import load_needs +from .rst import MAX_ENTRIES_PER_FILE, _count_entries, _find_all_rst_files + +# --------------------------------------------------------------------------- +# Action types +# --------------------------------------------------------------------------- + +ActionKind = Literal["RETAG", "SUPERSEDE", "DEPRECATE", "UPDATE", "PRUNE", "SPLIT_FILE"] + +ALL_CHECKS: list[str] = [ + "duplicates", + "missing_tags", + "stale", + "conflicts", + "tag_normalize", + "split_files", +] + + +@dataclass +class Action: + """A planned maintenance action.""" + + kind: ActionKind + reason: str + # Fields used by individual action kinds + id: str = "" + add_tags: list[str] = field(default_factory=list) + remove_tags: list[str] = field(default_factory=list) + field_changes: dict[str, str] = field(default_factory=dict) + # SUPERSEDE-specific + old_id: str = "" + new_type: str = "" + new_title: str = "" + new_body: str = "" + new_tags: list[str] = field(default_factory=list) + new_links: list[str] = field(default_factory=list) + # DEPRECATE-specific + by_id: str = "" + # SPLIT_FILE-specific + rst_path: str = "" + + def to_dict(self) -> dict[str, Any]: + """Serialize to a plain dict, omitting empty fields.""" + d = asdict(self) + return {k: v for k, v in d.items() if v} + + +# --------------------------------------------------------------------------- +# Detection algorithms +# --------------------------------------------------------------------------- + + +def _active_needs(needs: dict[str, Any]) -> dict[str, Any]: + """Filter to non-deprecated needs only.""" + return {k: v for k, v in needs.items() if v.get("status") != "deprecated"} + + +def detect_duplicates( + needs: dict[str, Any], + title_threshold: float = 0.8, + tag_overlap_threshold: float = 0.5, +) -> list[Action]: + """Find near-duplicate memories by title similarity + tag overlap. + + Complexity: O(n²) — acceptable for n < 500. + """ + active = _active_needs(needs) + items = list(active.items()) + actions: list[Action] = [] + seen_pairs: set[tuple[str, str]] = set() + + for i, (id1, n1) in enumerate(items): + for id2, n2 in items[i + 1 :]: + pair = tuple(sorted((id1, id2))) + if pair in seen_pairs: + continue + + title_sim = SequenceMatcher( + None, n1.get("title", "").lower(), n2.get("title", "").lower() + ).ratio() + if title_sim < title_threshold: + continue + + tags1 = set(n1.get("tags", [])) + tags2 = set(n2.get("tags", [])) + union = tags1 | tags2 + if not union: + continue + tag_overlap = len(tags1 & tags2) / len(union) + if tag_overlap < tag_overlap_threshold: + continue + + seen_pairs.add(pair) + + # Prefer newer + higher-confidence as canonical + conf_rank = {"high": 2, "medium": 1, "low": 0} + score1 = ( + conf_rank.get(n1.get("confidence", "medium"), 1), + n1.get("created_at", ""), + ) + score2 = ( + conf_rank.get(n2.get("confidence", "medium"), 1), + n2.get("created_at", ""), + ) + + if score2 > score1: + old_id, new_id = id1, id2 + else: + old_id, new_id = id2, id1 + + actions.append( + Action( + kind="SUPERSEDE", + reason=( + f"Near-duplicate: title similarity {title_sim:.0%}, " + f"tag overlap {tag_overlap:.0%}. " + f"Keep {new_id} (higher score), deprecate {old_id}." + ), + old_id=old_id, + by_id=new_id, + ) + ) + + return actions + + +def detect_missing_tags(needs: dict[str, Any]) -> list[Action]: + """Find memories without required tag prefixes (topic: or repo:). + + O(n) — checks every active need once. + """ + active = _active_needs(needs) + actions: list[Action] = [] + + for nid, need in active.items(): + tags = need.get("tags", []) + has_topic = any(t.startswith("topic:") for t in tags) + has_repo = any(t.startswith("repo:") for t in tags) + missing: list[str] = [] + if not has_topic: + missing.append("topic:") + if not has_repo: + missing.append("repo:") + if missing: + actions.append( + Action( + kind="RETAG", + reason=f"Missing required tag prefix(es): {', '.join(missing)}", + id=nid, + ) + ) + + return actions + + +def detect_stale(needs: dict[str, Any]) -> list[Action]: + """Find expired or review-overdue memories. + + O(n) — mirrors the logic in ``cmd_stale`` but returns actions. + """ + active = _active_needs(needs) + today = date.today().isoformat() + actions: list[Action] = [] + + for nid, need in active.items(): + expires = need.get("expires_at", "") + review = need.get("review_after", "") + + if expires and expires <= today: + actions.append( + Action( + kind="UPDATE", + reason=f"Expired on {expires} — needs review or deprecation.", + id=nid, + field_changes={"status": "review"}, + ) + ) + elif review and review <= today: + actions.append( + Action( + kind="UPDATE", + reason=f"Review overdue since {review}.", + id=nid, + field_changes={"status": "review"}, + ) + ) + + return actions + + +def detect_conflicts(needs: dict[str, Any]) -> list[Action]: + """Find active needs with same topic but no contradicts link. + + Heuristic: two decisions on the same topic:* tag with no explicit + relationship may indicate an unrecorded contradiction. + + O(n²) per topic group — practical for small graphs. + """ + active = _active_needs(needs) + # Group decisions by topic tag + by_topic: dict[str, list[str]] = defaultdict(list) + for nid, need in active.items(): + if need.get("type") not in ("dec", "pref"): + continue + for tag in need.get("tags", []): + if tag.startswith("topic:"): + by_topic[tag].append(nid) + + actions: list[Action] = [] + seen: set[tuple[str, str]] = set() + + for _topic, ids in by_topic.items(): + if len(ids) < 2: + continue + for i, id1 in enumerate(ids): + for id2 in ids[i + 1 :]: + pair = tuple(sorted((id1, id2))) + if pair in seen: + continue + seen.add(pair) + + n1, n2 = active[id1], active[id2] + # Check if they already have a relationship + links1 = set() + links2 = set() + for lt in ("relates", "supports", "depends", "contradicts", "supersedes"): + links1.update(n1.get(lt, [])) + links2.update(n2.get(lt, [])) + + if id2 not in links1 and id1 not in links2: + actions.append( + Action( + kind="UPDATE", + reason=( + f"Potential conflict: {id1} and {id2} are both " + f"active {n1.get('type')}/{n2.get('type')} entries on " + f"the same topic with no explicit relationship link." + ), + id=id1, + field_changes={"status": "review"}, + ) + ) + + return actions + + +def detect_tag_normalization(needs: dict[str, Any]) -> list[Action]: + """Find case-insensitive tag duplicates and normalize to most common form. + + O(n) scan + O(types per lowered tag). + """ + active = _active_needs(needs) + # Collect all tag forms with usage counts + tag_usage: dict[str, int] = defaultdict(int) # exact form -> count + for need in active.values(): + for tag in need.get("tags", []): + tag_usage[tag] += 1 + + # Group by lowercased form + lower_groups: dict[str, set[str]] = defaultdict(set) + for tag in tag_usage: + lower_groups[tag.lower()].add(tag) + + actions: list[Action] = [] + for forms in lower_groups.values(): + if len(forms) <= 1: + continue + # Pick the most common form as canonical + canonical = max(forms, key=lambda t: tag_usage[t]) + non_canonical = forms - {canonical} + for form in non_canonical: + for nid, need in active.items(): + if form in need.get("tags", []): + actions.append( + Action( + kind="RETAG", + reason=f"Tag normalization: '{form}' → '{canonical}'", + id=nid, + remove_tags=[form], + add_tags=[canonical], + ) + ) + + return actions + + +def detect_split_files(workspace: Path) -> list[Action]: + """Find RST files that exceed MAX_ENTRIES_PER_FILE. + + O(files) — scans each RST file once. + """ + actions: list[Action] = [] + + for mem_type in TYPE_FILES: + for rst_path in _find_all_rst_files(workspace, mem_type): + count = _count_entries(rst_path) + if count > MAX_ENTRIES_PER_FILE: + actions.append( + Action( + kind="SPLIT_FILE", + reason=( + f"{rst_path.name} has {count} entries (limit: {MAX_ENTRIES_PER_FILE})." + ), + rst_path=str(rst_path), + ) + ) + + return actions + + +# --------------------------------------------------------------------------- +# Public interface +# --------------------------------------------------------------------------- + +# Map check names to detector functions +_DETECTORS: dict[str, Any] = { + "duplicates": lambda needs, ws: detect_duplicates(needs), + "missing_tags": lambda needs, ws: detect_missing_tags(needs), + "stale": lambda needs, ws: detect_stale(needs), + "conflicts": lambda needs, ws: detect_conflicts(needs), + "tag_normalize": lambda needs, ws: detect_tag_normalization(needs), + "split_files": lambda needs, ws: detect_split_files(ws), +} + + +def run_plan( + workspace: Path, + checks: list[str] | None = None, + needs: dict[str, Any] | None = None, +) -> list[Action]: + """Run selected (or all) checks and return a unified list of actions. + + Parameters + ---------- + workspace + Path to the memory workspace (containing conf.py, memory/, etc.). + checks + Which checks to run. Defaults to all checks. + needs + Pre-loaded needs dict. If ``None``, loads from workspace. + + Returns + ------- + list[Action] + Ordered list of planned actions — not yet executed. + """ + if needs is None: + needs = load_needs(workspace) + + selected = checks or ALL_CHECKS + all_actions: list[Action] = [] + + for check in selected: + detector = _DETECTORS.get(check) + if detector is None: + raise ValueError( + f"Unknown check {check!r} requested. Available checks: {sorted(_DETECTORS.keys())}" + ) + # split_files only needs workspace, others need needs dict + actions = detector(needs, workspace) + all_actions.extend(actions) + + return all_actions + + +def format_plan(actions: list[Action], fmt: str = "human") -> str: + """Render a plan as human-readable text or JSON. + + Parameters + ---------- + fmt + ``"human"`` for readable text, ``"json"`` for machine-readable. + """ + if not actions: + return "No issues found — memory graph looks healthy." + + if fmt == "json": + import json + + return json.dumps([a.to_dict() for a in actions], indent=2, ensure_ascii=False) + + lines: list[str] = [f"## Memory Maintenance Plan — {len(actions)} action(s)\n"] + + # Group by kind + by_kind: dict[str, list[Action]] = defaultdict(list) + for a in actions: + by_kind[a.kind].append(a) + + for kind in ("SUPERSEDE", "DEPRECATE", "RETAG", "UPDATE", "PRUNE", "SPLIT_FILE"): + group = by_kind.get(kind, []) + if not group: + continue + lines.append(f"### {kind} ({len(group)})\n") + for a in group: + target = a.id or a.old_id or a.rst_path + lines.append(f" - **{target}**: {a.reason}") + if a.add_tags: + lines.append(f" + add tags: {', '.join(a.add_tags)}") + if a.remove_tags: + lines.append(f" - remove tags: {', '.join(a.remove_tags)}") + if a.field_changes: + for k, v in a.field_changes.items(): + lines.append(f" ~ {k} → {v}") + if a.by_id: + lines.append(f" → superseded by: {a.by_id}") + lines.append("") + + return "\n".join(lines) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..58b3a22 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,123 @@ +"""Shared test fixtures for AI Memory Protocol tests.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from ai_memory_protocol.config import TYPE_FILES + + +@pytest.fixture +def tmp_workspace(tmp_path: Path) -> Path: + """Create a minimal memory workspace in a temp directory.""" + ws = tmp_path / ".memories" + ws.mkdir() + + # Minimal conf.py + (ws / "conf.py").write_text( + 'project = "Test"\n' + 'extensions = ["sphinx_needs"]\n' + "needs_types = [\n" + ' {"directive": "mem", "title": "Observation", "prefix": "MEM_", ' + '"color": "#BDD7EE", "style": "node"},\n' + ' {"directive": "dec", "title": "Decision", "prefix": "DEC_", ' + '"color": "#B6D7A8", "style": "node"},\n' + ' {"directive": "fact", "title": "Fact", "prefix": "FACT_", ' + '"color": "#FFE599", "style": "node"},\n' + ' {"directive": "pref", "title": "Preference", "prefix": "PREF_", ' + '"color": "#D5A6BD", "style": "node"},\n' + ' {"directive": "risk", "title": "Risk", "prefix": "RISK_", ' + '"color": "#EA9999", "style": "node"},\n' + ' {"directive": "goal", "title": "Goal", "prefix": "GOAL_", ' + '"color": "#A4C2F4", "style": "node"},\n' + ' {"directive": "q", "title": "Open Question", "prefix": "Q_", ' + '"color": "#D9D2E9", "style": "node"},\n' + "]\n" + "needs_extra_options = {\n" + ' "source": {}, "owner": {}, "confidence": {}, "scope": {},\n' + ' "created_at": {}, "updated_at": {}, "expires_at": {}, "review_after": {},\n' + "}\n" + "needs_build_json = True\n" + ) + + # Minimal index.rst with toctree + (ws / "index.rst").write_text( + "Test Memory\n===========\n\n.. toctree::\n :glob:\n\n memory/*\n" + ) + + # Memory subdirectory + mem_dir = ws / "memory" + mem_dir.mkdir() + + # Create RST files matching TYPE_FILES + seen_files: set[str] = set() + for _mem_type, rel_path in TYPE_FILES.items(): + filename = rel_path.split("/")[-1] + if filename in seen_files: + continue + seen_files.add(filename) + header = filename.replace(".rst", "").title() + (mem_dir / filename).write_text(f"{'=' * len(header)}\n{header}\n{'=' * len(header)}\n\n") + + return ws + + +@pytest.fixture +def sample_needs() -> dict[str, dict]: + """Return a dict of sample needs for testing search/filter/format.""" + return { + "MEM_test_observation": { + "id": "MEM_test_observation", + "type": "mem", + "title": "Test observation about gateway", + "description": "The gateway uses port 8080 by default.", + "status": "active", + "tags": ["topic:gateway", "repo:ros2_medkit"], + "confidence": "high", + "created_at": "2026-01-15", + "review_after": "2026-07-15", + "expires_at": "", + "source": "manual", + "scope": "repo:ros2_medkit", + }, + "DEC_use_httplib": { + "id": "DEC_use_httplib", + "type": "dec", + "title": "Use cpp-httplib for REST server", + "description": "Header-only, simple, sufficient for SOVD API.", + "status": "active", + "tags": ["topic:gateway", "topic:http"], + "confidence": "high", + "created_at": "2026-01-10", + "review_after": "", + "expires_at": "", + "source": "architecture review", + "scope": "repo:ros2_medkit", + "relates": ["MEM_test_observation"], + }, + "FACT_deprecated": { + "id": "FACT_deprecated", + "type": "fact", + "title": "Old fact", + "description": "Deprecated.", + "status": "deprecated", + "tags": ["topic:old"], + "confidence": "low", + "created_at": "2025-01-01", + "review_after": "", + "expires_at": "", + }, + } + + +@pytest.fixture +def needs_json_file(tmp_workspace: Path, sample_needs: dict) -> Path: + """Create a needs.json file in the workspace build directory.""" + build_dir = tmp_workspace / "_build" / "html" + build_dir.mkdir(parents=True) + needs_data = {"current_version": "", "versions": {"": {"needs": sample_needs}}} + (build_dir / "needs.json").write_text(json.dumps(needs_data, indent=2)) + return tmp_workspace diff --git a/tests/test_capture.py b/tests/test_capture.py new file mode 100644 index 0000000..02b7003 --- /dev/null +++ b/tests/test_capture.py @@ -0,0 +1,436 @@ +"""Tests for the capture module — git commit analysis and candidate generation.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from ai_memory_protocol.capture import ( + MemoryCandidate, + _classify_commit, + _extract_scope, + _file_overlap, + _GitCommit, + _group_commits, + _infer_tags, + _is_duplicate, + capture_from_git, + format_candidates, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def fix_commit() -> _GitCommit: + return _GitCommit( + hash="abc12345", + subject="fix(gateway): resolve timeout issue", + body="The timeout was set too low.", + author="dev", + date="2026-01-15T10:00:00+00:00", + files=["src/gateway/server.cpp", "src/gateway/config.hpp"], + ) + + +@pytest.fixture +def feat_commit() -> _GitCommit: + return _GitCommit( + hash="def67890", + subject="feat(api): add health endpoint", + body="", + author="dev", + date="2026-01-16T10:00:00+00:00", + files=["src/api/health.cpp"], + ) + + +@pytest.fixture +def breaking_commit() -> _GitCommit: + return _GitCommit( + hash="ghi11111", + subject="refactor(core): restructure module layout", + body="BREAKING CHANGE: Module paths changed.", + author="dev", + date="2026-01-17T10:00:00+00:00", + files=["src/core/module.cpp"], + ) + + +@pytest.fixture +def plain_commit() -> _GitCommit: + return _GitCommit( + hash="jkl22222", + subject="Update README", + body="", + author="dev", + date="2026-01-18T10:00:00+00:00", + files=["README.md"], + ) + + +# --------------------------------------------------------------------------- +# Tests: _classify_commit +# --------------------------------------------------------------------------- + + +class TestClassifyCommit: + def test_fix_commit(self, fix_commit): + mem_type, confidence = _classify_commit(fix_commit) + assert mem_type == "mem" + assert confidence == "high" + + def test_feat_commit(self, feat_commit): + mem_type, confidence = _classify_commit(feat_commit) + assert mem_type == "fact" + assert confidence == "medium" + + def test_breaking_change(self, breaking_commit): + mem_type, confidence = _classify_commit(breaking_commit) + assert mem_type == "risk" + assert confidence == "high" + + def test_plain_commit(self, plain_commit): + mem_type, confidence = _classify_commit(plain_commit) + assert mem_type == "mem" + assert confidence == "low" + + def test_style_commit(self): + c = _GitCommit(hash="x", subject="style(ui): fix formatting", body="", author="", date="") + mem_type, _ = _classify_commit(c) + assert mem_type == "pref" + + def test_docs_commit(self): + c = _GitCommit(hash="x", subject="docs: update API guide", body="", author="", date="") + mem_type, _ = _classify_commit(c) + assert mem_type == "fact" + + +# --------------------------------------------------------------------------- +# Tests: _extract_scope +# --------------------------------------------------------------------------- + + +class TestExtractScope: + def test_with_scope(self): + assert _extract_scope("fix(gateway): bug") == "gateway" + + def test_without_scope(self): + assert _extract_scope("fix: general bug") == "" + + def test_with_parentheses_in_title(self): + assert _extract_scope("feat(api): add handler (new)") == "api" + + +# --------------------------------------------------------------------------- +# Tests: _infer_tags +# --------------------------------------------------------------------------- + + +class TestInferTags: + def test_includes_repo(self, fix_commit): + tags = _infer_tags(fix_commit, "ros2_medkit") + assert "repo:ros2_medkit" in tags + + def test_includes_scope_as_topic(self, fix_commit): + tags = _infer_tags(fix_commit, "ros2_medkit") + assert "topic:gateway" in tags + + def test_infers_from_paths(self, feat_commit): + tags = _infer_tags(feat_commit, "ros2_medkit") + assert "repo:ros2_medkit" in tags + # Should infer topic from file path + assert any("topic:" in t for t in tags) + + +# --------------------------------------------------------------------------- +# Tests: _file_overlap +# --------------------------------------------------------------------------- + + +class TestFileOverlap: + def test_full_overlap(self): + assert _file_overlap(["a.cpp", "b.cpp"], ["a.cpp", "b.cpp"]) == 1.0 + + def test_no_overlap(self): + assert _file_overlap(["a.cpp"], ["b.cpp"]) == 0.0 + + def test_partial_overlap(self): + overlap = _file_overlap(["a.cpp", "b.cpp"], ["b.cpp", "c.cpp"]) + assert 0.3 < overlap < 0.4 # 1/3 + + def test_empty_files(self): + assert _file_overlap([], []) == 0.0 + + +# --------------------------------------------------------------------------- +# Tests: _group_commits +# --------------------------------------------------------------------------- + + +class TestGroupCommits: + def test_groups_by_file_overlap(self, fix_commit): + c2 = _GitCommit( + hash="222", + subject="fix(gateway): another fix", + body="", + author="", + date="", + files=["src/gateway/server.cpp"], + ) + groups = _group_commits([fix_commit, c2]) + assert len(groups) == 1 # Should be grouped together + + def test_separate_groups_for_unrelated(self, fix_commit, feat_commit): + groups = _group_commits([fix_commit, feat_commit]) + assert len(groups) == 2 # Different files → separate groups + + def test_empty_list(self): + assert _group_commits([]) == [] + + def test_single_commit(self, fix_commit): + groups = _group_commits([fix_commit]) + assert len(groups) == 1 + assert len(groups[0]) == 1 + + +# --------------------------------------------------------------------------- +# Tests: _is_duplicate +# --------------------------------------------------------------------------- + + +class TestIsDuplicate: + def test_similar_title_is_duplicate(self): + candidate = MemoryCandidate( + type="mem", + title="Gateway timeout is 30 seconds", + body="test", + source="commit:abc", + ) + existing = { + "MEM_x": { + "title": "Gateway timeout is 30 seconds by default", + "status": "active", + "source": "", + }, + } + assert _is_duplicate(candidate, existing) + + def test_different_title_not_duplicate(self): + candidate = MemoryCandidate( + type="mem", + title="API supports pagination", + body="test", + source="commit:abc", + ) + existing = { + "MEM_x": { + "title": "Gateway timeout issue", + "status": "active", + "source": "", + }, + } + assert not _is_duplicate(candidate, existing) + + def test_same_source_is_duplicate(self): + candidate = MemoryCandidate( + type="mem", + title="Completely different title", + body="test", + source="commit:abc12345", + ) + existing = { + "MEM_x": { + "title": "Something else", + "status": "active", + "source": "commit:abc12345", + }, + } + assert _is_duplicate(candidate, existing) + + def test_skips_deprecated(self): + candidate = MemoryCandidate( + type="mem", + title="Gateway timeout", + body="", + source="", + ) + existing = { + "MEM_x": { + "title": "Gateway timeout", + "status": "deprecated", + "source": "", + }, + } + assert not _is_duplicate(candidate, existing) + + +# --------------------------------------------------------------------------- +# Tests: format_candidates +# --------------------------------------------------------------------------- + + +class TestFormatCandidates: + def test_empty_candidates(self): + result = format_candidates([]) + assert "No new" in result + + def test_human_format(self): + candidates = [ + MemoryCandidate( + type="mem", + title="Test fix", + body="Fixed a bug", + tags=["topic:test"], + source="commit:abc", + confidence="high", + ), + ] + result = format_candidates(candidates, fmt="human") + assert "Test fix" in result + assert "topic:test" in result + assert "commit:abc" in result + + def test_json_format(self): + candidates = [ + MemoryCandidate(type="fact", title="New feature", body="Added X", tags=["topic:api"]), + ] + result = format_candidates(candidates, fmt="json") + import json + + parsed = json.loads(result) + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0]["type"] == "fact" + + def test_multiple_candidates(self): + candidates = [ + MemoryCandidate(type="mem", title=f"Item {i}", body="", tags=[]) for i in range(5) + ] + result = format_candidates(candidates, fmt="human") + assert "5 memory candidate" in result + + +# --------------------------------------------------------------------------- +# Tests: capture_from_git (integration-ish, mocked subprocess) +# --------------------------------------------------------------------------- + + +class TestCaptureFromGit: + def test_no_commits_returns_empty(self, tmp_workspace): + with patch("ai_memory_protocol.capture._parse_git_log", return_value=[]): + candidates = capture_from_git( + workspace=tmp_workspace, + repo_path=Path("/fake/repo"), + since="HEAD~5", + until="HEAD", + ) + assert candidates == [] + + def test_single_commit_creates_candidate(self, tmp_workspace): + mock_commits = [ + _GitCommit( + hash="abc12345", + subject="fix(gateway): timeout bug", + body="Set default to 30s", + author="dev", + date="2026-01-15", + files=["src/server.cpp"], + ), + ] + with ( + patch("ai_memory_protocol.capture._parse_git_log", return_value=mock_commits), + patch("ai_memory_protocol.capture.load_needs", return_value={}), + ): + candidates = capture_from_git( + workspace=tmp_workspace, + repo_path=Path("/fake/repo"), + repo_name="ros2_medkit", + ) + assert len(candidates) == 1 + assert candidates[0].type == "mem" + assert "timeout" in candidates[0].title.lower() + assert "repo:ros2_medkit" in candidates[0].tags + + def test_dedup_filters_existing(self, tmp_workspace): + mock_commits = [ + _GitCommit( + hash="abc12345", + subject="fix: gateway timeout issue", + body="", + author="dev", + date="2026-01-15", + files=[], + ), + ] + existing_needs = { + "MEM_x": { + "title": "gateway timeout issue", + "status": "active", + "source": "", + }, + } + with ( + patch("ai_memory_protocol.capture._parse_git_log", return_value=mock_commits), + patch("ai_memory_protocol.capture.load_needs", return_value=existing_needs), + ): + candidates = capture_from_git( + workspace=tmp_workspace, + repo_path=Path("/fake/repo"), + deduplicate=True, + ) + assert len(candidates) == 0 + + def test_min_confidence_filters(self, tmp_workspace): + mock_commits = [ + _GitCommit( + hash="abc12345", + subject="chore: update deps", + body="", + author="dev", + date="2026-01-15", + files=[], + ), + ] + with ( + patch("ai_memory_protocol.capture._parse_git_log", return_value=mock_commits), + patch("ai_memory_protocol.capture.load_needs", return_value={}), + ): + candidates = capture_from_git( + workspace=tmp_workspace, + repo_path=Path("/fake/repo"), + min_confidence="medium", + ) + assert len(candidates) == 0 # chore → low confidence, filtered out + + +# --------------------------------------------------------------------------- +# Tests: MemoryCandidate +# --------------------------------------------------------------------------- + + +class TestMemoryCandidate: + def test_to_dict(self): + c = MemoryCandidate( + type="mem", + title="Test", + body="Body text", + tags=["topic:test"], + source="commit:abc", + confidence="high", + ) + d = c.to_dict() + assert d["type"] == "mem" + assert d["title"] == "Test" + assert d["confidence"] == "high" + assert "_source_hashes" not in d # Private field excluded + + def test_empty_fields_omitted(self): + c = MemoryCandidate(type="mem", title="Minimal", body="", tags=[]) + d = c.to_dict() + assert "body" not in d + assert "source" not in d diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..593d6f7 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,156 @@ +"""Integration tests for CLI commands. Requires Sphinx/sphinx-needs.""" + +from __future__ import annotations + +import subprocess + +import pytest + +pytestmark = pytest.mark.integration + + +class TestCLIWorkflow: + def test_version(self) -> None: + result = subprocess.run(["memory", "--version"], capture_output=True, text=True) + assert result.returncode == 0 + assert "0." in result.stdout + + def test_init_add_recall(self, tmp_path) -> None: + ws = str(tmp_path / ".memories") + + # Init + result = subprocess.run( + ["memory", "init", ws, "--name", "CLI Test", "--install"], + capture_output=True, + text=True, + timeout=120, + ) + assert result.returncode == 0 + + # Add + result = subprocess.run( + [ + "memory", + "--dir", + ws, + "add", + "fact", + "CLI test fact", + "--tags", + "topic:test", + "--confidence", + "high", + "--body", + "Test body content", + "--rebuild", + ], + capture_output=True, + text=True, + timeout=120, + ) + assert result.returncode == 0 + + # Recall + result = subprocess.run( + ["memory", "--dir", ws, "recall", "test", "--format", "brief"], + capture_output=True, + text=True, + timeout=30, + ) + assert result.returncode == 0 + assert "FACT_" in result.stdout + + def test_doctor(self, tmp_path) -> None: + ws = str(tmp_path / ".memories") + subprocess.run( + ["memory", "init", ws, "--name", "Doctor Test", "--install"], + capture_output=True, + text=True, + timeout=120, + ) + result = subprocess.run( + ["memory", "--dir", ws, "doctor"], + capture_output=True, + text=True, + timeout=30, + ) + # doctor should run without crash + assert result.returncode in (0, 1) + + def test_tags_command(self, tmp_path) -> None: + ws = str(tmp_path / ".memories") + subprocess.run( + ["memory", "init", ws, "--name", "Tags Test", "--install"], + capture_output=True, + text=True, + timeout=120, + ) + subprocess.run( + [ + "memory", + "--dir", + ws, + "add", + "mem", + "Tagged memory", + "--tags", + "topic:test,repo:example", + "--rebuild", + ], + capture_output=True, + text=True, + timeout=120, + ) + result = subprocess.run( + ["memory", "--dir", ws, "tags"], + capture_output=True, + text=True, + timeout=30, + ) + assert result.returncode == 0 + assert "topic:test" in result.stdout + + def test_update_and_deprecate(self, tmp_path) -> None: + ws = str(tmp_path / ".memories") + subprocess.run( + ["memory", "init", ws, "--name", "Update Test", "--install"], + capture_output=True, + text=True, + timeout=120, + ) + subprocess.run( + [ + "memory", + "--dir", + ws, + "add", + "mem", + "Updateable", + "--tags", + "topic:test", + "--id", + "MEM_updateable", + "--rebuild", + ], + capture_output=True, + text=True, + timeout=120, + ) + + # Update confidence + result = subprocess.run( + ["memory", "--dir", ws, "update", "MEM_updateable", "--confidence", "high"], + capture_output=True, + text=True, + timeout=30, + ) + assert result.returncode == 0 + + # Deprecate + result = subprocess.run( + ["memory", "--dir", ws, "deprecate", "MEM_updateable"], + capture_output=True, + text=True, + timeout=30, + ) + assert result.returncode == 0 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..03bd19c --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,78 @@ +"""Tests for config constants — ensure consistency of mappings.""" + +from ai_memory_protocol.config import ( + CONTEXT_PACK_LABELS, + CONTEXT_PACK_ORDER, + DEFAULT_STATUS, + LINK_FIELDS, + METADATA_FIELDS, + TYPE_FILES, + TYPE_LABELS, + TYPE_PREFIXES, +) + + +def test_type_files_maps_all_types(): + """Every type in TYPE_PREFIXES should have a corresponding file.""" + for typ in TYPE_PREFIXES: + assert typ in TYPE_FILES, f"Type '{typ}' missing from TYPE_FILES" + + +def test_type_prefixes_maps_all_types(): + """Every type in TYPE_FILES should have a prefix.""" + for typ in TYPE_FILES: + assert typ in TYPE_PREFIXES, f"Type '{typ}' missing from TYPE_PREFIXES" + + +def test_type_prefixes_uppercase(): + """Prefixes should be uppercase.""" + for typ, prefix in TYPE_PREFIXES.items(): + assert prefix == prefix.upper(), f"Prefix for '{typ}' should be uppercase: got '{prefix}'" + + +def test_type_labels_all_types(): + """Every type should have a human-readable label.""" + for typ in TYPE_PREFIXES: + assert typ in TYPE_LABELS, f"Type '{typ}' missing from TYPE_LABELS" + + +def test_default_status_all_types(): + """Every type should have a default status.""" + for typ in TYPE_PREFIXES: + assert typ in DEFAULT_STATUS, f"Type '{typ}' missing from DEFAULT_STATUS" + + +def test_link_fields_are_strings(): + """Link fields should all be strings.""" + for field in LINK_FIELDS: + assert isinstance(field, str) + + +def test_metadata_fields_are_strings(): + """Metadata fields should all be strings.""" + for field in METADATA_FIELDS: + assert isinstance(field, str) + + +def test_context_pack_order_covers_types(): + """Context pack order should include all types.""" + for typ in TYPE_PREFIXES: + assert typ in CONTEXT_PACK_ORDER, f"Type '{typ}' missing from CONTEXT_PACK_ORDER" + + +def test_context_pack_labels_covers_order(): + """Every type in context pack order should have a label.""" + for typ in CONTEXT_PACK_ORDER: + assert typ in CONTEXT_PACK_LABELS, f"Type '{typ}' missing from CONTEXT_PACK_LABELS" + + +def test_type_files_paths_are_rst(): + """All type file paths should be .rst files.""" + for typ, path in TYPE_FILES.items(): + assert path.endswith(".rst"), f"Type '{typ}' file path should be .rst: got '{path}'" + + +def test_type_files_in_memory_dir(): + """All type file paths should be under memory/.""" + for typ, path in TYPE_FILES.items(): + assert path.startswith("memory/"), f"Type '{typ}' path should start with 'memory/': {path}" diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..4c60c6a --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,163 @@ +"""Tests for the core engine — search, filter, graph traversal, workspace detection.""" + +from __future__ import annotations + +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from ai_memory_protocol.engine import ( + expand_graph, + find_workspace, + load_needs, + resolve_id, + tag_match, + text_match, +) + + +class TestFindWorkspace: + def test_explicit_dir(self, tmp_workspace: Path) -> None: + ws = find_workspace(str(tmp_workspace)) + assert ws == tmp_workspace + + def test_cli_override(self, tmp_workspace: Path) -> None: + ws = find_workspace(str(tmp_workspace)) + assert ws == tmp_workspace + + def test_env_var(self, tmp_workspace: Path) -> None: + with patch.dict(os.environ, {"MEMORY_DIR": str(tmp_workspace)}): + ws = find_workspace(None) + assert ws == tmp_workspace + + def test_missing_dir_raises(self, tmp_path: Path) -> None: + nonexistent = tmp_path / "nonexistent" + with pytest.raises(SystemExit): + find_workspace(str(nonexistent)) + + def test_invalid_workspace_raises(self, tmp_path: Path) -> None: + # Directory exists but no conf.py + with pytest.raises(SystemExit): + find_workspace(str(tmp_path)) + + def test_walk_up_finds_workspace(self, tmp_workspace: Path) -> None: + # Create a subdirectory and check walk-up from there + subdir = tmp_workspace / "subdir" / "deep" + subdir.mkdir(parents=True) + # Walk-up from the subdirectory should find the workspace at tmp_workspace + original_cwd = Path.cwd() + try: + os.chdir(subdir) + # Clear MEMORY_DIR to test pure walk-up (but real workspace may be found first) + with patch.dict(os.environ, {}, clear=True): + ws = find_workspace(None) + # Should find *some* workspace by walking up + assert ws is not None + finally: + os.chdir(original_cwd) + + +class TestLoadNeeds: + def test_loads_sample_data(self, needs_json_file: Path, sample_needs: dict) -> None: + needs = load_needs(needs_json_file) + assert "MEM_test_observation" in needs + assert "DEC_use_httplib" in needs + + def test_missing_json_exits(self, tmp_workspace: Path) -> None: + with pytest.raises(SystemExit): + load_needs(tmp_workspace) + + def test_loads_correct_fields(self, needs_json_file: Path) -> None: + needs = load_needs(needs_json_file) + mem = needs["MEM_test_observation"] + assert mem["title"] == "Test observation about gateway" + assert mem["confidence"] == "high" + assert mem["type"] == "mem" + + +class TestResolveId: + def test_exact_match(self, sample_needs: dict) -> None: + result = resolve_id(sample_needs, "MEM_test_observation") + assert result == "MEM_test_observation" + + def test_case_insensitive(self, sample_needs: dict) -> None: + result = resolve_id(sample_needs, "mem_test_observation") + assert result == "MEM_test_observation" + + def test_not_found(self, sample_needs: dict) -> None: + result = resolve_id(sample_needs, "NONEXISTENT_id") + assert result is None + + +class TestTextMatch: + def test_matches_title(self, sample_needs: dict) -> None: + assert text_match(sample_needs["MEM_test_observation"], "gateway") + + def test_matches_body(self, sample_needs: dict) -> None: + assert text_match(sample_needs["MEM_test_observation"], "port 8080") + + def test_no_match(self, sample_needs: dict) -> None: + assert not text_match(sample_needs["MEM_test_observation"], "nonexistent_keyword_xyz") + + def test_case_insensitive(self, sample_needs: dict) -> None: + assert text_match(sample_needs["MEM_test_observation"], "GATEWAY") + + def test_matches_id(self, sample_needs: dict) -> None: + assert text_match(sample_needs["MEM_test_observation"], "MEM_test") + + def test_matches_tags(self, sample_needs: dict) -> None: + assert text_match(sample_needs["MEM_test_observation"], "repo:ros2_medkit") + + def test_or_logic(self, sample_needs: dict) -> None: + # Any word matching = True + assert text_match(sample_needs["MEM_test_observation"], "nonexistent gateway") + + def test_all_words_miss(self, sample_needs: dict) -> None: + assert not text_match(sample_needs["MEM_test_observation"], "aaa bbb ccc") + + +class TestTagMatch: + def test_single_tag(self, sample_needs: dict) -> None: + assert tag_match(sample_needs["MEM_test_observation"], ["topic:gateway"]) + + def test_multiple_tags_and(self, sample_needs: dict) -> None: + assert tag_match( + sample_needs["MEM_test_observation"], ["topic:gateway", "repo:ros2_medkit"] + ) + + def test_tag_not_found(self, sample_needs: dict) -> None: + assert not tag_match(sample_needs["MEM_test_observation"], ["topic:nonexistent"]) + + def test_partial_tag_no_match(self, sample_needs: dict) -> None: + assert not tag_match(sample_needs["MEM_test_observation"], ["topic:gate"]) + + def test_empty_tags(self, sample_needs: dict) -> None: + assert tag_match(sample_needs["MEM_test_observation"], []) + + +class TestExpandGraph: + def test_expands_one_hop(self, sample_needs: dict) -> None: + matched = {"DEC_use_httplib": sample_needs["DEC_use_httplib"]} + expanded = expand_graph(sample_needs, set(matched.keys()), hops=1) + # Should pull in MEM_test_observation via "relates" link + assert "MEM_test_observation" in expanded + + def test_zero_hops_no_expansion(self, sample_needs: dict) -> None: + matched = {"DEC_use_httplib": sample_needs["DEC_use_httplib"]} + expanded = expand_graph(sample_needs, set(matched.keys()), hops=0) + assert set(expanded.keys()) == {"DEC_use_httplib"} + + def test_includes_seed(self, sample_needs: dict) -> None: + expanded = expand_graph(sample_needs, {"MEM_test_observation"}, hops=1) + assert "MEM_test_observation" in expanded + + def test_nonexistent_seed_excluded(self, sample_needs: dict) -> None: + expanded = expand_graph(sample_needs, {"NONEXISTENT"}, hops=1) + assert "NONEXISTENT" not in expanded + + def test_multiple_hops(self, sample_needs: dict) -> None: + # With 2 hops, starting from DEC, should reach MEM through relates + expanded = expand_graph(sample_needs, {"DEC_use_httplib"}, hops=2) + assert "MEM_test_observation" in expanded diff --git a/tests/test_executor.py b/tests/test_executor.py new file mode 100644 index 0000000..029f582 --- /dev/null +++ b/tests/test_executor.py @@ -0,0 +1,283 @@ +"""Tests for the executor module — action execution and rollback.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from ai_memory_protocol.executor import ( + ExecutionResult, + actions_from_json, + execute_plan, + validate_actions, +) +from ai_memory_protocol.planner import Action +from ai_memory_protocol.rst import ( + append_to_rst, + generate_rst_directive, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def workspace_with_memories(tmp_workspace: Path) -> Path: + """Workspace with a few memories already added.""" + for _, (mid, title, tags) in enumerate( + [ + ("MEM_alpha", "Alpha observation", ["topic:test", "repo:demo"]), + ("MEM_beta", "Beta observation", ["topic:test", "repo:demo"]), + ("DEC_choice_a", "Choose option A", ["topic:api"]), + ] + ): + mem_type = mid.split("_")[0].lower() + directive = generate_rst_directive( + mem_type=mem_type, + title=title, + need_id=mid, + tags=tags, + confidence="medium", + ) + append_to_rst(tmp_workspace, mem_type, directive) + return tmp_workspace + + +# --------------------------------------------------------------------------- +# Tests: validate_actions +# --------------------------------------------------------------------------- + + +class TestValidateActions: + def test_valid_retag(self): + actions = [Action(kind="RETAG", reason="fix tags", id="MEM_test")] + valid, skipped = validate_actions(actions) + assert len(valid) == 1 + assert len(skipped) == 0 + + def test_retag_missing_id(self): + actions = [Action(kind="RETAG", reason="fix tags")] + valid, skipped = validate_actions(actions) + assert len(valid) == 0 + assert len(skipped) == 1 + + def test_supersede_missing_old_id(self): + actions = [Action(kind="SUPERSEDE", reason="dup")] + valid, skipped = validate_actions(actions) + assert len(valid) == 0 + assert len(skipped) == 1 + + def test_circular_supersede_detected(self): + actions = [ + Action(kind="SUPERSEDE", reason="dup", old_id="A", by_id="B"), + Action(kind="SUPERSEDE", reason="dup", old_id="B", by_id="A"), + ] + valid, skipped = validate_actions(actions) + # At least one should be skipped for circular reference + assert len(skipped) >= 1 + + def test_update_missing_id(self): + actions = [Action(kind="UPDATE", reason="stale", field_changes={"status": "review"})] + valid, skipped = validate_actions(actions) + assert len(valid) == 0 + assert len(skipped) == 1 + + def test_split_file_missing_path(self): + actions = [Action(kind="SPLIT_FILE", reason="too big")] + valid, skipped = validate_actions(actions) + assert len(valid) == 0 + assert len(skipped) == 1 + + def test_mixed_valid_and_invalid(self): + actions = [ + Action(kind="RETAG", reason="good", id="MEM_x"), + Action(kind="RETAG", reason="bad"), # missing id + Action(kind="UPDATE", reason="ok", id="MEM_y", field_changes={"status": "review"}), + ] + valid, skipped = validate_actions(actions) + assert len(valid) == 2 + assert len(skipped) == 1 + + +# --------------------------------------------------------------------------- +# Tests: execute_plan +# --------------------------------------------------------------------------- + + +class TestExecutePlan: + def test_empty_actions(self, workspace_with_memories): + result = execute_plan(workspace_with_memories, []) + assert result.success + + def test_retag_action(self, workspace_with_memories): + actions = [ + Action( + kind="RETAG", + reason="add missing tag", + id="MEM_alpha", + add_tags=["tier:core"], + ) + ] + result = execute_plan(workspace_with_memories, actions, rebuild=False) + assert result.success + assert len(result.applied) == 1 + + def test_update_action(self, workspace_with_memories): + actions = [ + Action( + kind="UPDATE", + reason="mark for review", + id="MEM_alpha", + field_changes={"status": "review"}, + ) + ] + result = execute_plan(workspace_with_memories, actions, rebuild=False) + assert result.success + assert len(result.applied) == 1 + + def test_deprecate_action(self, workspace_with_memories): + actions = [Action(kind="DEPRECATE", reason="outdated", id="MEM_beta")] + result = execute_plan(workspace_with_memories, actions, rebuild=False) + assert result.success + assert len(result.applied) == 1 + + def test_supersede_action(self, workspace_with_memories): + actions = [ + Action( + kind="SUPERSEDE", + reason="duplicate", + old_id="MEM_beta", + by_id="MEM_alpha", + ) + ] + result = execute_plan(workspace_with_memories, actions, rebuild=False) + assert result.success + assert len(result.applied) == 1 + + def test_invalid_action_skipped(self, workspace_with_memories): + actions = [ + Action(kind="RETAG", reason="no id"), # Missing id + Action(kind="RETAG", reason="valid", id="MEM_alpha", add_tags=["topic:new"]), + ] + result = execute_plan(workspace_with_memories, actions, rebuild=False) + assert result.success + assert len(result.applied) == 1 + assert len(result.skipped) == 1 + + def test_split_file_informational(self, workspace_with_memories): + actions = [ + Action(kind="SPLIT_FILE", reason="too large", rst_path="/some/path.rst"), + ] + result = execute_plan(workspace_with_memories, actions, rebuild=False) + assert result.success + assert len(result.applied) == 1 + + def test_prune_action(self, workspace_with_memories): + actions = [Action(kind="PRUNE", reason="irrelevant", id="MEM_alpha")] + result = execute_plan(workspace_with_memories, actions, rebuild=False) + assert result.success + assert len(result.applied) == 1 + + def test_unknown_action_kind(self, workspace_with_memories): + actions = [Action(kind="RETAG", reason="valid", id="MEM_alpha")] + # Patch the kind after creation to test unknown handling + actions[0].kind = "NONEXISTENT" # type: ignore[assignment] + valid, _ = validate_actions(actions) + result = execute_plan(workspace_with_memories, valid, rebuild=False) + assert len(result.failed) == 1 + + def test_multiple_actions_sequential(self, workspace_with_memories): + actions = [ + Action(kind="RETAG", reason="add tag", id="MEM_alpha", add_tags=["tier:core"]), + Action( + kind="UPDATE", + reason="update status", + id="MEM_beta", + field_changes={"confidence": "high"}, + ), + ] + result = execute_plan(workspace_with_memories, actions, rebuild=False) + assert result.success + assert len(result.applied) == 2 + + +# --------------------------------------------------------------------------- +# Tests: ExecutionResult +# --------------------------------------------------------------------------- + + +class TestExecutionResult: + def test_to_dict(self): + result = ExecutionResult( + success=True, + applied=[{"action": {"kind": "RETAG"}, "message": "done"}], + message="OK", + ) + d = result.to_dict() + assert d["success"] is True + assert d["applied_count"] == 1 + assert d["failed_count"] == 0 + + def test_summary(self): + result = ExecutionResult( + success=True, + applied=[{}], + failed=[{}], + skipped=[{}], + message="test", + ) + s = result.summary() + assert "Applied: 1" in s + assert "Failed: 1" in s + assert "Skipped: 1" in s + + +# --------------------------------------------------------------------------- +# Tests: actions_from_json +# --------------------------------------------------------------------------- + + +class TestActionsFromJson: + def test_basic_deserialization(self): + data = [ + {"kind": "RETAG", "reason": "fix tags", "id": "MEM_x", "add_tags": ["topic:new"]}, + { + "kind": "UPDATE", + "reason": "stale", + "id": "MEM_y", + "field_changes": {"status": "review"}, + }, + ] + actions = actions_from_json(data) + assert len(actions) == 2 + assert actions[0].kind == "RETAG" + assert actions[0].add_tags == ["topic:new"] + assert actions[1].field_changes == {"status": "review"} + + def test_empty_list(self): + assert actions_from_json([]) == [] + + def test_defaults_for_missing_fields(self): + data = [{"kind": "DEPRECATE", "reason": "old", "id": "MEM_z"}] + actions = actions_from_json(data) + assert actions[0].add_tags == [] + assert actions[0].field_changes == {} + assert actions[0].by_id == "" + + def test_roundtrip(self): + """Action → to_dict → actions_from_json → same action.""" + original = Action( + kind="SUPERSEDE", + reason="duplicate", + old_id="MEM_old", + by_id="MEM_new", + new_tags=["topic:x"], + ) + d = original.to_dict() + restored = actions_from_json([d])[0] + assert restored.kind == original.kind + assert restored.old_id == original.old_id + assert restored.by_id == original.by_id + assert restored.new_tags == original.new_tags diff --git a/tests/test_formatter.py b/tests/test_formatter.py new file mode 100644 index 0000000..61d8f7c --- /dev/null +++ b/tests/test_formatter.py @@ -0,0 +1,129 @@ +"""Tests for output formatters.""" + +from __future__ import annotations + +from ai_memory_protocol.formatter import ( + format_brief, + format_compact, + format_context_pack, + format_full, +) + + +class TestFormatBrief: + def test_basic_output(self, sample_needs: dict) -> None: + result = format_brief(sample_needs["MEM_test_observation"]) + assert "MEM_test_observation" in result + assert "Test observation" in result + + def test_includes_confidence(self, sample_needs: dict) -> None: + result = format_brief(sample_needs["MEM_test_observation"]) + assert "high" in result + + def test_includes_key_tags(self, sample_needs: dict) -> None: + result = format_brief(sample_needs["MEM_test_observation"]) + assert "topic:gateway" in result + + def test_deprecated_entry(self, sample_needs: dict) -> None: + result = format_brief(sample_needs["FACT_deprecated"]) + assert "FACT_deprecated" in result + + def test_no_tags(self) -> None: + need = {"id": "TEST_1", "title": "No tags", "confidence": "low", "tags": []} + result = format_brief(need) + assert "TEST_1" in result + assert "No tags" in result + + +class TestFormatCompact: + def test_basic_output(self, sample_needs: dict) -> None: + result = format_compact(sample_needs["MEM_test_observation"]) + assert "MEM_test_observation" in result + + def test_with_body(self, sample_needs: dict) -> None: + result = format_compact(sample_needs["MEM_test_observation"], show_body=True) + assert "port 8080" in result + + def test_without_body(self, sample_needs: dict) -> None: + result = format_compact(sample_needs["MEM_test_observation"], show_body=False) + assert "port 8080" not in result + + def test_includes_status(self, sample_needs: dict) -> None: + result = format_compact(sample_needs["MEM_test_observation"]) + assert "status=active" in result + + def test_includes_tags(self, sample_needs: dict) -> None: + result = format_compact(sample_needs["MEM_test_observation"]) + assert "topic:gateway" in result + + def test_includes_links(self, sample_needs: dict) -> None: + result = format_compact(sample_needs["DEC_use_httplib"]) + assert "relates" in result + assert "MEM_test_observation" in result + + def test_long_body_truncated(self) -> None: + need = { + "id": "TEST_1", + "title": "Long body", + "description": "x" * 600, + "status": "active", + "confidence": "medium", + "tags": [], + } + result = format_compact(need, show_body=True) + assert "..." in result + + +class TestFormatFull: + def test_includes_all_fields(self, sample_needs: dict) -> None: + result = format_full(sample_needs["DEC_use_httplib"]) + assert "DEC_use_httplib" in result + assert "cpp-httplib" in result + assert "topic:gateway" in result + + def test_includes_header(self, sample_needs: dict) -> None: + result = format_full(sample_needs["MEM_test_observation"]) + assert result.startswith("# MEM_test_observation") + + def test_includes_type(self, sample_needs: dict) -> None: + result = format_full(sample_needs["MEM_test_observation"]) + assert "type: mem" in result + + def test_includes_scope(self, sample_needs: dict) -> None: + result = format_full(sample_needs["MEM_test_observation"]) + assert "scope: repo:ros2_medkit" in result + + def test_includes_body(self, sample_needs: dict) -> None: + result = format_full(sample_needs["MEM_test_observation"]) + assert "port 8080" in result + + def test_includes_links(self, sample_needs: dict) -> None: + result = format_full(sample_needs["DEC_use_httplib"]) + assert "relates: MEM_test_observation" in result + + +class TestFormatContextPack: + def test_groups_by_type(self, sample_needs: dict) -> None: + active = {k: v for k, v in sample_needs.items() if v["status"] != "deprecated"} + result = format_context_pack(active) + assert isinstance(result, str) + assert len(result) > 0 + + def test_includes_count(self, sample_needs: dict) -> None: + active = {k: v for k, v in sample_needs.items() if v["status"] != "deprecated"} + result = format_context_pack(active) + assert "2 results" in result + + def test_empty_needs(self) -> None: + result = format_context_pack({}) + assert "No relevant" in result + + def test_hide_body_by_default(self, sample_needs: dict) -> None: + active = {k: v for k, v in sample_needs.items() if v["status"] != "deprecated"} + result = format_context_pack(active, show_body=False) + assert "memory get" in result.lower() or "memory_get" in result.lower() + + def test_show_body(self, sample_needs: dict) -> None: + active = {k: v for k, v in sample_needs.items() if v["status"] != "deprecated"} + result = format_context_pack(active, show_body=True) + assert "port 8080" in result diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py new file mode 100644 index 0000000..bac271a --- /dev/null +++ b/tests/test_mcp_server.py @@ -0,0 +1,123 @@ +"""Tests for MCP server tool handlers.""" + +from __future__ import annotations + +import json + +import pytest + +try: + from ai_memory_protocol.mcp_server import ( + _MCP_AVAILABLE, + TOOLS, + _format_output, + _sort_needs, + create_mcp_server, + ) # noqa: I001 + + MCP_AVAILABLE = _MCP_AVAILABLE +except ImportError: + MCP_AVAILABLE = False + +pytestmark = pytest.mark.skipif(not MCP_AVAILABLE, reason="MCP SDK not installed") + + +class TestMCPAvailability: + def test_mcp_flag_is_set(self) -> None: + assert MCP_AVAILABLE is True + + def test_create_server(self) -> None: + server = create_mcp_server() + assert server is not None + + +class TestMCPToolDefinitions: + def test_tool_count(self) -> None: + assert len(TOOLS) >= 8 + + def test_all_tools_have_schemas(self) -> None: + for tool in TOOLS: + assert tool.inputSchema is not None + assert tool.inputSchema.get("type") == "object" + + def test_required_tools_present(self) -> None: + names = {t.name for t in TOOLS} + for expected in [ + "memory_recall", + "memory_get", + "memory_add", + "memory_update", + "memory_deprecate", + "memory_tags", + "memory_stale", + "memory_rebuild", + "memory_plan", + "memory_apply", + "memory_capture_git", + ]: + assert expected in names, f"Missing tool: {expected}" + + def test_tools_have_descriptions(self) -> None: + for tool in TOOLS: + assert tool.description, f"Tool {tool.name} has empty description" + assert len(tool.description) > 10, f"Tool {tool.name} description too short" + + +class TestFormatOutput: + def test_brief_format(self, sample_needs: dict) -> None: + output = _format_output(sample_needs, fmt="brief") + assert isinstance(output, str) + assert "MEM_test_observation" in output + + def test_json_format(self, sample_needs: dict) -> None: + output = _format_output(sample_needs, fmt="json") + parsed = json.loads(output) + assert isinstance(parsed, dict) + + def test_compact_format(self, sample_needs: dict) -> None: + output = _format_output(sample_needs, fmt="compact") + assert isinstance(output, str) + + def test_context_format(self, sample_needs: dict) -> None: + output = _format_output(sample_needs, fmt="context") + assert isinstance(output, str) + assert "Recalled Memories" in output + + def test_limit(self, sample_needs: dict) -> None: + output = _format_output(sample_needs, fmt="brief", limit=1) + assert "omitted" in output.lower() + + def test_with_body(self, sample_needs: dict) -> None: + output = _format_output(sample_needs, fmt="compact", show_body=True) + assert "port 8080" in output + + def test_without_body(self, sample_needs: dict) -> None: + output = _format_output(sample_needs, fmt="compact", show_body=False) + assert "port 8080" not in output + + +class TestSortNeeds: + def test_sort_newest(self, sample_needs: dict) -> None: + sorted_items = _sort_needs(sample_needs, "newest") + dates = [item[1].get("created_at", "") for item in sorted_items] + assert dates == sorted(dates, reverse=True) + + def test_sort_oldest(self, sample_needs: dict) -> None: + sorted_items = _sort_needs(sample_needs, "oldest") + dates = [item[1].get("created_at", "") for item in sorted_items] + assert dates == sorted(dates) + + def test_sort_confidence(self, sample_needs: dict) -> None: + sorted_items = _sort_needs(sample_needs, "confidence") + conf_order = [item[1].get("confidence", "medium") for item in sorted_items] + # high should come first + assert conf_order[0] == "high" + + def test_sort_none(self, sample_needs: dict) -> None: + sorted_items = _sort_needs(sample_needs, None) + assert len(sorted_items) == len(sample_needs) + + def test_sort_updated(self, sample_needs: dict) -> None: + sorted_items = _sort_needs(sample_needs, "updated") + # Should not crash even without updated_at fields + assert len(sorted_items) == len(sample_needs) diff --git a/tests/test_planner.py b/tests/test_planner.py new file mode 100644 index 0000000..7052fff --- /dev/null +++ b/tests/test_planner.py @@ -0,0 +1,425 @@ +"""Tests for the planner module — detection algorithms and plan formatting.""" + +from __future__ import annotations + +import json +from datetime import date, timedelta + +import pytest + +from ai_memory_protocol.planner import ( + Action, + detect_conflicts, + detect_duplicates, + detect_missing_tags, + detect_split_files, + detect_stale, + detect_tag_normalization, + format_plan, + run_plan, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def needs_with_duplicates() -> dict: + """Two near-duplicate active needs.""" + return { + "MEM_gateway_timeout": { + "id": "MEM_gateway_timeout", + "type": "mem", + "title": "Gateway timeout is 30 seconds", + "description": "Default timeout.", + "status": "active", + "tags": ["topic:gateway", "repo:ros2_medkit"], + "confidence": "medium", + "created_at": "2026-01-10", + }, + "MEM_gateway_timeout_issue": { + "id": "MEM_gateway_timeout_issue", + "type": "mem", + "title": "Gateway timeout is 30 seconds by default", + "description": "Same info.", + "status": "active", + "tags": ["topic:gateway", "repo:ros2_medkit"], + "confidence": "high", + "created_at": "2026-01-15", + }, + } + + +@pytest.fixture +def needs_with_missing_tags() -> dict: + """Needs missing topic: or repo: tags.""" + return { + "MEM_no_topic": { + "id": "MEM_no_topic", + "type": "mem", + "title": "An observation", + "status": "active", + "tags": ["repo:ros2_medkit"], + "confidence": "medium", + }, + "MEM_no_repo": { + "id": "MEM_no_repo", + "type": "mem", + "title": "Another observation", + "status": "active", + "tags": ["topic:gateway"], + "confidence": "medium", + }, + "MEM_no_tags_at_all": { + "id": "MEM_no_tags_at_all", + "type": "mem", + "title": "Missing all tags", + "status": "active", + "tags": [], + "confidence": "low", + }, + "MEM_complete": { + "id": "MEM_complete", + "type": "mem", + "title": "Complete one", + "status": "active", + "tags": ["topic:api", "repo:ros2_medkit"], + "confidence": "high", + }, + } + + +@pytest.fixture +def needs_with_stale() -> dict: + """Needs with expired/review-overdue dates.""" + yesterday = (date.today() - timedelta(days=1)).isoformat() + tomorrow = (date.today() + timedelta(days=1)).isoformat() + return { + "MEM_expired": { + "id": "MEM_expired", + "type": "mem", + "title": "Old memory", + "status": "active", + "tags": ["topic:test"], + "expires_at": yesterday, + "review_after": "", + }, + "MEM_review_due": { + "id": "MEM_review_due", + "type": "mem", + "title": "Review needed", + "status": "active", + "tags": ["topic:test"], + "expires_at": "", + "review_after": yesterday, + }, + "MEM_still_fresh": { + "id": "MEM_still_fresh", + "type": "mem", + "title": "Still fresh", + "status": "active", + "tags": ["topic:test"], + "expires_at": "", + "review_after": tomorrow, + }, + } + + +@pytest.fixture +def needs_with_tag_issues() -> dict: + """Needs with case-inconsistent tags.""" + return { + "MEM_one": { + "id": "MEM_one", + "type": "mem", + "title": "First", + "status": "active", + "tags": ["topic:Gateway", "repo:ros2_medkit"], + }, + "MEM_two": { + "id": "MEM_two", + "type": "mem", + "title": "Second", + "status": "active", + "tags": ["topic:gateway", "repo:ros2_medkit"], + }, + "MEM_three": { + "id": "MEM_three", + "type": "mem", + "title": "Third", + "status": "active", + "tags": ["topic:gateway", "repo:ros2_medkit"], + }, + } + + +@pytest.fixture +def needs_with_conflicts() -> dict: + """Two decisions on the same topic with no link.""" + return { + "DEC_use_rest": { + "id": "DEC_use_rest", + "type": "dec", + "title": "Use REST for API", + "status": "active", + "tags": ["topic:api"], + }, + "DEC_use_grpc": { + "id": "DEC_use_grpc", + "type": "dec", + "title": "Use gRPC for API", + "status": "active", + "tags": ["topic:api"], + }, + } + + +# --------------------------------------------------------------------------- +# Tests: detect_duplicates +# --------------------------------------------------------------------------- + + +class TestDetectDuplicates: + def test_finds_near_duplicates(self, needs_with_duplicates): + actions = detect_duplicates(needs_with_duplicates) + assert len(actions) == 1 + assert actions[0].kind == "SUPERSEDE" + + def test_prefers_higher_confidence(self, needs_with_duplicates): + actions = detect_duplicates(needs_with_duplicates) + # The higher-confidence one should be kept + action = actions[0] + assert action.by_id == "MEM_gateway_timeout_issue" # high confidence + assert action.old_id == "MEM_gateway_timeout" # medium confidence + + def test_no_duplicates_for_different_titles(self, sample_needs): + actions = detect_duplicates(sample_needs) + assert len(actions) == 0 + + def test_skips_deprecated(self, sample_needs): + """Deprecated needs should not be flagged as duplicates.""" + actions = detect_duplicates(sample_needs) + assert all(a.old_id != "FACT_deprecated" for a in actions) + + def test_threshold_respected(self, needs_with_duplicates): + # With very high threshold, should find nothing + actions = detect_duplicates(needs_with_duplicates, title_threshold=0.99) + assert len(actions) == 0 + + def test_tag_overlap_threshold(self, needs_with_duplicates): + # With very high tag overlap requirement, should still match (100% overlap) + actions = detect_duplicates( + needs_with_duplicates, title_threshold=0.8, tag_overlap_threshold=0.9 + ) + assert len(actions) == 1 + + +# --------------------------------------------------------------------------- +# Tests: detect_missing_tags +# --------------------------------------------------------------------------- + + +class TestDetectMissingTags: + def test_finds_missing_topic(self, needs_with_missing_tags): + actions = detect_missing_tags(needs_with_missing_tags) + ids_with_actions = {a.id for a in actions} + assert "MEM_no_topic" in ids_with_actions + + def test_finds_missing_repo(self, needs_with_missing_tags): + actions = detect_missing_tags(needs_with_missing_tags) + ids_with_actions = {a.id for a in actions} + assert "MEM_no_repo" in ids_with_actions + + def test_finds_missing_both(self, needs_with_missing_tags): + actions = detect_missing_tags(needs_with_missing_tags) + ids_with_actions = {a.id for a in actions} + assert "MEM_no_tags_at_all" in ids_with_actions + + def test_skips_complete(self, needs_with_missing_tags): + actions = detect_missing_tags(needs_with_missing_tags) + ids_with_actions = {a.id for a in actions} + assert "MEM_complete" not in ids_with_actions + + def test_action_type_is_retag(self, needs_with_missing_tags): + actions = detect_missing_tags(needs_with_missing_tags) + assert all(a.kind == "RETAG" for a in actions) + + +# --------------------------------------------------------------------------- +# Tests: detect_stale +# --------------------------------------------------------------------------- + + +class TestDetectStale: + def test_finds_expired(self, needs_with_stale): + actions = detect_stale(needs_with_stale) + ids = {a.id for a in actions} + assert "MEM_expired" in ids + + def test_finds_review_overdue(self, needs_with_stale): + actions = detect_stale(needs_with_stale) + ids = {a.id for a in actions} + assert "MEM_review_due" in ids + + def test_skips_fresh(self, needs_with_stale): + actions = detect_stale(needs_with_stale) + ids = {a.id for a in actions} + assert "MEM_still_fresh" not in ids + + def test_action_type_is_update(self, needs_with_stale): + actions = detect_stale(needs_with_stale) + assert all(a.kind == "UPDATE" for a in actions) + for a in actions: + assert a.field_changes.get("status") == "review" + + +# --------------------------------------------------------------------------- +# Tests: detect_conflicts +# --------------------------------------------------------------------------- + + +class TestDetectConflicts: + def test_finds_unlinked_decisions(self, needs_with_conflicts): + actions = detect_conflicts(needs_with_conflicts) + assert len(actions) >= 1 + + def test_skips_linked_decisions(self): + needs = { + "DEC_a": { + "type": "dec", + "title": "A", + "status": "active", + "tags": ["topic:api"], + "relates": ["DEC_b"], + }, + "DEC_b": { + "type": "dec", + "title": "B", + "status": "active", + "tags": ["topic:api"], + }, + } + actions = detect_conflicts(needs) + assert len(actions) == 0 + + +# --------------------------------------------------------------------------- +# Tests: detect_tag_normalization +# --------------------------------------------------------------------------- + + +class TestDetectTagNormalization: + def test_finds_inconsistent_case(self, needs_with_tag_issues): + actions = detect_tag_normalization(needs_with_tag_issues) + assert len(actions) >= 1 + + def test_normalizes_to_most_common(self, needs_with_tag_issues): + actions = detect_tag_normalization(needs_with_tag_issues) + # "topic:gateway" appears twice, "topic:Gateway" once → normalize to lowercase + for a in actions: + if "Gateway" in str(a.remove_tags): + assert "topic:gateway" in a.add_tags + + def test_action_type_is_retag(self, needs_with_tag_issues): + actions = detect_tag_normalization(needs_with_tag_issues) + assert all(a.kind == "RETAG" for a in actions) + + +# --------------------------------------------------------------------------- +# Tests: detect_split_files +# --------------------------------------------------------------------------- + + +class TestDetectSplitFiles: + def test_no_split_needed(self, tmp_workspace): + actions = detect_split_files(tmp_workspace) + assert len(actions) == 0 + + def test_detects_oversized(self, tmp_workspace): + # Write many directives to a file + rst_path = tmp_workspace / "memory" / "observations.rst" + content = rst_path.read_text() + for i in range(55): + content += f"\n.. mem:: Entry {i}\n :id: MEM_entry_{i}\n\n Body text.\n" + rst_path.write_text(content) + + actions = detect_split_files(tmp_workspace) + assert len(actions) >= 1 + assert actions[0].kind == "SPLIT_FILE" + + +# --------------------------------------------------------------------------- +# Tests: run_plan +# --------------------------------------------------------------------------- + + +class TestRunPlan: + def test_runs_all_checks(self, needs_json_file, sample_needs): + actions = run_plan(needs_json_file, needs=sample_needs) + # sample_needs has missing repo tags on DEC_use_httplib + assert isinstance(actions, list) + + def test_runs_specific_checks(self, needs_json_file, sample_needs): + actions = run_plan(needs_json_file, checks=["missing_tags"], needs=sample_needs) + # Only RETAG actions from missing_tags check + assert all(a.kind == "RETAG" for a in actions) + + def test_empty_needs_returns_empty(self, needs_json_file): + actions = run_plan(needs_json_file, needs={}) + assert actions == [] + + +# --------------------------------------------------------------------------- +# Tests: format_plan +# --------------------------------------------------------------------------- + + +class TestFormatPlan: + def test_empty_plan(self): + result = format_plan([]) + assert "healthy" in result.lower() + + def test_human_format(self): + actions = [ + Action(kind="RETAG", reason="Missing topic tag", id="MEM_test"), + ] + result = format_plan(actions, fmt="human") + assert "RETAG" in result + assert "MEM_test" in result + + def test_json_format(self): + actions = [ + Action(kind="UPDATE", reason="Stale", id="MEM_old", field_changes={"status": "review"}), + ] + result = format_plan(actions, fmt="json") + parsed = json.loads(result) + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0]["kind"] == "UPDATE" + + +# --------------------------------------------------------------------------- +# Tests: Action dataclass +# --------------------------------------------------------------------------- + + +class TestAction: + def test_to_dict_omits_empty(self): + a = Action(kind="RETAG", reason="test", id="MEM_x", add_tags=["topic:new"]) + d = a.to_dict() + assert "remove_tags" not in d + assert "field_changes" not in d + assert d["kind"] == "RETAG" + assert d["id"] == "MEM_x" + + def test_supersede_action(self): + a = Action( + kind="SUPERSEDE", + reason="duplicate", + old_id="MEM_old", + by_id="MEM_new", + ) + d = a.to_dict() + assert d["old_id"] == "MEM_old" + assert d["by_id"] == "MEM_new" diff --git a/tests/test_rst.py b/tests/test_rst.py new file mode 100644 index 0000000..4894ed1 --- /dev/null +++ b/tests/test_rst.py @@ -0,0 +1,241 @@ +"""Tests for RST directive generation and in-place editing.""" + +from __future__ import annotations + +from pathlib import Path + +from ai_memory_protocol.config import TYPE_PREFIXES +from ai_memory_protocol.rst import ( + add_tags_in_rst, + append_to_rst, + deprecate_in_rst, + generate_id, + generate_rst_directive, + remove_tags_in_rst, + update_field_in_rst, +) + + +class TestGenerateId: + def test_basic_id(self) -> None: + result = generate_id("mem", "Gateway timeout issue") + assert result.startswith("MEM_") + assert "gateway" in result.lower() + + def test_special_chars_removed(self) -> None: + result = generate_id("dec", "Use C++ HTTP library (cpp-httplib)") + assert "(" not in result + assert ")" not in result + assert "++" not in result + + def test_different_types(self) -> None: + for typ, prefix in TYPE_PREFIXES.items(): + result = generate_id(typ, "test title") + assert result.startswith(f"{prefix}_"), ( + f"ID for type '{typ}' should start with '{prefix}_'" + ) + + def test_max_length(self) -> None: + long_title = "a " * 100 + result = generate_id("mem", long_title) + # slugify limits to 50 chars + prefix + assert len(result) <= 60 + + def test_empty_title(self) -> None: + result = generate_id("mem", "") + assert result.startswith("MEM_") + + +class TestGenerateRstDirective: + def test_minimal_directive(self) -> None: + rst = generate_rst_directive("mem", "Test memory", tags=["topic:test"]) + assert ".. mem::" in rst + assert "Test memory" in rst + assert ":tags: topic:test" in rst + + def test_with_body(self) -> None: + rst = generate_rst_directive( + "dec", "Test decision", tags=["topic:api"], body="Detailed rationale." + ) + assert "Detailed rationale." in rst + + def test_with_all_fields(self) -> None: + rst = generate_rst_directive( + "fact", + "Complete fact", + need_id="FACT_custom_id", + tags=["topic:test", "tier:core"], + source="manual test", + confidence="high", + scope="global", + body="Full body text.", + relates=["MEM_related"], + supersedes=["FACT_old"], + ) + assert ":id: FACT_custom_id" in rst + assert ":confidence: high" in rst + assert ":source: manual test" in rst + assert ":relates: MEM_related" in rst + assert ":supersedes: FACT_old" in rst + + def test_includes_created_at(self) -> None: + rst = generate_rst_directive("mem", "Test", tags=["topic:test"]) + assert ":created_at:" in rst + + def test_includes_review_after(self) -> None: + rst = generate_rst_directive("mem", "Test", tags=["topic:test"]) + assert ":review_after:" in rst + + def test_default_status_by_type(self) -> None: + rst_mem = generate_rst_directive("mem", "Observation", tags=["topic:test"]) + rst_fact = generate_rst_directive("fact", "Verified fact", tags=["topic:test"]) + assert ":status: draft" in rst_mem + assert ":status: promoted" in rst_fact + + def test_custom_review_days(self) -> None: + rst1 = generate_rst_directive("mem", "Short", tags=["t:x"], review_days=7) + rst2 = generate_rst_directive("mem", "Long", tags=["t:x"], review_days=365) + # Both should have review_after but different dates + assert ":review_after:" in rst1 + assert ":review_after:" in rst2 + + def test_empty_body_placeholder(self) -> None: + rst = generate_rst_directive("mem", "No body", tags=["topic:test"]) + assert "TODO: Add description." in rst + + def test_no_tags(self) -> None: + rst = generate_rst_directive("mem", "Tagless", tags=[]) + assert ":tags:" not in rst + + +class TestAppendToRst: + def test_append_creates_entry(self, tmp_workspace: Path) -> None: + directive = generate_rst_directive("mem", "New memory", tags=["topic:test"]) + target = append_to_rst(tmp_workspace, "mem", directive) + assert target.exists() + content = target.read_text() + assert "New memory" in content + + def test_append_to_correct_file(self, tmp_workspace: Path) -> None: + directive = generate_rst_directive("dec", "New decision", tags=["topic:test"]) + target = append_to_rst(tmp_workspace, "dec", directive) + assert "decision" in target.name.lower() + + def test_multiple_appends(self, tmp_workspace: Path) -> None: + for i in range(3): + directive = generate_rst_directive("mem", f"Memory {i}", tags=["topic:test"]) + append_to_rst(tmp_workspace, "mem", directive) + # Read the file and check all three are there + from ai_memory_protocol.config import TYPE_FILES + + target = tmp_workspace / TYPE_FILES["mem"] + content = target.read_text() + for i in range(3): + assert f"Memory {i}" in content + + +class TestUpdateFieldInRst: + def test_update_existing_field(self, tmp_workspace: Path) -> None: + directive = generate_rst_directive( + "mem", + "Updatable memory", + need_id="MEM_updatable", + tags=["topic:test"], + confidence="low", + ) + append_to_rst(tmp_workspace, "mem", directive) + ok, msg = update_field_in_rst(tmp_workspace, "MEM_updatable", "confidence", "high") + assert ok, msg + # Verify the change + from ai_memory_protocol.config import TYPE_FILES + + content = (tmp_workspace / TYPE_FILES["mem"]).read_text() + assert ":confidence: high" in content + + def test_update_nonexistent_field_inserts(self, tmp_workspace: Path) -> None: + directive = generate_rst_directive( + "mem", "Field insert test", need_id="MEM_field_insert", tags=["topic:test"] + ) + append_to_rst(tmp_workspace, "mem", directive) + ok, msg = update_field_in_rst(tmp_workspace, "MEM_field_insert", "expires_at", "2099-12-31") + assert ok, msg + + def test_update_nonexistent_id(self, tmp_workspace: Path) -> None: + ok, msg = update_field_in_rst(tmp_workspace, "MEM_nonexistent", "status", "active") + assert not ok + assert "not found" in msg.lower() + + +class TestTagOperations: + def test_add_tags(self, tmp_workspace: Path) -> None: + directive = generate_rst_directive( + "mem", "Tag test", need_id="MEM_tag_test", tags=["topic:original"] + ) + append_to_rst(tmp_workspace, "mem", directive) + ok, msg = add_tags_in_rst(tmp_workspace, "MEM_tag_test", ["topic:new"]) + assert ok, msg + from ai_memory_protocol.config import TYPE_FILES + + content = (tmp_workspace / TYPE_FILES["mem"]).read_text() + assert "topic:original" in content + assert "topic:new" in content + + def test_add_duplicate_tag(self, tmp_workspace: Path) -> None: + directive = generate_rst_directive( + "mem", "Dedup test", need_id="MEM_dedup", tags=["topic:existing"] + ) + append_to_rst(tmp_workspace, "mem", directive) + ok, msg = add_tags_in_rst(tmp_workspace, "MEM_dedup", ["topic:existing"]) + assert ok + # Should not double-add + from ai_memory_protocol.config import TYPE_FILES + + content = (tmp_workspace / TYPE_FILES["mem"]).read_text() + assert content.count("topic:existing") == 1 + + def test_remove_tags(self, tmp_workspace: Path) -> None: + directive = generate_rst_directive( + "mem", + "Tag remove test", + need_id="MEM_tag_remove", + tags=["topic:keep", "topic:remove"], + ) + append_to_rst(tmp_workspace, "mem", directive) + ok, msg = remove_tags_in_rst(tmp_workspace, "MEM_tag_remove", ["topic:remove"]) + assert ok, msg + from ai_memory_protocol.config import TYPE_FILES + + content = (tmp_workspace / TYPE_FILES["mem"]).read_text() + assert "topic:keep" in content + assert "topic:remove" not in content + + def test_remove_nonexistent_id(self, tmp_workspace: Path) -> None: + ok, msg = remove_tags_in_rst(tmp_workspace, "MEM_nonexistent", ["topic:x"]) + assert not ok + + +class TestDeprecate: + def test_deprecate_sets_status(self, tmp_workspace: Path) -> None: + directive = generate_rst_directive( + "mem", "Deprecatable", need_id="MEM_deprecatable", tags=["topic:test"] + ) + append_to_rst(tmp_workspace, "mem", directive) + ok, msg = deprecate_in_rst(tmp_workspace, "MEM_deprecatable") + assert ok, msg + from ai_memory_protocol.config import TYPE_FILES + + content = (tmp_workspace / TYPE_FILES["mem"]).read_text() + assert ":status: deprecated" in content + + def test_deprecate_with_superseded_by(self, tmp_workspace: Path) -> None: + directive = generate_rst_directive( + "mem", "Old memory", need_id="MEM_old", tags=["topic:test"] + ) + append_to_rst(tmp_workspace, "mem", directive) + ok, msg = deprecate_in_rst(tmp_workspace, "MEM_old", "MEM_new") + assert ok + assert "MEM_new" in msg + + def test_deprecate_nonexistent(self, tmp_workspace: Path) -> None: + ok, msg = deprecate_in_rst(tmp_workspace, "MEM_nonexistent") + assert not ok diff --git a/tests/test_scaffold.py b/tests/test_scaffold.py new file mode 100644 index 0000000..5f99331 --- /dev/null +++ b/tests/test_scaffold.py @@ -0,0 +1,69 @@ +"""Tests for workspace initialization (scaffold module).""" + +from __future__ import annotations + +from pathlib import Path + +from ai_memory_protocol.scaffold import init_workspace + + +class TestInitWorkspace: + def test_creates_directory_structure(self, tmp_path: Path) -> None: + ws = tmp_path / ".memories" + init_workspace(ws, project_name="Test Memory") + assert (ws / "conf.py").exists() + assert (ws / "index.rst").exists() + assert (ws / "memory").is_dir() + + def test_creates_rst_files(self, tmp_path: Path) -> None: + ws = tmp_path / ".memories" + init_workspace(ws, project_name="Test Memory") + assert (ws / "memory" / "observations.rst").exists() + assert (ws / "memory" / "decisions.rst").exists() + assert (ws / "memory" / "facts.rst").exists() + assert (ws / "memory" / "preferences.rst").exists() + assert (ws / "memory" / "risks.rst").exists() + assert (ws / "memory" / "goals.rst").exists() + assert (ws / "memory" / "questions.rst").exists() + + def test_creates_memory_index(self, tmp_path: Path) -> None: + ws = tmp_path / ".memories" + init_workspace(ws, project_name="Test Memory") + assert (ws / "memory" / "index.rst").exists() + + def test_creates_makefile(self, tmp_path: Path) -> None: + ws = tmp_path / ".memories" + init_workspace(ws, project_name="Test Memory") + assert (ws / "Makefile").exists() + + def test_creates_gitignore(self, tmp_path: Path) -> None: + ws = tmp_path / ".memories" + init_workspace(ws, project_name="Test Memory") + assert (ws / ".gitignore").exists() + + def test_conf_has_needs_types(self, tmp_path: Path) -> None: + ws = tmp_path / ".memories" + init_workspace(ws, project_name="Test Memory") + content = (ws / "conf.py").read_text() + assert "needs_types" in content + + def test_conf_has_project_name(self, tmp_path: Path) -> None: + ws = tmp_path / ".memories" + init_workspace(ws, project_name="My Custom Project") + content = (ws / "conf.py").read_text() + assert "My Custom Project" in content + + def test_idempotent(self, tmp_path: Path) -> None: + ws = tmp_path / ".memories" + init_workspace(ws, project_name="Test 1") + init_workspace(ws, project_name="Test 2") # Should not crash + assert (ws / "conf.py").exists() + # Original content should be preserved (skips existing files) + content = (ws / "conf.py").read_text() + assert "Test 1" in content # First creation wins + + def test_custom_author(self, tmp_path: Path) -> None: + ws = tmp_path / ".memories" + init_workspace(ws, project_name="Test", author="testauthor") + content = (ws / "conf.py").read_text() + assert "testauthor" in content