From bae44483f01fcde8e778d35dd16d35e1cb51041c Mon Sep 17 00:00:00 2001 From: Dyan Galih Date: Wed, 17 Jun 2026 03:49:27 +0000 Subject: [PATCH 1/3] feat: interactive approval flow for community extensions --- src/specify_cli/__init__.py | 149 +++++++++++++---------- src/specify_cli/__main__.py | 5 + src/specify_cli/extensions.py | 59 ++++++++- tests/test_extensions.py | 222 ++++++++++++++++++++++++++++++++++ 4 files changed, 373 insertions(+), 62 deletions(-) create mode 100644 src/specify_cli/__main__.py diff --git a/src/specify_cli/__init__.py b/src/specify_cli/__init__.py index 9d9a869b77..52d9dae944 100644 --- a/src/specify_cli/__init__.py +++ b/src/specify_cli/__init__.py @@ -879,7 +879,7 @@ def catalog_add( }) config["catalogs"] = catalogs - config_path.write_text(yaml.dump(config, default_flow_style=False, sort_keys=False, allow_unicode=True), encoding="utf-8") + config_path.write_text(yaml.safe_dump(config, default_flow_style=False, sort_keys=False, allow_unicode=True), encoding="utf-8") install_label = "install allowed" if install_allowed else "discovery only" console.print(f"\n[green]✓[/green] Added catalog '[bold]{name}[/bold]' ({install_label})") @@ -919,7 +919,7 @@ def catalog_remove( raise typer.Exit(1) config["catalogs"] = catalogs - config_path.write_text(yaml.dump(config, default_flow_style=False, sort_keys=False, allow_unicode=True), encoding="utf-8") + config_path.write_text(yaml.safe_dump(config, default_flow_style=False, sort_keys=False, allow_unicode=True), encoding="utf-8") console.print(f"[green]✓[/green] Removed catalog '{name}'") if not catalogs: @@ -987,8 +987,8 @@ def extension_add( raise typer.Exit(0) try: - with console.status(f"[cyan]Installing extension: {extension}[/cyan]"): - if dev: + if dev: + with console.status(f"[cyan]Installing extension: {extension}[/cyan]"): # Install from local directory source_path = Path(extension).expanduser().resolve() if not source_path.exists(): @@ -1010,12 +1010,13 @@ def extension_add( force=force ) - elif from_url: - # Install from URL (ZIP file) - import urllib.error + elif from_url: + # Install from URL (ZIP file) + import urllib.error - console.print(f"Downloading from {safe_url}...") + console.print(f"Downloading from {safe_url}...") + with console.status(f"[cyan]Installing extension: {extension}[/cyan]"): # Download ZIP to temp location download_dir = project_root / ".specify" / "extensions" / ".cache" / "downloads" download_dir.mkdir(parents=True, exist_ok=True) @@ -1038,66 +1039,86 @@ def extension_add( if zip_path.exists(): zip_path.unlink() - else: - # Try bundled extensions first (shipped with spec-kit) - bundled_path = _locate_bundled_extension(extension) - if bundled_path is not None: + else: + # Try bundled extensions first (shipped with spec-kit) + bundled_path = _locate_bundled_extension(extension) + if bundled_path is not None: + with console.status(f"[cyan]Installing extension: {extension}[/cyan]"): manifest = manager.install_from_directory( bundled_path, speckit_version, priority=priority, force=force ) - else: - # Install from catalog (also resolves display names to IDs) - catalog = ExtensionCatalog(project_root) + else: + # Install from catalog (also resolves display names to IDs) + catalog = ExtensionCatalog(project_root) - # Check if extension exists in catalog (supports both ID and display name) - ext_info, catalog_error = _resolve_catalog_extension(extension, catalog, "add") - if catalog_error: - console.print(f"[red]Error:[/red] Could not query extension catalog: {catalog_error}") - raise typer.Exit(1) - if not ext_info: - console.print(f"[red]Error:[/red] Extension '{extension}' not found in catalog") - console.print("\nSearch available extensions:") - console.print(" specify extension search") - raise typer.Exit(1) + # Check if extension exists in catalog (supports both ID and display name) + ext_info, catalog_error = _resolve_catalog_extension(extension, catalog, "add") + if catalog_error: + console.print(f"[red]Error:[/red] Could not query extension catalog: {catalog_error}") + raise typer.Exit(1) + if not ext_info: + console.print(f"[red]Error:[/red] Extension '{extension}' not found in catalog") + console.print("\nSearch available extensions:") + console.print(" specify extension search") + raise typer.Exit(1) - # If catalog resolved a display name to an ID, check bundled again - resolved_id = ext_info['id'] - if resolved_id != extension: - bundled_path = _locate_bundled_extension(resolved_id) - if bundled_path is not None: + # If catalog resolved a display name to an ID, check bundled again + resolved_id = ext_info['id'] + if resolved_id != extension: + bundled_path = _locate_bundled_extension(resolved_id) + if bundled_path is not None: + with console.status(f"[cyan]Installing extension: {extension}[/cyan]"): manifest = manager.install_from_directory( bundled_path, speckit_version, priority=priority, force=force ) - if bundled_path is None: - # Bundled extensions without a download URL must come from the local package - if ext_info.get("bundled") and not ext_info.get("download_url"): - console.print( - f"[red]Error:[/red] Extension '{ext_info['id']}' is bundled with spec-kit " - f"but could not be found in the installed package." - ) - console.print( - "\nThis usually means the spec-kit installation is incomplete or corrupted." - ) - console.print("Try reinstalling spec-kit:") - console.print(f" {REINSTALL_COMMAND}") - raise typer.Exit(1) - - # Enforce install_allowed policy - if not ext_info.get("_install_allowed", True): - catalog_name = ext_info.get("_catalog_name", "community") - console.print( - f"[red]Error:[/red] '{extension}' is available in the " - f"'{catalog_name}' catalog but installation is not allowed from that catalog." - ) - console.print( - f"\nTo enable installation, add '{extension}' to an approved catalog " - f"(install_allowed: true) in .specify/extension-catalogs.yml." + if bundled_path is None: + # Bundled extensions without a download URL must come from the local package + if ext_info.get("bundled") and not ext_info.get("download_url"): + console.print( + f"[red]Error:[/red] Extension '{ext_info['id']}' is bundled with spec-kit " + f"but could not be found in the installed package." + ) + console.print( + "\nThis usually means the spec-kit installation is incomplete or corrupted." + ) + console.print("Try reinstalling spec-kit:") + console.print(f" {REINSTALL_COMMAND}") + raise typer.Exit(1) + + # Enforce install_allowed policy + if not ext_info.get("_install_allowed", True): + catalog_name = ext_info.get("_catalog_name", "community") + console.print() + console.print( + Panel( + f"[bold]'{ext_info['name']}' is available in the '{catalog_name}' catalog " + f"but installation is not allowed from that catalog.[/bold]\n\n" + f"Approve installation from '{catalog_name}' for this project?\n" + "This will update .specify/extension-catalogs.yml so future installs " + "from that catalog are allowed.", + title="[bold yellow]Catalog Approval Required[/bold yellow]", + border_style="yellow", + padding=(1, 2), ) - raise typer.Exit(1) + ) + console.print() + try: + confirm = typer.confirm("Approve catalog and continue?", default=False) + except (typer.Abort, KeyboardInterrupt): + console.print("Cancelled") + raise typer.Exit(0) + if not confirm: + console.print("Cancelled") + raise typer.Exit(0) + approved_catalog = catalog.approve_catalog_install(catalog_name) + console.print( + f"[green]✓[/green] Approved catalog '[bold]{approved_catalog.name}[/bold]' for installation" + ) - # Download extension ZIP (use resolved ID, not original argument which may be display name) - extension_id = ext_info['id'] + # Download extension ZIP (use resolved ID, not original argument which may be display name) + extension_id = ext_info['id'] + with console.status(f"[cyan]Installing extension: {ext_info['name']}[/cyan]"): console.print(f"Downloading {ext_info['name']} v{ext_info.get('version', 'unknown')}...") zip_path = catalog.download_extension(extension_id) @@ -1286,8 +1307,9 @@ def extension_search( else: console.print(f"\n [yellow]⚠[/yellow] Not directly installable from '{catalog_name}'.") console.print( - f" Add to an approved catalog with install_allowed: true, " - f"or install from a ZIP URL: specify extension add {ext['id']} --from " + f" Run [cyan]specify extension add {ext['id']}[/cyan] to approve " + f"the catalog and install, or use a ZIP URL: " + f"specify extension add {ext['id']} --from " ) console.print() @@ -1486,8 +1508,13 @@ def _print_extension_info(ext_info: dict, manager): console.print("[yellow]Not installed[/yellow]") console.print( f"\n[yellow]⚠[/yellow] '{ext_info['id']}' is available in the '{catalog_name}' catalog " - f"but not in your approved catalog. Add it to .specify/extension-catalogs.yml " - f"with install_allowed: true to enable installation." + f"but installation is not currently allowed from that catalog." + ) + console.print( + f"\n[cyan]Install:[/cyan] specify extension add {ext_info['id']}" + ) + console.print( + "[dim]You will be prompted to approve the catalog before installation proceeds.[/dim]" ) diff --git a/src/specify_cli/__main__.py b/src/specify_cli/__main__.py new file mode 100644 index 0000000000..2852c0d78e --- /dev/null +++ b/src/specify_cli/__main__.py @@ -0,0 +1,5 @@ +import sys +from specify_cli import main + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/specify_cli/extensions.py b/src/specify_cli/extensions.py index df2eae0083..b7cb509b29 100644 --- a/src/specify_cli/extensions.py +++ b/src/specify_cli/extensions.py @@ -1275,7 +1275,7 @@ def check_compatibility( # Parse version specifier (e.g., ">=0.1.0,<2.0.0") try: specifier = SpecifierSet(required) - if current not in specifier: + if not specifier.contains(current, prereleases=True): raise CompatibilityError( f"Extension requires spec-kit {required}, " f"but {speckit_version} is installed.\n" @@ -2100,6 +2100,63 @@ def get_active_catalogs(self) -> List[CatalogEntry]: ), ] + def _catalog_entry_to_dict(self, entry: CatalogEntry) -> Dict[str, Any]: + """Serialize a catalog entry back to YAML config shape.""" + return { + "name": entry.name, + "url": entry.url, + "priority": entry.priority, + "install_allowed": entry.install_allowed, + "description": entry.description, + } + + def approve_catalog_install(self, catalog_name: str) -> CatalogEntry: + """Persist install permission for a catalog while preserving the stack.""" + active_catalogs = self.get_active_catalogs() + updated_catalogs: List[Dict[str, Any]] = [] + approved_entry: Optional[CatalogEntry] = None + + for entry in active_catalogs: + if entry.name == catalog_name: + entry = self._entry( + url=entry.url, + name=entry.name, + priority=entry.priority, + install_allowed=True, + description=entry.description, + ) + approved_entry = entry + updated_catalogs.append(self._catalog_entry_to_dict(entry)) + + if approved_entry is None: + raise ValidationError( + f"Catalog '{catalog_name}' is not active and cannot be approved" + ) + + project_root = self.project_root.resolve() + config_path = self.project_root / ".specify" / self.CONFIG_FILENAME + resolved_parent = config_path.parent.resolve() + if not resolved_parent.is_relative_to(project_root): + raise ValidationError( + "Refusing to write catalog config outside the project root" + ) + if config_path.exists() and config_path.is_symlink(): + raise ValidationError( + f"Refusing to write catalog config via symlink: {config_path}" + ) + + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text( + yaml.safe_dump( + {"catalogs": updated_catalogs}, + default_flow_style=False, + sort_keys=False, + allow_unicode=True, + ), + encoding="utf-8", + ) + return approved_entry + def get_catalog_url(self) -> str: """Get the primary catalog URL. diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 1d05e1c2c4..11347329ce 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -4171,6 +4171,43 @@ def test_search_results_include_catalog_metadata(self, temp_dir): assert results[0]["_catalog_name"] == "org" assert results[0]["_install_allowed"] is True + def test_approve_catalog_install_preserves_active_stack(self, temp_dir): + """Approving one catalog should rewrite the active stack without dropping other entries.""" + import yaml as yaml_module + + project_dir = self._make_project(temp_dir) + catalog = ExtensionCatalog(project_dir) + + approved = catalog.approve_catalog_install("community") + + config_path = project_dir / ".specify" / "extension-catalogs.yml" + parsed = yaml_module.safe_load(config_path.read_text(encoding="utf-8")) + + assert approved.name == "community" + assert approved.install_allowed is True + assert len(parsed["catalogs"]) == 2 + assert parsed["catalogs"][0]["name"] == "default" + assert parsed["catalogs"][0]["install_allowed"] is True + assert parsed["catalogs"][1]["name"] == "community" + assert parsed["catalogs"][1]["install_allowed"] is True + + def test_approve_catalog_install_rejects_symlinked_specify_dir(self, temp_dir): + """Approval writes fail closed when .specify resolves outside the project root.""" + project_dir = self._make_project(temp_dir) + if not can_create_symlink(temp_dir): + pytest.skip("Symlinks are not available on this platform") + + external_specify = temp_dir / "external-specify" + external_specify.mkdir() + symlink_path = project_dir / ".specify" + shutil.rmtree(symlink_path) + os.symlink(external_specify, symlink_path) + + catalog = ExtensionCatalog(project_dir) + + with pytest.raises(ValidationError, match="outside the project root"): + catalog.approve_catalog_install("community") + class TestExtensionIgnore: """Test .extensionignore support during extension installation.""" @@ -4760,6 +4797,191 @@ def test_add_bundled_extension_not_found_gives_clear_error(self, tmp_path): assert "bundled with spec-kit" in result.output assert "reinstall" in result.output.lower() + def test_add_blocked_extension_approval_updates_project_catalog_config(self, tmp_path): + """Approving a blocked catalog extension should update the project catalog config and continue installation.""" + from typer.testing import CliRunner + from unittest.mock import patch + from types import SimpleNamespace + from specify_cli import app + import yaml as yaml_module + + runner = CliRunner() + project_dir = tmp_path / "test-project" + project_dir.mkdir() + (project_dir / ".specify").mkdir() + + zip_path = tmp_path / "security-review.zip" + zip_path.write_bytes(b"fake-zip") + + mock_manifest = SimpleNamespace( + id="security-review", + name="Security Review", + version="1.0.0", + description="Security review extension", + warnings=[], + commands=[], + ) + + with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ + "id": "security-review", + "name": "Security Review", + "version": "1.0.0", + "description": "Security review extension", + "_catalog_name": "community", + "_install_allowed": False, + }), patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path), patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest), patch("typer.confirm", return_value=True), patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke( + app, + ["extension", "add", "security-review"], + catch_exceptions=True, + ) + + assert result.exit_code == 0, result.output + assert "Catalog Approval Required" in result.output + assert "Approved catalog" in result.output + assert "manually" not in result.output.lower() + + config_path = project_dir / ".specify" / "extension-catalogs.yml" + parsed = yaml_module.safe_load(config_path.read_text(encoding="utf-8")) + assert [entry["name"] for entry in parsed["catalogs"]] == ["default", "community"] + assert parsed["catalogs"][0]["install_allowed"] is True + assert parsed["catalogs"][1]["install_allowed"] is True + + def test_add_blocked_extension_prompt_comes_before_spinner(self, tmp_path): + """The approval prompt should appear before any spinner or install work starts.""" + from typer.testing import CliRunner + from unittest.mock import patch, MagicMock + from specify_cli import app + + runner = CliRunner() + project_dir = tmp_path / "test-project" + project_dir.mkdir() + (project_dir / ".specify").mkdir() + + call_order: list[str] = [] + + def record_status(*args, **kwargs): + call_order.append("spinner") + return MagicMock() + + with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ + "id": "security-review", + "name": "Security Review", + "version": "1.0.0", + "description": "Security review extension", + "_catalog_name": "community", + "_install_allowed": False, + }), patch("typer.confirm", side_effect=lambda *a, **kw: (call_order.append("confirm"), False)[-1]), patch("specify_cli.console.status", side_effect=record_status), patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke( + app, + ["extension", "add", "security-review"], + catch_exceptions=True, + ) + + assert result.exit_code == 0, result.output + assert call_order and call_order[0] == "confirm" + assert "spinner" not in call_order + assert "Cancelled" in result.output + + def test_add_blocked_extension_cancel_leaves_config_unchanged(self, tmp_path): + """Cancelling approval should not create or modify project catalog config.""" + from typer.testing import CliRunner + from unittest.mock import patch + from specify_cli import app + + runner = CliRunner() + project_dir = tmp_path / "test-project" + project_dir.mkdir() + (project_dir / ".specify").mkdir() + + with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ + "id": "security-review", + "name": "Security Review", + "version": "1.0.0", + "description": "Security review extension", + "_catalog_name": "community", + "_install_allowed": False, + }), patch("typer.confirm", return_value=False), patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke( + app, + ["extension", "add", "security-review"], + catch_exceptions=True, + ) + + assert result.exit_code == 0, result.output + assert "Cancelled" in result.output + assert not (project_dir / ".specify" / "extension-catalogs.yml").exists() + + def test_add_approved_catalog_skips_approval_prompt(self, tmp_path): + """Already-approved catalogs should install directly without the guided approval prompt.""" + from typer.testing import CliRunner + from unittest.mock import patch + from types import SimpleNamespace + from specify_cli import app + import contextlib + + runner = CliRunner() + project_dir = tmp_path / "test-project" + project_dir.mkdir() + (project_dir / ".specify").mkdir() + + zip_path = tmp_path / "approved-extension.zip" + zip_path.write_bytes(b"fake-zip") + mock_manifest = SimpleNamespace( + id="security-review", + name="Security Review", + version="1.0.0", + description="Security review extension", + warnings=[], + commands=[], + ) + + def unexpected_confirm(*args, **kwargs): + raise AssertionError("Approval prompt should not run for approved catalogs") + + with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ + "id": "security-review", + "name": "Security Review", + "version": "1.0.0", + "description": "Security review extension", + "_catalog_name": "default", + "_install_allowed": True, + }), patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path), patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest), patch("typer.confirm", side_effect=unexpected_confirm), patch("specify_cli.console.status", return_value=contextlib.nullcontext()), patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke( + app, + ["extension", "add", "security-review"], + catch_exceptions=True, + ) + + assert result.exit_code == 0, result.output + assert "Catalog Approval Required" not in result.output + + def test_add_not_found_still_reports_missing_extension(self, tmp_path): + """Missing catalog entries should still use the existing not-found path.""" + from typer.testing import CliRunner + from unittest.mock import patch, MagicMock + from specify_cli import app + + runner = CliRunner() + project_dir = tmp_path / "test-project" + project_dir.mkdir() + (project_dir / ".specify").mkdir() + + mock_catalog = MagicMock() + mock_catalog.get_extension_info.return_value = None + mock_catalog.search.return_value = [] + + with patch("specify_cli.extensions.ExtensionCatalog", return_value=mock_catalog), patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke( + app, + ["extension", "add", "does-not-exist"], + catch_exceptions=True, + ) + + assert result.exit_code != 0 + assert "not found in catalog" in result.output + assert "Approve catalog" not in result.output + def test_add_from_url_prompts_before_spinner(self, tmp_path): """Confirm prompt for --from must fire before the console.status spinner. From 83d8013069ffdc538ff9c219f3c0d09896dde8a0 Mon Sep 17 00:00:00 2001 From: Dyan Galih Date: Wed, 17 Jun 2026 08:18:22 +0000 Subject: [PATCH 2/3] Address PR feedback and improve install flow --- src/specify_cli/__init__.py | 7 +- src/specify_cli/extensions.py | 62 +++++++++++++- tests/test_extensions.py | 157 ++++++++++++++++++++++++++++++++-- 3 files changed, 212 insertions(+), 14 deletions(-) diff --git a/src/specify_cli/__init__.py b/src/specify_cli/__init__.py index 52d9dae944..f32aa7db29 100644 --- a/src/specify_cli/__init__.py +++ b/src/specify_cli/__init__.py @@ -1086,7 +1086,12 @@ def extension_add( console.print(f" {REINSTALL_COMMAND}") raise typer.Exit(1) - # Enforce install_allowed policy + # If a different approved source exists, use it instead of prompting. + installable_info = catalog.get_installable_extension_info(resolved_id) + if installable_info is not None: + ext_info = installable_info + + # Enforce install_allowed policy only when no approved source exists. if not ext_info.get("_install_allowed", True): catalog_name = ext_info.get("_catalog_name", "community") console.print() diff --git a/src/specify_cli/extensions.py b/src/specify_cli/extensions.py index b7cb509b29..99d73dcdff 100644 --- a/src/specify_cli/extensions.py +++ b/src/specify_cli/extensions.py @@ -2112,11 +2112,19 @@ def _catalog_entry_to_dict(self, entry: CatalogEntry) -> Dict[str, Any]: def approve_catalog_install(self, catalog_name: str) -> CatalogEntry: """Persist install permission for a catalog while preserving the stack.""" - active_catalogs = self.get_active_catalogs() + config_path = self.project_root / ".specify" / self.CONFIG_FILENAME + + # Base the update on the project-level config if it exists + if config_path.exists(): + base_catalogs = self._load_catalog_config(config_path) or [] + else: + # Otherwise, preserve the currently active stack so user-level catalogs remain available. + base_catalogs = self.get_active_catalogs() + updated_catalogs: List[Dict[str, Any]] = [] approved_entry: Optional[CatalogEntry] = None - for entry in active_catalogs: + for entry in base_catalogs: if entry.name == catalog_name: entry = self._entry( url=entry.url, @@ -2128,6 +2136,22 @@ def approve_catalog_install(self, catalog_name: str) -> CatalogEntry: approved_entry = entry updated_catalogs.append(self._catalog_entry_to_dict(entry)) + # If the catalog wasn't found in the base (e.g., a custom user-level catalog), + # we pull it from the active catalogs and append it to the project stack. + if approved_entry is None: + for entry in self.get_active_catalogs(): + if entry.name == catalog_name: + entry = self._entry( + url=entry.url, + name=entry.name, + priority=entry.priority, + install_allowed=True, + description=entry.description, + ) + approved_entry = entry + updated_catalogs.append(self._catalog_entry_to_dict(entry)) + break + if approved_entry is None: raise ValidationError( f"Catalog '{catalog_name}' is not active and cannot be approved" @@ -2542,6 +2566,36 @@ def get_extension_info(self, extension_id: str) -> Optional[Dict[str, Any]]: return ext_data return None + def get_installable_extension_info(self, extension_id: str) -> Optional[Dict[str, Any]]: + """Return the first installable source for an extension, if any. + + This checks the active catalogs in priority order and returns the + highest-priority source that is actually allowed to install. It is + used by the add flow to avoid prompting for approval when a usable + approved source already exists. + """ + for catalog_entry in self.get_active_catalogs(): + try: + catalog_data = self._fetch_single_catalog(catalog_entry, force_refresh=False) + except ExtensionError: + continue + + ext_data = catalog_data.get("extensions", {}).get(extension_id) + if not isinstance(ext_data, dict): + continue + + if not catalog_entry.install_allowed: + continue + + return { + **ext_data, + "id": extension_id, + "_catalog_name": catalog_entry.name, + "_install_allowed": catalog_entry.install_allowed, + } + + return None + def download_extension( self, extension_id: str, target_dir: Optional[Path] = None ) -> Path: @@ -2559,8 +2613,8 @@ def download_extension( """ import urllib.error - # Get extension info from catalog - ext_info = self.get_extension_info(extension_id) + # Get the best installable source first, then fall back to the merged view. + ext_info = self.get_installable_extension_info(extension_id) or self.get_extension_info(extension_id) if not ext_info: raise ExtensionError(f"Extension '{extension_id}' not found in catalog") diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 11347329ce..5ed82822ef 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -3623,6 +3623,7 @@ def fake_open(req, timeout=None): } with patch.object(catalog, "get_extension_info", return_value=ext_info), \ + patch.object(catalog, "get_installable_extension_info", return_value=ext_info), \ patch("specify_cli.authentication.http.urllib.request.build_opener", return_value=mock_opener): catalog.download_extension("test-ext", target_dir=temp_dir) @@ -3669,6 +3670,7 @@ def fake_open(req, timeout=None): } with patch.object(catalog, "get_extension_info", return_value=ext_info), \ + patch.object(catalog, "get_installable_extension_info", return_value=ext_info), \ patch("specify_cli.authentication.http.urllib.request.build_opener", return_value=mock_opener): catalog.download_extension("test-ext", target_dir=temp_dir) @@ -4191,6 +4193,58 @@ def test_approve_catalog_install_preserves_active_stack(self, temp_dir): assert parsed["catalogs"][1]["name"] == "community" assert parsed["catalogs"][1]["install_allowed"] is True + def test_approve_catalog_install_preserves_user_level_active_catalogs(self, temp_dir): + """Approving a catalog should preserve the full active stack when no project config exists.""" + import yaml as yaml_module + from unittest.mock import patch + + project_dir = self._make_project(temp_dir) + home_dir = temp_dir / "home" + specify_home = home_dir / ".specify" + specify_home.mkdir(parents=True) + with (specify_home / "extension-catalogs.yml").open("w", encoding="utf-8") as f: + yaml_module.safe_dump( + { + "catalogs": [ + { + "name": "alpha", + "url": "https://alpha.example.com/catalog.json", + "priority": 1, + "install_allowed": False, + }, + { + "name": "community", + "url": ExtensionCatalog.COMMUNITY_CATALOG_URL, + "priority": 2, + "install_allowed": False, + }, + { + "name": "beta", + "url": "https://beta.example.com/catalog.json", + "priority": 3, + "install_allowed": True, + }, + ] + }, + f, + sort_keys=False, + allow_unicode=True, + ) + + catalog = ExtensionCatalog(project_dir) + + with patch("specify_cli.extensions.Path.home", return_value=home_dir): + approved = catalog.approve_catalog_install("community") + + config_path = project_dir / ".specify" / "extension-catalogs.yml" + parsed = yaml_module.safe_load(config_path.read_text(encoding="utf-8")) + + assert approved.name == "community" + assert [entry["name"] for entry in parsed["catalogs"]] == ["alpha", "community", "beta"] + assert parsed["catalogs"][0]["install_allowed"] is False + assert parsed["catalogs"][1]["install_allowed"] is True + assert parsed["catalogs"][2]["install_allowed"] is True + def test_approve_catalog_install_rejects_symlinked_specify_dir(self, temp_dir): """Approval writes fail closed when .specify resolves outside the project root.""" project_dir = self._make_project(temp_dir) @@ -4720,6 +4774,7 @@ def test_add_by_display_name_uses_resolved_id_for_download(self, tmp_path): # Mock catalog that returns extension by display name mock_catalog = MagicMock() mock_catalog.get_extension_info.return_value = None # ID lookup fails + mock_catalog.get_installable_extension_info.return_value = None # Installable lookup fails mock_catalog.search.return_value = [ { "id": "acme-jira-integration", @@ -4822,14 +4877,20 @@ def test_add_blocked_extension_approval_updates_project_catalog_config(self, tmp commands=[], ) - with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ + with ( + patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ "id": "security-review", "name": "Security Review", "version": "1.0.0", "description": "Security review extension", "_catalog_name": "community", "_install_allowed": False, - }), patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path), patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest), patch("typer.confirm", return_value=True), patch.object(Path, "cwd", return_value=project_dir): + }), + patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path), + patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest), + patch("typer.confirm", return_value=True), + patch.object(Path, "cwd", return_value=project_dir), + ): result = runner.invoke( app, ["extension", "add", "security-review"], @@ -4864,14 +4925,19 @@ def record_status(*args, **kwargs): call_order.append("spinner") return MagicMock() - with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ + with ( + patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ "id": "security-review", "name": "Security Review", "version": "1.0.0", "description": "Security review extension", "_catalog_name": "community", "_install_allowed": False, - }), patch("typer.confirm", side_effect=lambda *a, **kw: (call_order.append("confirm"), False)[-1]), patch("specify_cli.console.status", side_effect=record_status), patch.object(Path, "cwd", return_value=project_dir): + }), + patch("typer.confirm", side_effect=lambda *a, **kw: (call_order.append("confirm"), False)[-1]), + patch("specify_cli.console.status", side_effect=record_status), + patch.object(Path, "cwd", return_value=project_dir), + ): result = runner.invoke( app, ["extension", "add", "security-review"], @@ -4894,14 +4960,18 @@ def test_add_blocked_extension_cancel_leaves_config_unchanged(self, tmp_path): project_dir.mkdir() (project_dir / ".specify").mkdir() - with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ + with ( + patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ "id": "security-review", "name": "Security Review", "version": "1.0.0", "description": "Security review extension", "_catalog_name": "community", "_install_allowed": False, - }), patch("typer.confirm", return_value=False), patch.object(Path, "cwd", return_value=project_dir): + }), + patch("typer.confirm", return_value=False), + patch.object(Path, "cwd", return_value=project_dir), + ): result = runner.invoke( app, ["extension", "add", "security-review"], @@ -4939,14 +5009,80 @@ def test_add_approved_catalog_skips_approval_prompt(self, tmp_path): def unexpected_confirm(*args, **kwargs): raise AssertionError("Approval prompt should not run for approved catalogs") - with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ + with ( + patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ "id": "security-review", "name": "Security Review", "version": "1.0.0", "description": "Security review extension", "_catalog_name": "default", "_install_allowed": True, - }), patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path), patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest), patch("typer.confirm", side_effect=unexpected_confirm), patch("specify_cli.console.status", return_value=contextlib.nullcontext()), patch.object(Path, "cwd", return_value=project_dir): + }), + patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path), + patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest), + patch("typer.confirm", side_effect=unexpected_confirm), + patch("specify_cli.console.status", return_value=contextlib.nullcontext()), + patch.object(Path, "cwd", return_value=project_dir), + ): + result = runner.invoke( + app, + ["extension", "add", "security-review"], + catch_exceptions=True, + ) + + assert result.exit_code == 0, result.output + assert "Catalog Approval Required" not in result.output + + def test_add_prefers_approved_source_over_blocked_duplicate(self, tmp_path): + """If the same extension exists in approved and blocked catalogs, the add flow should skip approval.""" + from typer.testing import CliRunner + from unittest.mock import patch + from types import SimpleNamespace + from specify_cli import app + import contextlib + + runner = CliRunner() + project_dir = tmp_path / "test-project" + project_dir.mkdir() + (project_dir / ".specify").mkdir() + + zip_path = tmp_path / "approved-duplicate.zip" + zip_path.write_bytes(b"fake-zip") + mock_manifest = SimpleNamespace( + id="security-review", + name="Security Review", + version="1.0.0", + description="Security review extension", + warnings=[], + commands=[], + ) + + def unexpected_confirm(*args, **kwargs): + raise AssertionError("Approval prompt should not run when an approved source exists") + + with ( + patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ + "id": "security-review", + "name": "Security Review", + "version": "1.0.0", + "description": "Security review extension", + "_catalog_name": "community", + "_install_allowed": False, + }), + patch("specify_cli.extensions.ExtensionCatalog.get_installable_extension_info", return_value={ + "id": "security-review", + "name": "Security Review", + "version": "1.0.0", + "description": "Security review extension", + "_catalog_name": "default", + "_install_allowed": True, + }), + patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path), + patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest), + patch("typer.confirm", side_effect=unexpected_confirm), + patch("specify_cli.console.status", return_value=contextlib.nullcontext()), + patch.object(Path, "cwd", return_value=project_dir), + ): result = runner.invoke( app, ["extension", "add", "security-review"], @@ -4971,7 +5107,10 @@ def test_add_not_found_still_reports_missing_extension(self, tmp_path): mock_catalog.get_extension_info.return_value = None mock_catalog.search.return_value = [] - with patch("specify_cli.extensions.ExtensionCatalog", return_value=mock_catalog), patch.object(Path, "cwd", return_value=project_dir): + with ( + patch("specify_cli.extensions.ExtensionCatalog", return_value=mock_catalog), + patch.object(Path, "cwd", return_value=project_dir), + ): result = runner.invoke( app, ["extension", "add", "does-not-exist"], From 21cde5f3b0883b76de4a5800e6286581cc16a7f7 Mon Sep 17 00:00:00 2001 From: Dyan Galih Date: Wed, 17 Jun 2026 11:55:05 +0000 Subject: [PATCH 3/3] Address remaining PR feedback --- src/specify_cli/__init__.py | 67 +++++++++++++++++++---------------- src/specify_cli/extensions.py | 24 ++++++------- tests/test_extensions.py | 5 +++ 3 files changed, 53 insertions(+), 43 deletions(-) diff --git a/src/specify_cli/__init__.py b/src/specify_cli/__init__.py index f32aa7db29..746ef448cb 100644 --- a/src/specify_cli/__init__.py +++ b/src/specify_cli/__init__.py @@ -1086,40 +1086,45 @@ def extension_add( console.print(f" {REINSTALL_COMMAND}") raise typer.Exit(1) - # If a different approved source exists, use it instead of prompting. - installable_info = catalog.get_installable_extension_info(resolved_id) - if installable_info is not None: - ext_info = installable_info - # Enforce install_allowed policy only when no approved source exists. if not ext_info.get("_install_allowed", True): - catalog_name = ext_info.get("_catalog_name", "community") - console.print() - console.print( - Panel( - f"[bold]'{ext_info['name']}' is available in the '{catalog_name}' catalog " - f"but installation is not allowed from that catalog.[/bold]\n\n" - f"Approve installation from '{catalog_name}' for this project?\n" - "This will update .specify/extension-catalogs.yml so future installs " - "from that catalog are allowed.", - title="[bold yellow]Catalog Approval Required[/bold yellow]", - border_style="yellow", - padding=(1, 2), + # If a different approved source exists, use it instead of prompting. + installable_info = catalog.get_installable_extension_info(resolved_id) + if installable_info is not None: + ext_info = installable_info + else: + catalog_name = ext_info.get("_catalog_name", "community") + console.print() + console.print( + Panel( + f"[bold]'{ext_info['name']}' is available in the '{catalog_name}' catalog " + f"but installation is not allowed from that catalog.[/bold]\n\n" + f"Approve installation from '{catalog_name}' for this project?\n" + "This will update .specify/extension-catalogs.yml so future installs " + "from that catalog are allowed.", + title="[bold yellow]Catalog Approval Required[/bold yellow]", + border_style="yellow", + padding=(1, 2), + ) ) - ) - console.print() - try: - confirm = typer.confirm("Approve catalog and continue?", default=False) - except (typer.Abort, KeyboardInterrupt): - console.print("Cancelled") - raise typer.Exit(0) - if not confirm: - console.print("Cancelled") - raise typer.Exit(0) - approved_catalog = catalog.approve_catalog_install(catalog_name) - console.print( - f"[green]✓[/green] Approved catalog '[bold]{approved_catalog.name}[/bold]' for installation" - ) + console.print() + try: + confirm = typer.confirm("Approve catalog and continue?", default=False) + except (typer.Abort, KeyboardInterrupt): + console.print("Cancelled") + raise typer.Exit(0) + if not confirm: + console.print("Cancelled") + raise typer.Exit(0) + + try: + approved_catalog = catalog.approve_catalog_install(catalog_name) + console.print( + f"[green]✓[/green] Approved catalog '[bold]{approved_catalog.name}[/bold]' for installation" + ) + except ValidationError as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(1) # Download extension ZIP (use resolved ID, not original argument which may be display name) extension_id = ext_info['id'] diff --git a/src/specify_cli/extensions.py b/src/specify_cli/extensions.py index 99d73dcdff..811dbd95d5 100644 --- a/src/specify_cli/extensions.py +++ b/src/specify_cli/extensions.py @@ -2114,6 +2114,18 @@ def approve_catalog_install(self, catalog_name: str) -> CatalogEntry: """Persist install permission for a catalog while preserving the stack.""" config_path = self.project_root / ".specify" / self.CONFIG_FILENAME + # Path safety checks first + project_root = self.project_root.resolve() + resolved_parent = config_path.parent.resolve() + if not resolved_parent.is_relative_to(project_root): + raise ValidationError( + "Refusing to read or write catalog config outside the project root" + ) + if config_path.is_symlink(): + raise ValidationError( + f"Refusing to read or write catalog config via symlink: {config_path}" + ) + # Base the update on the project-level config if it exists if config_path.exists(): base_catalogs = self._load_catalog_config(config_path) or [] @@ -2157,18 +2169,6 @@ def approve_catalog_install(self, catalog_name: str) -> CatalogEntry: f"Catalog '{catalog_name}' is not active and cannot be approved" ) - project_root = self.project_root.resolve() - config_path = self.project_root / ".specify" / self.CONFIG_FILENAME - resolved_parent = config_path.parent.resolve() - if not resolved_parent.is_relative_to(project_root): - raise ValidationError( - "Refusing to write catalog config outside the project root" - ) - if config_path.exists() and config_path.is_symlink(): - raise ValidationError( - f"Refusing to write catalog config via symlink: {config_path}" - ) - config_path.parent.mkdir(parents=True, exist_ok=True) config_path.write_text( yaml.safe_dump( diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 5ed82822ef..493276826a 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -4878,6 +4878,7 @@ def test_add_blocked_extension_approval_updates_project_catalog_config(self, tmp ) with ( + patch("specify_cli.extensions.ExtensionCatalog.get_installable_extension_info", return_value=None), patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ "id": "security-review", "name": "Security Review", @@ -4926,6 +4927,7 @@ def record_status(*args, **kwargs): return MagicMock() with ( + patch("specify_cli.extensions.ExtensionCatalog.get_installable_extension_info", return_value=None), patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ "id": "security-review", "name": "Security Review", @@ -4961,6 +4963,7 @@ def test_add_blocked_extension_cancel_leaves_config_unchanged(self, tmp_path): (project_dir / ".specify").mkdir() with ( + patch("specify_cli.extensions.ExtensionCatalog.get_installable_extension_info", return_value=None), patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ "id": "security-review", "name": "Security Review", @@ -5010,6 +5013,7 @@ def unexpected_confirm(*args, **kwargs): raise AssertionError("Approval prompt should not run for approved catalogs") with ( + patch("specify_cli.extensions.ExtensionCatalog.get_installable_extension_info", return_value=None), patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ "id": "security-review", "name": "Security Review", @@ -5061,6 +5065,7 @@ def unexpected_confirm(*args, **kwargs): raise AssertionError("Approval prompt should not run when an approved source exists") with ( + patch("specify_cli.extensions.ExtensionCatalog.get_installable_extension_info", return_value=None), patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={ "id": "security-review", "name": "Security Review",