From 5e0e13c26ccfca8d966f3c29ef942a2f4e8aa407 Mon Sep 17 00:00:00 2001 From: Matthieu Irondelle Date: Fri, 6 Mar 2026 12:56:26 +0100 Subject: [PATCH 01/10] add refresh button to stride ui --- src/stride/ui/app.py | 158 +++++++++++++++++++++++++--- src/stride/ui/assets/dark-theme.css | 8 +- src/stride/ui/project_manager.py | 2 +- 3 files changed, 150 insertions(+), 18 deletions(-) diff --git a/src/stride/ui/app.py b/src/stride/ui/app.py index a23be02..4012c25 100644 --- a/src/stride/ui/app.py +++ b/src/stride/ui/app.py @@ -41,6 +41,24 @@ _loaded_projects: dict[str, tuple[Project, ColorManager, StridePlots, str]] = {} _current_project_path: str | None = None +# Maximum number of projects to keep open simultaneously. +# Each open project holds a DuckDB connection with file descriptors; +# on BlobFuse2 FUSE mounts too many concurrent connections cause [Errno 5]. +MAX_CACHED_PROJECTS = 3 + + +def _evict_oldest_project() -> None: + """Evict the least-recently-used project from the cache if at capacity.""" + while len(_loaded_projects) >= MAX_CACHED_PROJECTS: + # Dict is insertion-ordered; first key is the oldest (LRU) + oldest_path = next(iter(_loaded_projects)) + old_project, _, _, old_name = _loaded_projects.pop(oldest_path) + try: + old_project.close() + logger.info(f"Evicted and closed project '{old_name}' at {oldest_path}") + except Exception as e: + logger.warning(f"Error closing evicted project at {oldest_path}: {e}") + def create_fresh_color_manager(palette: ColorPalette, scenarios: list[str]) -> ColorManager: """Create a fresh ColorManager instance, bypassing the singleton. @@ -83,7 +101,7 @@ def load_project(project_path: str) -> tuple[bool, str]: (success, message) where success is True if loaded successfully """ global _loaded_projects, _current_project_path - + try: path = Path(project_path).resolve() path_str = str(path) @@ -91,11 +109,17 @@ def load_project(project_path: str) -> tuple[bool, str]: # Check if already loaded - just switch to it if path_str in _loaded_projects: _current_project_path = path_str + # Move to end of dict for LRU ordering (most recently used = last) + entry = _loaded_projects.pop(path_str) + _loaded_projects[path_str] = entry + cached_project, _, _, project_name = entry # Update the APIClient singleton to point to this project - cached_project, _, _, project_name = _loaded_projects[path_str] APIClient(cached_project) # Updates singleton return True, f"Switched to cached project: {project_name}" + # Evict oldest project(s) if at capacity + _evict_oldest_project() + # Load new project project = Project.load(path, read_only=True) data_handler = APIClient(project) # Updates singleton @@ -310,15 +334,34 @@ def create_app( # noqa: C901 className="small mb-2", style={"fontSize": "0.8rem"}, ), - # Dropdown for available projects (recent + discovered) - dcc.Dropdown( - id="project-switcher-dropdown", - options=dropdown_options, # type: ignore[arg-type] - value=current_project_path, - placeholder="Switch project...", + + html.Div( + [ + dcc.Dropdown( + id="project-switcher-dropdown", + options=dropdown_options, # type: ignore[arg-type] + value=current_project_path, + placeholder="Switch project...", + className="mb-2", + style={"fontSize": "0.85rem", "width": "calc(100% - 35px)", "display": "inline-block"}, + clearable=False, + ), + html.Button( + html.I(className="bi bi-arrow-clockwise"), + id="refresh-projects-btn", + className="btn btn-sm btn-outline-secondary", + style={ + "width": "30px", + "height": "38px", + "marginLeft": "5px", + "verticalAlign": "top", + "padding": "0", + }, + title="Refresh project list", + ), + ], className="mb-2", - style={"fontSize": "0.85rem"}, - clearable=False, + style={"display": "flex", "alignItems": "flex-start"}, ), ] ), @@ -870,7 +913,10 @@ def handle_project_switch( if dropdown_value in _loaded_projects: # Project already loaded - just switch to it _current_project_path = dropdown_value - cached_project, _, _, project_name = _loaded_projects[dropdown_value] + # Move to end of dict for LRU ordering + entry = _loaded_projects.pop(dropdown_value) + _loaded_projects[dropdown_value] = entry + cached_project, _, _, project_name = entry # Update the APIClient singleton to point to this project APIClient(cached_project) return ( @@ -929,6 +975,42 @@ def update_navigation_tabs(project_path: str) -> tuple[list[dict[str, str]], str ] return options, "compare" # Reset to home view + # Refresh dropdown options when refresh button is clicked + @callback( + Output("project-switcher-dropdown", "options", allow_duplicate=True), + Input("refresh-projects-btn", "n_clicks"), + State("current-project-path", "data"), + prevent_initial_call=True, + ) + def refresh_dropdown_options(n_clicks: int | None, current_path: str) -> list[dict[str, str]]: + """Refresh the project switcher dropdown options with latest recent projects.""" + if not n_clicks: + raise PreventUpdate + + # Get current project info + data_handler = get_current_data_handler() + if not data_handler: + raise PreventUpdate + + current_project_name = data_handler.project.config.project_id + + # Build fresh dropdown options with deduplication by project_id + dropdown_options = [{"label": current_project_name, "value": current_path}] + seen_project_ids = {current_project_name} + + # Get recent projects + recent = get_recent_projects() + for proj in recent: + project_id = proj.get("project_id", "") + proj_path = proj.get("path", "") + if project_id and project_id not in seen_project_ids and Path(proj_path).exists(): + dropdown_options.append( + {"label": proj.get("name", project_id), "value": proj_path} + ) + seen_project_ids.add(project_id) + + return dropdown_options + # Regenerate home layout when project changes @callback( Output("home-view", "children"), @@ -1172,9 +1254,22 @@ def create_app_no_project( value=None, placeholder="Select a recent project...", className="mb-2", - style={"fontSize": "0.85rem"}, - clearable=True, + style={"fontSize": "0.85rem", "width": "calc(100% - 35px)", "display": "inline-block"}, + clearable=False, ), + html.Button( + html.I(className="bi bi-arrow-clockwise"), + id="refresh-projects-btn", + className="btn btn-sm btn-outline-secondary", + style={ + "width": "30px", + "height": "38px", + "marginLeft": "5px", + "verticalAlign": "top", + "padding": "0", + }, + title="Refresh project list", + ), ] ), html.Hr(className="bg-white"), @@ -1448,6 +1543,9 @@ def _register_no_project_callbacks( # Register the scenario CSS update callback _register_scenario_css_callback(get_current_color_manager) + # Register the refresh projects callback + _register_refresh_projects_callback() + # Register home and scenario callbacks with dynamic data fetching register_home_callbacks( _get_current_data_handler_no_project, @@ -1471,6 +1569,40 @@ def _register_no_project_callbacks( ) +def _register_refresh_projects_callback() -> None: + """Register the refresh projects button callback.""" + + @callback( + Output("project-switcher-dropdown", "options", allow_duplicate=True), + Input("refresh-projects-btn", "n_clicks"), + State("current-project-path", "data"), + prevent_initial_call=True, + ) + def refresh_dropdown_options( + n_clicks: int | None, current_path: str + ) -> list[dict[str, str]]: + """Refresh the project switcher dropdown options with latest recent projects.""" + if not n_clicks: + raise PreventUpdate + + # Build fresh dropdown options with deduplication by project_id + dropdown_options: list[dict[str, str]] = [] + seen_project_ids: set[str] = set() + + # Get recent projects + recent = get_recent_projects() + for proj in recent: + project_id = proj.get("project_id", "") + proj_path = proj.get("path", "") + if project_id and project_id not in seen_project_ids and Path(proj_path).exists(): + dropdown_options.append( + {"label": proj.get("name", project_id), "value": proj_path} + ) + seen_project_ids.add(project_id) + + return dropdown_options + + def _register_sidebar_toggle_callback() -> None: """Register the sidebar toggle callback.""" diff --git a/src/stride/ui/assets/dark-theme.css b/src/stride/ui/assets/dark-theme.css index 417c1f4..b524cfc 100644 --- a/src/stride/ui/assets/dark-theme.css +++ b/src/stride/ui/assets/dark-theme.css @@ -10,8 +10,8 @@ --bg-hover: #404040; --bg-card: #252525; - --text-primary: #e0e0e0; - --text-secondary: #b0b0b0; + --text-primary: #9e9e9e; + --text-secondary: #8a8a8a; --text-muted: #808080; --border-color: #404040; @@ -22,10 +22,10 @@ --input-bg: #2d2d2d; --input-border: #404040; - --input-text: #e0e0e0; + --input-text: #9e9e9e; --dropdown-bg: #2d2d2d; - --dropdown-text: #e0e0e0; + --dropdown-text: #9e9e9e; --dropdown-hover: #3a3a3a; --modal-bg: #2d2d2d; diff --git a/src/stride/ui/project_manager.py b/src/stride/ui/project_manager.py index 05495d9..a150b14 100644 --- a/src/stride/ui/project_manager.py +++ b/src/stride/ui/project_manager.py @@ -177,7 +177,7 @@ def load_project_by_path(project_path: str | Path, **kwargs: Any) -> Project: return Project.load(project_path, **kwargs) -def get_recent_projects(max_count: int = 5) -> list[dict[str, Any]]: +def get_recent_projects(max_count: int = 10) -> list[dict[str, Any]]: """ Get recently accessed projects from config. From 4b00335b02964e4f72b86559059dbe9f62fbe420 Mon Sep 17 00:00:00 2001 From: Matthieu Irondelle Date: Fri, 6 Mar 2026 13:09:46 +0100 Subject: [PATCH 02/10] remove trailing spaces --- src/stride/ui/app.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/stride/ui/app.py b/src/stride/ui/app.py index 4012c25..e664725 100644 --- a/src/stride/ui/app.py +++ b/src/stride/ui/app.py @@ -101,7 +101,7 @@ def load_project(project_path: str) -> tuple[bool, str]: (success, message) where success is True if loaded successfully """ global _loaded_projects, _current_project_path - + try: path = Path(project_path).resolve() path_str = str(path) @@ -334,7 +334,6 @@ def create_app( # noqa: C901 className="small mb-2", style={"fontSize": "0.8rem"}, ), - html.Div( [ dcc.Dropdown( @@ -986,18 +985,18 @@ def refresh_dropdown_options(n_clicks: int | None, current_path: str) -> list[di """Refresh the project switcher dropdown options with latest recent projects.""" if not n_clicks: raise PreventUpdate - + # Get current project info data_handler = get_current_data_handler() if not data_handler: raise PreventUpdate - + current_project_name = data_handler.project.config.project_id - + # Build fresh dropdown options with deduplication by project_id dropdown_options = [{"label": current_project_name, "value": current_path}] seen_project_ids = {current_project_name} - + # Get recent projects recent = get_recent_projects() for proj in recent: @@ -1008,7 +1007,7 @@ def refresh_dropdown_options(n_clicks: int | None, current_path: str) -> list[di {"label": proj.get("name", project_id), "value": proj_path} ) seen_project_ids.add(project_id) - + return dropdown_options # Regenerate home layout when project changes @@ -1269,7 +1268,7 @@ def create_app_no_project( "padding": "0", }, title="Refresh project list", - ), + ), ] ), html.Hr(className="bg-white"), From 22265c7a5d989e1b680036aba515de21fad2ffbd Mon Sep 17 00:00:00 2001 From: Matthieu Irondelle Date: Fri, 13 Mar 2026 10:21:13 +0100 Subject: [PATCH 03/10] test for refresh button --- tests/test_app_cache.py | 381 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 381 insertions(+) create mode 100644 tests/test_app_cache.py diff --git a/tests/test_app_cache.py b/tests/test_app_cache.py new file mode 100644 index 0000000..1a42699 --- /dev/null +++ b/tests/test_app_cache.py @@ -0,0 +1,381 @@ +""" +Tests for the project cache and LRU eviction logic in app.py, +and the refresh-projects dropdown logic. + +Covers code added in the `refresh_recent_projects` branch: +- MAX_CACHED_PROJECTS constant +- _evict_oldest_project() +- LRU reordering when switching to a cached project via load_project() +- refresh_dropdown_options (no-project variant via _register_refresh_projects_callback) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from stride.ui import app as app_module + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +def _make_mock_project(name: str = "proj") -> MagicMock: + """Return a lightweight mock that behaves like a Project.""" + proj = MagicMock() + proj.config.project_id = name + proj.close = MagicMock() + return proj + + +def _make_cache_entry( + name: str = "proj", +) -> tuple[MagicMock, MagicMock, MagicMock, str]: + """Return a (Project, ColorManager, StridePlots, name) tuple for the cache.""" + return (_make_mock_project(name), MagicMock(), MagicMock(), name) + + +@pytest.fixture(autouse=True) +def _reset_global_state() -> None: + """Ensure module-level cache state is clean before *and* after each test.""" + app_module._loaded_projects.clear() + app_module._current_project_path = None + yield + app_module._loaded_projects.clear() + app_module._current_project_path = None + + +# =================================================================== +# Tests for _evict_oldest_project +# =================================================================== + + +class TestEvictOldestProject: + """Tests for the _evict_oldest_project helper.""" + + def test_no_eviction_when_below_capacity(self) -> None: + """No project should be evicted when cache is below MAX_CACHED_PROJECTS.""" + app_module._loaded_projects["/a"] = _make_cache_entry("A") + app_module._loaded_projects["/b"] = _make_cache_entry("B") + assert len(app_module._loaded_projects) < app_module.MAX_CACHED_PROJECTS + + app_module._evict_oldest_project() + + assert len(app_module._loaded_projects) == 2 + assert "/a" in app_module._loaded_projects + assert "/b" in app_module._loaded_projects + + def test_eviction_when_at_capacity(self) -> None: + """The oldest (first-inserted) project should be evicted when at capacity.""" + for i in range(app_module.MAX_CACHED_PROJECTS): + app_module._loaded_projects[f"/{i}"] = _make_cache_entry(f"P{i}") + + oldest_project = app_module._loaded_projects["/0"][0] + + app_module._evict_oldest_project() + + assert len(app_module._loaded_projects) == app_module.MAX_CACHED_PROJECTS - 1 + assert "/0" not in app_module._loaded_projects + oldest_project.close.assert_called_once() + + def test_eviction_removes_oldest_preserves_newest(self) -> None: + """Eviction should remove the first key (LRU) and keep the rest.""" + app_module._loaded_projects["/old"] = _make_cache_entry("Old") + app_module._loaded_projects["/mid"] = _make_cache_entry("Mid") + app_module._loaded_projects["/new"] = _make_cache_entry("New") + + app_module._evict_oldest_project() + + assert "/old" not in app_module._loaded_projects + assert "/mid" in app_module._loaded_projects + assert "/new" in app_module._loaded_projects + + def test_eviction_on_empty_cache(self) -> None: + """Eviction on an empty cache should be a no-op.""" + app_module._evict_oldest_project() + assert len(app_module._loaded_projects) == 0 + + def test_eviction_handles_close_exception(self) -> None: + """If Project.close() raises, eviction should still proceed (logged as warning).""" + for i in range(app_module.MAX_CACHED_PROJECTS): + app_module._loaded_projects[f"/{i}"] = _make_cache_entry(f"P{i}") + + # Make close() raise for the oldest project + app_module._loaded_projects["/0"][0].close.side_effect = RuntimeError("oops") + + # Should not raise + app_module._evict_oldest_project() + + assert "/0" not in app_module._loaded_projects + assert len(app_module._loaded_projects) == app_module.MAX_CACHED_PROJECTS - 1 + + def test_eviction_over_capacity(self) -> None: + """If the cache somehow exceeds capacity, evict until below MAX_CACHED_PROJECTS.""" + # Manually stuff more than MAX_CACHED_PROJECTS entries + for i in range(app_module.MAX_CACHED_PROJECTS + 2): + app_module._loaded_projects[f"/{i}"] = _make_cache_entry(f"P{i}") + + app_module._evict_oldest_project() + + assert len(app_module._loaded_projects) == app_module.MAX_CACHED_PROJECTS - 1 + + +# =================================================================== +# Tests for LRU reordering in load_project +# =================================================================== + + +class TestLoadProjectLRU: + """Tests for LRU re-ordering when switching to a cached project.""" + + def test_cached_project_moves_to_end(self) -> None: + """Accessing a cached project should move it to the end of the dict (MRU).""" + app_module._loaded_projects["/first"] = _make_cache_entry("First") + app_module._loaded_projects["/second"] = _make_cache_entry("Second") + + # Patch APIClient so it doesn't actually try to create an instance + with patch.object(app_module, "APIClient"): + success, msg = app_module.load_project("/first") + + assert success is True + assert "Switched to cached" in msg + # /first should now be the last key (most recently used) + keys = list(app_module._loaded_projects.keys()) + assert keys[-1] == str(Path("/first").resolve()) + + def test_cached_project_updates_current_path(self) -> None: + """Switching to a cached project should update _current_project_path.""" + resolved = str(Path("/cached").resolve()) + app_module._loaded_projects[resolved] = _make_cache_entry("Cached") + + with patch.object(app_module, "APIClient"): + app_module.load_project("/cached") + + assert app_module._current_project_path == resolved + + def test_load_new_project_triggers_eviction(self) -> None: + """Loading a new project when at capacity should evict the oldest first.""" + for i in range(app_module.MAX_CACHED_PROJECTS): + app_module._loaded_projects[f"/proj{i}"] = _make_cache_entry(f"P{i}") + + oldest_mock = app_module._loaded_projects["/proj0"][0] + + mock_project = _make_mock_project("NewProj") + mock_project.palette = MagicMock() + mock_project.palette.copy.return_value = MagicMock( + scenario_theme=["#aaa"], model_year_theme=["#bbb"], metric_theme=["#ccc"] + ) + + with ( + patch.object(app_module, "Project") as MockProject, + patch.object(app_module, "APIClient") as MockAPIClient, + patch.object(app_module, "create_fresh_color_manager") as mock_cm, + patch.object(app_module, "StridePlots"), + patch.object(app_module, "add_recent_project"), + ): + MockProject.load.return_value = mock_project + mock_api = MagicMock() + mock_api.scenarios = ["baseline"] + MockAPIClient.return_value = mock_api + mock_cm.return_value = MagicMock() + + success, _ = app_module.load_project("/brand_new") + + assert success is True + oldest_mock.close.assert_called_once() + assert "/proj0" not in app_module._loaded_projects + + def test_load_project_failure_returns_false(self) -> None: + """A failed load should return (False, error_message).""" + with patch.object(app_module.Path, "resolve", side_effect=RuntimeError("bad")): + success, msg = app_module.load_project("/does/not/exist") + + assert success is False + assert "bad" in msg + + +# =================================================================== +# Tests for get_loaded_project_options +# =================================================================== + + +class TestGetLoadedProjectOptions: + """Tests for the get_loaded_project_options helper.""" + + def test_empty_cache(self) -> None: + assert app_module.get_loaded_project_options() == [] + + def test_returns_all_cached_projects(self) -> None: + app_module._loaded_projects["/a"] = _make_cache_entry("Alpha") + app_module._loaded_projects["/b"] = _make_cache_entry("Beta") + + options = app_module.get_loaded_project_options() + + assert len(options) == 2 + labels = {o["label"] for o in options} + assert labels == {"Alpha", "Beta"} + + def test_preserves_insertion_order(self) -> None: + app_module._loaded_projects["/x"] = _make_cache_entry("X") + app_module._loaded_projects["/y"] = _make_cache_entry("Y") + + options = app_module.get_loaded_project_options() + assert options[0]["value"] == "/x" + assert options[1]["value"] == "/y" + + +# =================================================================== +# Tests for MAX_CACHED_PROJECTS constant +# =================================================================== + + +def test_max_cached_projects_is_positive_int() -> None: + """Sanity check: MAX_CACHED_PROJECTS should be a small positive integer.""" + assert isinstance(app_module.MAX_CACHED_PROJECTS, int) + assert app_module.MAX_CACHED_PROJECTS > 0 + + +# =================================================================== +# Tests for _register_refresh_projects_callback logic +# (We test the inner function indirectly by extracting the logic.) +# =================================================================== + + +class TestRefreshDropdownLogic: + """ + Test the dropdown-refresh logic used by both the 'with-project' and + 'no-project' variants of refresh_dropdown_options. + + Since the actual functions are Dash callbacks registered inside closures, + we replicate/test the shared logic that builds dropdown options from + get_recent_projects(). + """ + + @staticmethod + def _build_dropdown_options_no_project( + recent: list[dict[str, Any]], + ) -> list[dict[str, str]]: + """Replicate the logic of the no-project refresh_dropdown_options.""" + dropdown_options: list[dict[str, str]] = [] + seen_project_ids: set[str] = set() + for proj in recent: + project_id = proj.get("project_id", "") + proj_path = proj.get("path", "") + if ( + project_id + and project_id not in seen_project_ids + and Path(proj_path).exists() + ): + dropdown_options.append( + {"label": proj.get("name", project_id), "value": proj_path} + ) + seen_project_ids.add(project_id) + return dropdown_options + + @staticmethod + def _build_dropdown_options_with_project( + current_project_name: str, + current_path: str, + recent: list[dict[str, Any]], + ) -> list[dict[str, str]]: + """Replicate the logic of the with-project refresh_dropdown_options.""" + dropdown_options = [{"label": current_project_name, "value": current_path}] + seen_project_ids = {current_project_name} + for proj in recent: + project_id = proj.get("project_id", "") + proj_path = proj.get("path", "") + if ( + project_id + and project_id not in seen_project_ids + and Path(proj_path).exists() + ): + dropdown_options.append( + {"label": proj.get("name", project_id), "value": proj_path} + ) + seen_project_ids.add(project_id) + return dropdown_options + + def test_no_project_empty_recent(self) -> None: + """No recent projects should yield an empty list.""" + assert self._build_dropdown_options_no_project([]) == [] + + def test_no_project_deduplicates_by_project_id(self, tmp_path: Path) -> None: + """Duplicate project_ids should be collapsed to a single entry.""" + p = tmp_path / "proj" + p.mkdir() + recent = [ + {"project_id": "dup", "path": str(p), "name": "Dup1"}, + {"project_id": "dup", "path": str(p), "name": "Dup2"}, + ] + result = self._build_dropdown_options_no_project(recent) + assert len(result) == 1 + assert result[0]["label"] == "Dup1" + + def test_no_project_skips_missing_paths(self, tmp_path: Path) -> None: + """Projects whose paths don't exist should be excluded.""" + recent = [ + {"project_id": "gone", "path": "/no/such/path", "name": "Gone"}, + ] + result = self._build_dropdown_options_no_project(recent) + assert result == [] + + def test_no_project_skips_empty_project_id(self, tmp_path: Path) -> None: + """Entries with an empty project_id should be skipped.""" + p = tmp_path / "proj" + p.mkdir() + recent = [{"project_id": "", "path": str(p), "name": "NoId"}] + result = self._build_dropdown_options_no_project(recent) + assert result == [] + + def test_no_project_uses_project_id_as_fallback_label(self, tmp_path: Path) -> None: + """If 'name' is missing, the project_id should be used as the label.""" + p = tmp_path / "proj" + p.mkdir() + recent = [{"project_id": "myid", "path": str(p)}] + result = self._build_dropdown_options_no_project(recent) + assert result[0]["label"] == "myid" + + def test_with_project_includes_current_first(self, tmp_path: Path) -> None: + """The current project should always be the first entry.""" + p = tmp_path / "other" + p.mkdir() + recent = [{"project_id": "other", "path": str(p), "name": "Other"}] + result = self._build_dropdown_options_with_project( + "Current", "/current/path", recent + ) + assert result[0] == {"label": "Current", "value": "/current/path"} + assert len(result) == 2 + + def test_with_project_does_not_duplicate_current(self, tmp_path: Path) -> None: + """If the current project also appears in recent, it should not be duplicated.""" + p = tmp_path / "cur" + p.mkdir() + recent = [{"project_id": "Current", "path": str(p), "name": "Current"}] + result = self._build_dropdown_options_with_project( + "Current", str(p), recent + ) + # Only one entry because deduplication by project_id + assert len(result) == 1 + + def test_with_project_multiple_recent(self, tmp_path: Path) -> None: + """Multiple valid recent projects should all appear after the current.""" + dirs = [] + for name in ("alpha", "beta", "gamma"): + d = tmp_path / name + d.mkdir() + dirs.append(d) + + recent = [ + {"project_id": f"P{i}", "path": str(d), "name": f"Project {i}"} + for i, d in enumerate(dirs) + ] + result = self._build_dropdown_options_with_project( + "Current", "/current", recent + ) + # Current + 3 recent + assert len(result) == 4 + assert result[0]["label"] == "Current" From 2afa0f8abd0c06668a8163a73bc8ec954054a124 Mon Sep 17 00:00:00 2001 From: Matthieu Irondelle Date: Fri, 13 Mar 2026 14:12:46 +0100 Subject: [PATCH 04/10] make max cached project a configurable variable in both UI and CLI --- docs/how_tos/launch_dashboard.md | 32 +++++ src/stride/cli/stride.py | 14 ++- src/stride/ui/app.py | 49 +++++++- src/stride/ui/settings/callbacks.py | 51 ++++++++ src/stride/ui/settings/layout.py | 84 +++++++++++++ src/stride/ui/tui.py | 29 +++++ tests/test_app_cache.py | 180 ++++++++++++++++++++++++++-- 7 files changed, 423 insertions(+), 16 deletions(-) diff --git a/docs/how_tos/launch_dashboard.md b/docs/how_tos/launch_dashboard.md index 1812e50..13ff2aa 100644 --- a/docs/how_tos/launch_dashboard.md +++ b/docs/how_tos/launch_dashboard.md @@ -49,3 +49,35 @@ Launch the dashboard without a project: ``` Projects can be loaded and color palettes can be managed from the sidebar. + +## Configure Max Cached Projects + +By default, STRIDE keeps up to 3 projects open simultaneously. Each open project holds a DuckDB connection, +and on BlobFuse2 FUSE mounts too many concurrent connections can cause errors. + +You can configure this limit via three methods (highest priority first): + +### CLI Flag + +```{eval-rst} + +.. code-block:: console + + $ stride view my_project --max-cached-projects 5 +``` + +### Environment Variable + +```{eval-rst} + +.. code-block:: console + + $ STRIDE_MAX_CACHED_PROJECTS=5 stride view my_project +``` + +### Settings UI + +Open the sidebar, click **Settings**, and adjust the **Max Cached Projects** value in the General section. +This persists the setting to `~/.stride/config.json`. + +Valid range is 1–10. The default is 3. diff --git a/src/stride/cli/stride.py b/src/stride/cli/stride.py index c1e3ea4..88056ee 100644 --- a/src/stride/cli/stride.py +++ b/src/stride/cli/stride.py @@ -636,6 +636,12 @@ def calculated_tables() -> None: default=False, help="Disable automatic loading of default user palette", ) +@click.option( + "--max-cached-projects", + type=click.IntRange(1, 10), + default=None, + help="Maximum number of projects to keep open simultaneously (1-10, default: 3)", +) @click.pass_context def view( ctx: click.Context, @@ -645,6 +651,7 @@ def view( debug: bool, user_palette: str | None, no_default_palette: bool, + max_cached_projects: int | None, ) -> None: """Start the STRIDE dashboard UI. @@ -657,9 +664,14 @@ def view( a different user palette to use. """ from stride.api import APIClient - from stride.ui.app import create_app, create_app_no_project + from stride.ui.app import create_app, create_app_no_project, set_max_cached_projects_override from stride.ui.tui import get_default_user_palette, load_user_palette + # Apply max cached projects override if provided via CLI + if max_cached_projects is not None: + set_max_cached_projects_override(max_cached_projects) + logger.info(f"Max cached projects set to {max_cached_projects} via CLI") + # Determine which palette to use palette_override = None palette_name = None diff --git a/src/stride/ui/app.py b/src/stride/ui/app.py index e664725..0940830 100644 --- a/src/stride/ui/app.py +++ b/src/stride/ui/app.py @@ -20,7 +20,7 @@ from stride.ui.scenario import create_scenario_layout, register_scenario_callbacks from stride.ui.settings import create_settings_layout, register_settings_callbacks from stride.ui.settings.layout import get_temp_color_edits -from stride.ui.tui import list_user_palettes +from stride.ui.tui import get_max_cached_projects as _get_config_max_cached, list_user_palettes assets_path = Path(__file__).parent.absolute() / "assets" app = Dash( @@ -44,12 +44,55 @@ # Maximum number of projects to keep open simultaneously. # Each open project holds a DuckDB connection with file descriptors; # on BlobFuse2 FUSE mounts too many concurrent connections cause [Errno 5]. -MAX_CACHED_PROJECTS = 3 +_DEFAULT_MAX_CACHED_PROJECTS = 3 +_max_cached_projects_override: int | None = None + + +def get_max_cached_projects() -> int: + """Resolve the effective max cached projects value. + + Priority: CLI override > STRIDE_MAX_CACHED_PROJECTS env var > config file > default (3). + Result is clamped to [1, 10]. + """ + import os + + if _max_cached_projects_override is not None: + return max(1, min(10, _max_cached_projects_override)) + + env_val = os.environ.get("STRIDE_MAX_CACHED_PROJECTS") + if env_val is not None: + try: + return max(1, min(10, int(env_val))) + except ValueError: + pass + + config_val = _get_config_max_cached() + if config_val is not None: + return max(1, min(10, config_val)) + + return _DEFAULT_MAX_CACHED_PROJECTS + + +def set_max_cached_projects_override(n: int | None) -> None: + """Set the CLI override for max cached projects. + + Parameters + ---------- + n : int | None + Override value, or None to clear + """ + global _max_cached_projects_override + _max_cached_projects_override = n + + +# Keep module-level attribute for backwards compatibility with tests +MAX_CACHED_PROJECTS = _DEFAULT_MAX_CACHED_PROJECTS def _evict_oldest_project() -> None: """Evict the least-recently-used project from the cache if at capacity.""" - while len(_loaded_projects) >= MAX_CACHED_PROJECTS: + limit = get_max_cached_projects() + while len(_loaded_projects) >= limit: # Dict is insertion-ordered; first key is the oldest (LRU) oldest_path = next(iter(_loaded_projects)) old_project, _, _, old_name = _loaded_projects.pop(oldest_path) diff --git a/src/stride/ui/settings/callbacks.py b/src/stride/ui/settings/callbacks.py index 10d9cb8..e4e5e68 100644 --- a/src/stride/ui/settings/callbacks.py +++ b/src/stride/ui/settings/callbacks.py @@ -845,6 +845,57 @@ def apply_json_palette( no_update, # type: ignore[return-value] ) + # Max Cached Projects callback + @callback( + Output("max-cached-projects-status", "children"), + Input("save-max-cached-btn", "n_clicks"), + State("max-cached-projects-input", "value"), + prevent_initial_call=True, + ) + def save_max_cached_projects( + n_clicks: int | None, + value: int | None, + ) -> html.Div: + """Save the max cached projects setting.""" + if not n_clicks: + raise PreventUpdate + print(value) + if value is None: + return html.Div( + "✗ Please enter a value", + className="text-danger", + ) + + try: + n = int(value) + except (TypeError, ValueError): + return html.Div( + "✗ Invalid number", + className="text-danger", + ) + + if n < 1 or n > 10: + return html.Div( + "✗ Value must be between 1 and 10", + className="text-danger", + ) + + from stride.ui.app import _evict_oldest_project, set_max_cached_projects_override + from stride.ui.tui import set_max_cached_projects + + # Persist to config file + set_max_cached_projects(n) + # Also update the runtime override so it takes effect immediately + set_max_cached_projects_override(n) + # Trigger eviction if current cache exceeds new limit + _evict_oldest_project() + + logger.info(f"Max cached projects set to {n}") + return html.Div( + f"✓ Max cached projects set to {n}", + className="text-success", + ) + def _convert_to_hex(color: str) -> str: """ diff --git a/src/stride/ui/settings/layout.py b/src/stride/ui/settings/layout.py index dd7b654..a07d624 100644 --- a/src/stride/ui/settings/layout.py +++ b/src/stride/ui/settings/layout.py @@ -1,5 +1,7 @@ """Settings page layout for STRIDE dashboard.""" +import os + import dash_bootstrap_components as dbc from dash import dcc, html @@ -59,6 +61,30 @@ def create_settings_layout( # Get temporary color edits temp_edits = get_temp_color_edits() + # Resolve max cached projects override state for the General section + from stride.ui.app import ( + _max_cached_projects_override, + get_max_cached_projects, + ) + + max_cached_value = get_max_cached_projects() + override_source = None + if _max_cached_projects_override is not None: + override_source = f"CLI flag (--max-cached-projects {_max_cached_projects_override})" + elif os.environ.get("STRIDE_MAX_CACHED_PROJECTS") is not None: + override_source = f"Environment variable (STRIDE_MAX_CACHED_PROJECTS={os.environ['STRIDE_MAX_CACHED_PROJECTS']})" + is_overridden = override_source is not None + + override_badge = [] + if is_overridden: + override_badge = [ + dbc.Badge( + f"Overridden by: {override_source}", + color="warning", + className="ms-2 mb-2", + ), + ] + return html.Div( [ dbc.Container( @@ -74,6 +100,64 @@ def create_settings_layout( ) ] ), + # General Settings Section + dbc.Row( + [ + dbc.Col( + [ + html.H4("General", className="mb-3"), + dbc.Card( + [ + dbc.CardBody( + [ + html.Label( + "Max Cached Projects:", + className="form-label fw-bold", + ), + *override_badge, + html.Div( + [ + dcc.Input( + id="max-cached-projects-input", + type="number", + step=1, + value=max_cached_value, + className="form-control form-control-sm", + style={"width": "100px", "display": "inline-block", "height": "31px", "fontSize": "0.85rem"}, + readOnly=is_overridden, + disabled=is_overridden, + ), + dbc.Button( + "Save", + id="save-max-cached-btn", + color="primary", + size="sm", + className="ms-2", + disabled=is_overridden, + ), + ], + className="d-flex align-items-center mb-2", + ), + html.Small( + "Number of projects to keep open simultaneously. " + "Each open project holds a DuckDB connection; " + "too many concurrent connections may cause errors on FUSE mounts.", + className="text-muted", + ), + html.Div( + id="max-cached-projects-status", + className="mt-2", + ), + ] + ) + ], + className="mb-4", + ), + ] + ) + ], + className="mb-4", + ), # Palette Selection Section dbc.Row( [ diff --git a/src/stride/ui/tui.py b/src/stride/ui/tui.py index 62064a2..2b9ccc3 100644 --- a/src/stride/ui/tui.py +++ b/src/stride/ui/tui.py @@ -1209,3 +1209,32 @@ def get_default_user_palette() -> str | None: """ config = load_stride_config() return config.get("default_user_palette") + + +def get_max_cached_projects() -> int | None: + """Get the max cached projects setting from config. + + Returns + ------- + int | None + Configured max cached projects, or None if not set + """ + config = load_stride_config() + value = config.get("max_cached_projects") + if value is not None: + return int(value) + return None + + +def set_max_cached_projects(n: int) -> None: + """Set the max cached projects in the config file. + + Parameters + ---------- + n : int + Number of max cached projects (will be clamped to [1, 10]) + """ + n = max(1, min(10, n)) + config = load_stride_config() + config["max_cached_projects"] = n + save_stride_config(config) diff --git a/tests/test_app_cache.py b/tests/test_app_cache.py index 1a42699..57e3443 100644 --- a/tests/test_app_cache.py +++ b/tests/test_app_cache.py @@ -3,7 +3,7 @@ and the refresh-projects dropdown logic. Covers code added in the `refresh_recent_projects` branch: -- MAX_CACHED_PROJECTS constant +- get_max_cached_projects() configurable limit - _evict_oldest_project() - LRU reordering when switching to a cached project via load_project() - refresh_dropdown_options (no-project variant via _register_refresh_projects_callback) @@ -11,6 +11,7 @@ from __future__ import annotations +import os from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch @@ -44,9 +45,81 @@ def _reset_global_state() -> None: """Ensure module-level cache state is clean before *and* after each test.""" app_module._loaded_projects.clear() app_module._current_project_path = None + app_module._max_cached_projects_override = None yield app_module._loaded_projects.clear() app_module._current_project_path = None + app_module._max_cached_projects_override = None + + +# =================================================================== +# Tests for get_max_cached_projects priority chain +# =================================================================== + + +class TestGetMaxCachedProjects: + """Tests for the get_max_cached_projects resolution function.""" + + def test_default_value(self) -> None: + """Should return 3 when no override, no env var, and no config.""" + with patch.object(app_module, "_get_config_max_cached", return_value=None): + assert app_module.get_max_cached_projects() == 3 + + def test_config_value(self) -> None: + """Config file value should be used when no override or env var.""" + with patch.object(app_module, "_get_config_max_cached", return_value=5): + assert app_module.get_max_cached_projects() == 5 + + def test_env_var_overrides_config(self) -> None: + """Environment variable should override config file value.""" + with ( + patch.object(app_module, "_get_config_max_cached", return_value=5), + patch.dict(os.environ, {"STRIDE_MAX_CACHED_PROJECTS": "4"}), + ): + assert app_module.get_max_cached_projects() == 4 + + def test_cli_override_overrides_env_and_config(self) -> None: + """CLI override should take highest priority.""" + app_module._max_cached_projects_override = 2 + with ( + patch.object(app_module, "_get_config_max_cached", return_value=5), + patch.dict(os.environ, {"STRIDE_MAX_CACHED_PROJECTS": "4"}), + ): + assert app_module.get_max_cached_projects() == 2 + + def test_clamped_to_minimum(self) -> None: + """Values below 1 should be clamped to 1.""" + app_module._max_cached_projects_override = 0 + assert app_module.get_max_cached_projects() == 1 + + def test_clamped_to_maximum(self) -> None: + """Values above 10 should be clamped to 10.""" + app_module._max_cached_projects_override = 99 + assert app_module.get_max_cached_projects() == 10 + + def test_env_var_clamped(self) -> None: + """Env var out of range should be clamped.""" + with ( + patch.object(app_module, "_get_config_max_cached", return_value=None), + patch.dict(os.environ, {"STRIDE_MAX_CACHED_PROJECTS": "0"}), + ): + assert app_module.get_max_cached_projects() == 1 + + def test_env_var_invalid_ignored(self) -> None: + """Non-numeric env var should be ignored, falling through to config/default.""" + with ( + patch.object(app_module, "_get_config_max_cached", return_value=None), + patch.dict(os.environ, {"STRIDE_MAX_CACHED_PROJECTS": "abc"}), + ): + assert app_module.get_max_cached_projects() == 3 + + def test_set_and_clear_override(self) -> None: + """set_max_cached_projects_override should set and clear correctly.""" + app_module.set_max_cached_projects_override(7) + assert app_module._max_cached_projects_override == 7 + + app_module.set_max_cached_projects_override(None) + assert app_module._max_cached_projects_override is None # =================================================================== @@ -58,10 +131,10 @@ class TestEvictOldestProject: """Tests for the _evict_oldest_project helper.""" def test_no_eviction_when_below_capacity(self) -> None: - """No project should be evicted when cache is below MAX_CACHED_PROJECTS.""" + """No project should be evicted when cache is below limit.""" app_module._loaded_projects["/a"] = _make_cache_entry("A") app_module._loaded_projects["/b"] = _make_cache_entry("B") - assert len(app_module._loaded_projects) < app_module.MAX_CACHED_PROJECTS + app_module._max_cached_projects_override = 3 app_module._evict_oldest_project() @@ -71,19 +144,22 @@ def test_no_eviction_when_below_capacity(self) -> None: def test_eviction_when_at_capacity(self) -> None: """The oldest (first-inserted) project should be evicted when at capacity.""" - for i in range(app_module.MAX_CACHED_PROJECTS): + app_module._max_cached_projects_override = 3 + limit = app_module.get_max_cached_projects() + for i in range(limit): app_module._loaded_projects[f"/{i}"] = _make_cache_entry(f"P{i}") oldest_project = app_module._loaded_projects["/0"][0] app_module._evict_oldest_project() - assert len(app_module._loaded_projects) == app_module.MAX_CACHED_PROJECTS - 1 + assert len(app_module._loaded_projects) == limit - 1 assert "/0" not in app_module._loaded_projects oldest_project.close.assert_called_once() def test_eviction_removes_oldest_preserves_newest(self) -> None: """Eviction should remove the first key (LRU) and keep the rest.""" + app_module._max_cached_projects_override = 3 app_module._loaded_projects["/old"] = _make_cache_entry("Old") app_module._loaded_projects["/mid"] = _make_cache_entry("Mid") app_module._loaded_projects["/new"] = _make_cache_entry("New") @@ -101,7 +177,9 @@ def test_eviction_on_empty_cache(self) -> None: def test_eviction_handles_close_exception(self) -> None: """If Project.close() raises, eviction should still proceed (logged as warning).""" - for i in range(app_module.MAX_CACHED_PROJECTS): + app_module._max_cached_projects_override = 3 + limit = app_module.get_max_cached_projects() + for i in range(limit): app_module._loaded_projects[f"/{i}"] = _make_cache_entry(f"P{i}") # Make close() raise for the oldest project @@ -111,17 +189,33 @@ def test_eviction_handles_close_exception(self) -> None: app_module._evict_oldest_project() assert "/0" not in app_module._loaded_projects - assert len(app_module._loaded_projects) == app_module.MAX_CACHED_PROJECTS - 1 + assert len(app_module._loaded_projects) == limit - 1 def test_eviction_over_capacity(self) -> None: - """If the cache somehow exceeds capacity, evict until below MAX_CACHED_PROJECTS.""" - # Manually stuff more than MAX_CACHED_PROJECTS entries - for i in range(app_module.MAX_CACHED_PROJECTS + 2): + """If the cache somehow exceeds capacity, evict until below limit.""" + app_module._max_cached_projects_override = 3 + limit = app_module.get_max_cached_projects() + # Manually stuff more than limit entries + for i in range(limit + 2): app_module._loaded_projects[f"/{i}"] = _make_cache_entry(f"P{i}") app_module._evict_oldest_project() - assert len(app_module._loaded_projects) == app_module.MAX_CACHED_PROJECTS - 1 + assert len(app_module._loaded_projects) == limit - 1 + + def test_eviction_respects_dynamic_limit(self) -> None: + """Lowering the limit should cause eviction of excess projects.""" + # Start with 5 projects and a limit of 5 + app_module._max_cached_projects_override = 5 + for i in range(5): + app_module._loaded_projects[f"/{i}"] = _make_cache_entry(f"P{i}") + + # Lower the limit to 2 + app_module._max_cached_projects_override = 2 + app_module._evict_oldest_project() + + # Should have evicted down to 1 (limit - 1) + assert len(app_module._loaded_projects) == 1 # =================================================================== @@ -159,7 +253,9 @@ def test_cached_project_updates_current_path(self) -> None: def test_load_new_project_triggers_eviction(self) -> None: """Loading a new project when at capacity should evict the oldest first.""" - for i in range(app_module.MAX_CACHED_PROJECTS): + app_module._max_cached_projects_override = 3 + limit = app_module.get_max_cached_projects() + for i in range(limit): app_module._loaded_projects[f"/proj{i}"] = _make_cache_entry(f"P{i}") oldest_mock = app_module._loaded_projects["/proj0"][0] @@ -379,3 +475,63 @@ def test_with_project_multiple_recent(self, tmp_path: Path) -> None: # Current + 3 recent assert len(result) == 4 assert result[0]["label"] == "Current" + + +# =================================================================== +# Tests for config round-trip (tui.py helpers) +# =================================================================== + + +class TestConfigMaxCachedProjects: + """Tests for get_max_cached_projects / set_max_cached_projects in tui.py.""" + + def test_round_trip(self, tmp_path: Path) -> None: + """set_max_cached_projects(n) -> get_max_cached_projects() should return n.""" + from stride.ui.tui import ( + get_max_cached_projects as tui_get, + set_max_cached_projects as tui_set, + ) + + config_file = tmp_path / "config.json" + with patch("stride.ui.tui.get_stride_config_path", return_value=config_file): + # Initially no config file + assert tui_get() is None + + # Set a value + tui_set(5) + assert tui_get() == 5 + + # Update the value + tui_set(8) + assert tui_get() == 8 + + def test_set_clamps_to_range(self, tmp_path: Path) -> None: + """set_max_cached_projects should clamp values to [1, 10].""" + from stride.ui.tui import ( + get_max_cached_projects as tui_get, + set_max_cached_projects as tui_set, + ) + + config_file = tmp_path / "config.json" + with patch("stride.ui.tui.get_stride_config_path", return_value=config_file): + tui_set(0) + assert tui_get() == 1 + + tui_set(99) + assert tui_get() == 10 + + def test_set_preserves_other_config(self, tmp_path: Path) -> None: + """set_max_cached_projects should not clobber other config keys.""" + import json + + from stride.ui.tui import set_max_cached_projects as tui_set + + config_file = tmp_path / "config.json" + config_file.write_text(json.dumps({"default_user_palette": "my_palette"})) + + with patch("stride.ui.tui.get_stride_config_path", return_value=config_file): + tui_set(7) + + saved = json.loads(config_file.read_text()) + assert saved["max_cached_projects"] == 7 + assert saved["default_user_palette"] == "my_palette" From 9025ad1dd5e9175f926379abbfb56a093b63289b Mon Sep 17 00:00:00 2001 From: Matthieu Irondelle Date: Fri, 13 Mar 2026 14:26:47 +0100 Subject: [PATCH 05/10] fixing mypy test errors --- tests/test_app_cache.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/test_app_cache.py b/tests/test_app_cache.py index 57e3443..e72cae9 100644 --- a/tests/test_app_cache.py +++ b/tests/test_app_cache.py @@ -12,6 +12,7 @@ from __future__ import annotations import os +from collections.abc import Generator from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch @@ -41,7 +42,7 @@ def _make_cache_entry( @pytest.fixture(autouse=True) -def _reset_global_state() -> None: +def _reset_global_state() -> Generator[None, None, None]: """Ensure module-level cache state is clean before *and* after each test.""" app_module._loaded_projects.clear() app_module._current_project_path = None @@ -146,10 +147,9 @@ def test_eviction_when_at_capacity(self) -> None: """The oldest (first-inserted) project should be evicted when at capacity.""" app_module._max_cached_projects_override = 3 limit = app_module.get_max_cached_projects() - for i in range(limit): - app_module._loaded_projects[f"/{i}"] = _make_cache_entry(f"P{i}") - - oldest_project = app_module._loaded_projects["/0"][0] + entries = {f"/{i}": _make_cache_entry(f"P{i}") for i in range(limit)} + oldest_project = entries["/0"][0] + app_module._loaded_projects.update(entries) app_module._evict_oldest_project() @@ -179,11 +179,10 @@ def test_eviction_handles_close_exception(self) -> None: """If Project.close() raises, eviction should still proceed (logged as warning).""" app_module._max_cached_projects_override = 3 limit = app_module.get_max_cached_projects() - for i in range(limit): - app_module._loaded_projects[f"/{i}"] = _make_cache_entry(f"P{i}") - + entries = {f"/{i}": _make_cache_entry(f"P{i}") for i in range(limit)} # Make close() raise for the oldest project - app_module._loaded_projects["/0"][0].close.side_effect = RuntimeError("oops") + entries["/0"][0].close.side_effect = RuntimeError("oops") + app_module._loaded_projects.update(entries) # Should not raise app_module._evict_oldest_project() @@ -255,11 +254,12 @@ def test_load_new_project_triggers_eviction(self) -> None: """Loading a new project when at capacity should evict the oldest first.""" app_module._max_cached_projects_override = 3 limit = app_module.get_max_cached_projects() - for i in range(limit): + entry = _make_cache_entry("P0") + oldest_mock = entry[0] + app_module._loaded_projects["/proj0"] = entry + for i in range(1, limit): app_module._loaded_projects[f"/proj{i}"] = _make_cache_entry(f"P{i}") - oldest_mock = app_module._loaded_projects["/proj0"][0] - mock_project = _make_mock_project("NewProj") mock_project.palette = MagicMock() mock_project.palette.copy.return_value = MagicMock( @@ -287,7 +287,7 @@ def test_load_new_project_triggers_eviction(self) -> None: def test_load_project_failure_returns_false(self) -> None: """A failed load should return (False, error_message).""" - with patch.object(app_module.Path, "resolve", side_effect=RuntimeError("bad")): + with patch.object(Path, "resolve", side_effect=RuntimeError("bad")): success, msg = app_module.load_project("/does/not/exist") assert success is False From f4b9ebbc382a49c97bf7870b2bcf421ff3845081 Mon Sep 17 00:00:00 2001 From: Matthieu Irondelle Date: Wed, 1 Apr 2026 10:43:45 +0200 Subject: [PATCH 06/10] changes to resolve conflict for PR --- .pre-commit-config.yaml | 8 + docs/how_tos/compare_scenarios.md | 11 +- docs/how_tos/customize_palette.md | 5 +- pyproject.toml | 2 +- src/stride/api/__init__.py | 59 +- src/stride/cli/stride.py | 159 +-- src/stride/config.py | 92 ++ src/stride/models.py | 4 +- src/stride/project.py | 57 +- src/stride/ui/app.py | 242 +++- src/stride/ui/assets/dark-theme.css | 60 +- src/stride/ui/assets/light-theme.css | 62 +- src/stride/ui/assets/theme-detector.js | 15 +- src/stride/ui/color_manager.py | 48 +- src/stride/ui/home/callbacks.py | 106 +- src/stride/ui/home/layout.py | 3 +- src/stride/ui/palette.py | 1021 +++++++++----- src/stride/ui/palette_utils.py | 196 +++ src/stride/ui/plotting/__init__.py | 42 +- src/stride/ui/plotting/facets.py | 82 +- src/stride/ui/plotting/simple.py | 104 +- src/stride/ui/plotting/utils.py | 124 +- src/stride/ui/scenario/callbacks.py | 35 +- src/stride/ui/settings/callbacks.py | 279 ++-- src/stride/ui/settings/layout.py | 263 +++- src/stride/ui/tui.py | 1240 ----------------- tests/palette/test_auto_color.py | 10 +- tests/palette/test_color_manager_update.py | 98 +- tests/palette/test_palette.py | 373 ++++- tests/palette/test_palette_init.py | 10 +- tests/palette/test_palette_merge.py | 248 ++++ tests/palette/test_palette_override.py | 18 +- .../test_palette_override_integration.py | 6 +- tests/palette/test_rgb_color_format.py | 63 +- tests/palette/test_save_user_palette.py | 278 ++++ tests/palette/test_settings_categories.py | 71 + tests/palette/test_tol_palettes.py | 257 ++++ tests/test_api.py | 153 ++ tests/test_app_cache.py | 14 +- tests/tui/test_edit_features.py | 212 --- tests/tui/test_palette_tui.py | 191 --- tests/tui/test_tui.py | 0 tests/tui/test_tui_refresh.py | 234 ---- tests/tui/test_tui_reordering.py | 144 -- tests/tui/test_tui_simple.py | 88 -- 45 files changed, 3717 insertions(+), 3070 deletions(-) create mode 100644 src/stride/config.py create mode 100644 src/stride/ui/palette_utils.py delete mode 100644 src/stride/ui/tui.py create mode 100644 tests/palette/test_palette_merge.py create mode 100644 tests/palette/test_save_user_palette.py create mode 100644 tests/palette/test_settings_categories.py create mode 100644 tests/palette/test_tol_palettes.py delete mode 100644 tests/tui/test_edit_features.py delete mode 100644 tests/tui/test_palette_tui.py delete mode 100644 tests/tui/test_tui.py delete mode 100644 tests/tui/test_tui_refresh.py delete mode 100644 tests/tui/test_tui_reordering.py delete mode 100644 tests/tui/test_tui_simple.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a1deb1..bc2f91b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,3 +8,11 @@ repos: args: [ --fix ] # Run the formatter. - id: ruff-format +- repo: local + hooks: + - id: mypy + name: mypy + entry: mypy + language: system + pass_filenames: false + stages: [pre-push] diff --git a/docs/how_tos/compare_scenarios.md b/docs/how_tos/compare_scenarios.md index a6c2413..7980db4 100644 --- a/docs/how_tos/compare_scenarios.md +++ b/docs/how_tos/compare_scenarios.md @@ -16,8 +16,8 @@ client = APIClient(project) ## Query Multiple Scenarios ```python -baseline = client.get_total_consumption(scenario="baseline") -high_growth = client.get_total_consumption(scenario="high_growth") +baseline = client.get_annual_electricity_consumption(scenarios=["baseline"]) +high_growth = client.get_annual_electricity_consumption(scenarios=["high_growth"]) ``` ## Calculate Differences @@ -27,7 +27,7 @@ import pandas as pd comparison = pd.merge( baseline, high_growth, - on=["geography", "model_year"], + on=["year"], suffixes=("_baseline", "_high_growth") ) comparison["difference"] = ( @@ -45,9 +45,8 @@ import plotly.express as px fig = px.scatter( comparison, - x="model_year", - y="pct_change", - color="geography", + x="year", + y="pct_difference", title="Consumption Change: High Growth vs Baseline" ) fig.show() diff --git a/docs/how_tos/customize_palette.md b/docs/how_tos/customize_palette.md index e652693..33da2c2 100644 --- a/docs/how_tos/customize_palette.md +++ b/docs/how_tos/customize_palette.md @@ -14,11 +14,14 @@ Create consistent colors for visualizations. ## Preview a Palette +Launch the dashboard and open the Settings panel to preview and edit palette +colors: + ```{eval-rst} .. code-block:: console - $ stride palette view my_project + $ stride view my_project ``` ## Create a Custom Palette diff --git a/pyproject.toml b/pyproject.toml index b699fc4..de7ffd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ ] dependencies = [ "click", - "dash>=3.2.0", + "dash>=4.0.0, <5", "dash-bootstrap-components>=2.0.3", "dbt-core >= 1.10.5, < 2", "dbt-duckdb", diff --git a/src/stride/api/__init__.py b/src/stride/api/__init__.py index 15a687b..5fad7c1 100644 --- a/src/stride/api/__init__.py +++ b/src/stride/api/__init__.py @@ -566,7 +566,7 @@ def get_annual_peak_demand( t.scenario, t.model_year as year, t.{group_col}, - t.value + SUM(t.value) as value FROM energy_projection t INNER JOIN peak_hours p ON t.scenario = p.scenario @@ -576,6 +576,7 @@ def get_annual_peak_demand( WHERE t.geography = ? AND t.scenario = ANY(?) AND t.model_year = ANY(?) + GROUP BY t.scenario, t.model_year, t.{group_col} ORDER BY {scenario_order}, t.model_year, t.{group_col} """ params = [ @@ -1177,17 +1178,28 @@ def get_time_series_comparison( group_col = "metric" else: # group_by == "Sector" group_col = "sector" + group_time_period_calc = f"ROW_NUMBER() OVER (PARTITION BY scenario, model_year, {group_col} ORDER BY timestamp)" sql = f""" + WITH hourly_totals AS ( + SELECT + scenario, + model_year, + timestamp, + {group_col}, + SUM(value) as value + FROM energy_projection + WHERE geography = ? + AND scenario = ? + AND model_year = ANY(?) + GROUP BY scenario, model_year, timestamp, {group_col} + ) SELECT scenario, model_year as year, - {time_period_calc} as time_period, + {group_time_period_calc} as time_period, {group_col}, value - FROM energy_projection - WHERE geography = ? - AND scenario = ? - AND model_year = ANY(?) + FROM hourly_totals ORDER BY scenario, model_year, timestamp, {group_col} """ params: list[Any] = [self.project_country, scenario, years] @@ -1238,31 +1250,50 @@ def get_time_series_comparison( else: # group_by == "Sector" group_col = "sector" sql = f""" + WITH hourly_totals AS ( + SELECT + scenario, + model_year, + timestamp, + {group_col}, + SUM(value) as value + FROM energy_projection + WHERE geography = ? + AND scenario = ? + AND model_year = ANY(?) + GROUP BY scenario, model_year, timestamp, {group_col} + ) SELECT scenario, model_year as year, {time_period_calc} as time_period, {group_col}, AVG(value) as value - FROM energy_projection - WHERE geography = ? - AND scenario = ? - AND model_year = ANY(?) + FROM hourly_totals GROUP BY scenario, model_year, {time_period_calc}, {group_col} ORDER BY scenario, model_year, time_period, {group_col} """ params = [self.project_country, scenario, years] else: sql = f""" + WITH hourly_totals AS ( + SELECT + scenario, + model_year, + timestamp, + SUM(value) as value + FROM energy_projection + WHERE geography = ? + AND scenario = ? + AND model_year = ANY(?) + GROUP BY scenario, model_year, timestamp + ) SELECT scenario, model_year as year, {time_period_calc} as time_period, AVG(value) as value - FROM energy_projection - WHERE geography = ? - AND scenario = ? - AND model_year = ANY(?) + FROM hourly_totals GROUP BY scenario, model_year, {time_period_calc} ORDER BY scenario, model_year, time_period """ diff --git a/src/stride/cli/stride.py b/src/stride/cli/stride.py index 88056ee..c8d331a 100644 --- a/src/stride/cli/stride.py +++ b/src/stride/cli/stride.py @@ -12,6 +12,7 @@ from stride import Project from stride.models import CalculatedTableOverride from stride.project import list_valid_countries, list_valid_model_years, list_valid_weather_years +from stride.ui.palette_utils import list_user_palettes, set_palette_priority from stride.dataset_download import ( DatasetDownloadError, download_dataset, @@ -665,7 +666,11 @@ def view( """ from stride.api import APIClient from stride.ui.app import create_app, create_app_no_project, set_max_cached_projects_override - from stride.ui.tui import get_default_user_palette, load_user_palette + from stride.ui.palette_utils import ( + get_default_user_palette, + get_palette_priority, + load_user_palette, + ) # Apply max cached projects override if provided via CLI if max_cached_projects is not None: @@ -677,11 +682,13 @@ def view( palette_name = None if user_palette: - # Explicit user palette override + # Explicit user palette override (always honored) palette_name = user_palette elif not no_default_palette: - # Check for default user palette - palette_name = get_default_user_palette() + # Check palette priority and default user palette + priority = get_palette_priority() + if priority == "user": + palette_name = get_default_user_palette() if palette_name: try: @@ -910,67 +917,6 @@ def palette() -> None: """Palette commands""" -_palette_view_epilog = """ -Examples:\n -$ stride palette view test_project --project\n -$ stride palette view my_palette --user\n -""" - - -@click.command(name="view", epilog=_palette_view_epilog) -@click.argument("name", type=str) -@click.option( - "--project", - "palette_type", - flag_value="project", - default=True, - help="View a project palette (default)", -) -@click.option( - "--user", - "palette_type", - flag_value="user", - help="View a user palette", -) -@click.pass_context -def view_palette(ctx: click.Context, name: str, palette_type: str) -> None: - """View a color palette in an interactive TUI. - - For project palettes, NAME should be the path to the project directory. - For user palettes, NAME should be the palette name. - """ - from stride.ui.tui import launch_palette_viewer - - if palette_type == "project": - project_path = Path(name) - if not project_path.exists(): - logger.error(f"Project path does not exist: {project_path}") - ctx.exit(1) - - palette_file = project_path / "project.json5" - if not palette_file.exists(): - logger.error(f"Project config not found: {palette_file}") - ctx.exit(1) - - # Load project config to get better grouping info - from stride.models import ProjectConfig - - config = ProjectConfig.from_file(palette_file) - - launch_palette_viewer(palette_file, palette_type="project", project_config=config) - else: - from stride.ui.tui import get_user_palette_dir - - palette_dir = get_user_palette_dir() - palette_file = palette_dir / f"{name}.json" - - if not palette_file.exists(): - logger.error(f"User palette not found: {palette_file}") - ctx.exit(1) - - launch_palette_viewer(palette_file, palette_type="user") - - _palette_init_epilog = """ Examples:\n # Create an empty palette (for manual population)\n @@ -1035,13 +981,13 @@ def init_palette( # noqa: C901 --from-user: Copy from an existing user palette in ~/.stride/palettes/ - (No source): Create an empty palette for manual population via TUI + (No source): Create an empty palette for manual population via the Settings UI The palette can be saved to user space (default) or embedded in a project. """ from stride.api import APIClient - from stride.ui.palette import ColorPalette - from stride.ui.tui import load_user_palette, save_user_palette + from stride.ui.palette import ColorCategory, ColorPalette + from stride.ui.palette_utils import load_user_palette, save_user_palette # Validate that at most one source is specified sources = [from_project, from_user] @@ -1063,8 +1009,8 @@ def init_palette( # noqa: C901 if source_count == 0: # Create an empty palette with structured categories print(f"Creating empty palette: {name}") - print("Use 'stride palette view {name} --user' to add labels interactively") - palette_dict = {"scenarios": {}, "model_years": {}, "metrics": {}} + print("Edit colors in the Settings panel after launching: stride view") + palette_dict = {"scenarios": {}, "model_years": {}, "sectors": {}, "end_uses": {}} elif from_project: # Get labels from project configuration and database project_path = from_project @@ -1078,14 +1024,14 @@ def init_palette( # noqa: C901 scenario_names = [scenario.name for scenario in project.config.scenarios] print(f"Found {len(scenario_names)} scenarios from config") for label in scenario_names: - palette.update(label, category="scenarios") + palette.update(label, category=ColorCategory.SCENARIO) # Get model years from ProjectConfig (fast lookup) model_years = project.config.list_model_years() year_labels = [str(year) for year in model_years] print(f"Found {len(year_labels)} model years from config") for label in year_labels: - palette.update(label, category="model_years") + palette.update(label, category=ColorCategory.MODEL_YEAR) # Get sectors and end uses from database (requires query) api_client = APIClient(project) @@ -1093,9 +1039,11 @@ def init_palette( # noqa: C901 end_uses = api_client.get_unique_end_uses() print(f"Found {len(sectors)} sectors and {len(end_uses)} end uses from database") - # Add sectors and end uses to the metrics category - for label in sectors + end_uses: - palette.update(label, category="metrics") + # Add sectors and end uses to their respective categories + for label in sectors: + palette.update(label, category=ColorCategory.SECTOR) + for label in end_uses: + palette.update(label, category=ColorCategory.END_USE) palette_dict = palette.to_dict() @@ -1116,7 +1064,7 @@ def init_palette( # noqa: C901 saved_path = save_user_palette(name, palette_dict) logger.info(f"Created user palette '{name}' at {saved_path}") print(f"\nCreated user palette: {saved_path}") - print(f"View with: stride palette view {name} --user") + print("Edit colors in the Settings panel after launching: stride view") else: # Save to project palette if not project_path: @@ -1127,7 +1075,7 @@ def init_palette( # noqa: C901 project.persist() logger.info(f"Created project palette in {project_path / 'project.json5'}") print(f"\nCreated project palette in: {project_path / 'project.json5'}") - print(f"View with: stride palette view {project_path} --project") + print("Edit colors in the Settings panel after launching: stride view") @click.command(name="list") @@ -1147,8 +1095,6 @@ def init_palette( # noqa: C901 def list_palettes(palette_type: str) -> None: """List available color palettes.""" if palette_type == "user": - from stride.ui.tui import list_user_palettes - palettes = list_user_palettes() if not palettes: print("No user palettes found.") @@ -1160,7 +1106,7 @@ def list_palettes(palette_type: str) -> None: for palette_path in palettes: print(f" - {palette_path.stem} ({palette_path})") else: - print("To view a project palette, use: stride palette view --project") + print("Project palettes are stored in project.json5. Launch the dashboard to view.") @click.command(name="set-default") @@ -1180,7 +1126,7 @@ def set_default_palette(ctx: click.Context, palette_name: str | None) -> None: Clear the default palette: $ stride palette set-default """ - from stride.ui.tui import set_default_user_palette + from stride.ui.palette_utils import set_default_user_palette try: set_default_user_palette(palette_name) @@ -1202,7 +1148,7 @@ def get_default_palette() -> None: $ stride palette get-default """ - from stride.ui.tui import get_default_user_palette + from stride.ui.palette_utils import get_default_user_palette default = get_default_user_palette() if default: @@ -1212,6 +1158,48 @@ def get_default_palette() -> None: print("Set one with: stride palette set-default ") +@click.command(name="set-priority") +@click.argument("priority", type=click.Choice(["user", "project"])) +def set_priority(priority: str) -> None: + """Set which palette takes priority when launching the dashboard. + + When set to "user", the default user palette (if set) will override the + project palette on dashboard launch. When set to "project", the project + palette is always used unless --user-palette is specified. + + Examples: + + $ stride palette set-priority user + + $ stride palette set-priority project + """ + set_palette_priority(priority) + if priority == "user": + print("Palette priority set to: user") + print("Default user palette (if set) will override project palette on launch.") + else: + print("Palette priority set to: project") + print("Project palette will always be used unless --user-palette is specified.") + + +@click.command(name="get-priority") +def get_priority() -> None: + """Show the current palette priority setting. + + Examples: + + $ stride palette get-priority + """ + from stride.ui.palette_utils import get_palette_priority + + priority = get_palette_priority() + print(f"Palette priority: {priority}") + if priority == "user": + print("Default user palette (if set) will override project palette on launch.") + else: + print("Project palette will always be used unless --user-palette is specified.") + + _palette_refresh_epilog = """ Examples:\n # Fix palette colors for a project\n @@ -1238,7 +1226,7 @@ def refresh_palette(ctx: click.Context, project_path: Path) -> None: print("\nBefore refresh:") print(f" Scenarios: {len(palette.scenarios)}") print(f" Model Years: {len(palette.model_years)}") - print(f" Metrics: {len(palette.metrics)}") + print(f" Sectors: {len(palette.sectors)}, End Uses: {len(palette.end_uses)}") # Refresh colors project.refresh_palette_colors() @@ -1247,7 +1235,7 @@ def refresh_palette(ctx: click.Context, project_path: Path) -> None: print("\nAfter refresh:") print(f" Scenarios: {len(palette.scenarios)} (Bold theme)") print(f" Model Years: {len(palette.model_years)} (YlOrRd theme)") - print(f" Metrics: {len(palette.metrics)} (Prism theme)") + print(f" Sectors: {len(palette.sectors)}, End Uses: {len(palette.end_uses)}") print("\nPalette colors refreshed and saved!") @@ -1304,9 +1292,10 @@ def safe_get_project_from_context( calculated_tables.add_command(override_calculated_table) calculated_tables.add_command(export_calculated_table) calculated_tables.add_command(remove_calculated_table_override) -palette.add_command(view_palette) palette.add_command(init_palette) palette.add_command(list_palettes) palette.add_command(set_default_palette) palette.add_command(get_default_palette) +palette.add_command(set_priority) +palette.add_command(get_priority) palette.add_command(refresh_palette) diff --git a/src/stride/config.py b/src/stride/config.py new file mode 100644 index 0000000..03b8be7 --- /dev/null +++ b/src/stride/config.py @@ -0,0 +1,92 @@ +"""Stride configuration utilities. + +Manages the Stride configuration directory (``~/.stride/``) and configuration +file (``~/.stride/config.json``). +""" + +import json +from pathlib import Path +from typing import Any + + +def get_stride_config_dir() -> Path: + """Get the stride configuration directory, creating it if necessary. + + Returns + ------- + Path + Path to ~/.stride/ + """ + config_dir = Path.home() / ".stride" + config_dir.mkdir(parents=True, exist_ok=True) + return config_dir + + +def get_stride_config_path() -> Path: + """Get the stride configuration file path. + + Returns + ------- + Path + Path to ~/.stride/config.json + """ + return get_stride_config_dir() / "config.json" + + +def load_stride_config() -> dict[str, Any]: + """Load the stride configuration file. + + Returns + ------- + dict[str, Any] + Configuration dictionary, or empty dict if file doesn't exist + """ + config_path = get_stride_config_path() + if not config_path.exists(): + return {} + + with open(config_path) as f: + result: dict[str, Any] = json.load(f) + return result + + +def save_stride_config(config: dict[str, Any]) -> None: + """Save the stride configuration file. + + Parameters + ---------- + config : dict[str, Any] + Configuration dictionary to save + """ + config_path = get_stride_config_path() + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + +def get_max_cached_projects() -> int | None: + """Get the max cached projects setting from config. + + Returns + ------- + int | None + Configured max cached projects, or None if not set + """ + config = load_stride_config() + value = config.get("max_cached_projects") + if value is not None: + return int(value) + return None + + +def set_max_cached_projects(n: int) -> None: + """Set the max cached projects in the config file. + + Parameters + ---------- + n : int + Number of max cached projects (will be clamped to [1, 10]) + """ + n = max(1, min(10, n)) + config = load_stride_config() + config["max_cached_projects"] = n + save_stride_config(config) diff --git a/src/stride/models.py b/src/stride/models.py index 69a8110..86eabe2 100644 --- a/src/stride/models.py +++ b/src/stride/models.py @@ -165,8 +165,8 @@ class ProjectConfig(DSGBaseModel): # type: ignore description="Calculated tables to override", ) color_palette: dict[str, dict[str, str]] = Field( - default={"scenarios": {}, "model_years": {}, "metrics": {}}, - description="Color palette organized into scenarios, model_years, and metrics categories. Each category maps labels to hex/rgb color strings for the UI.", + default={"scenarios": {}, "model_years": {}, "sectors": {}, "end_uses": {}}, + description="Color palette organized into scenarios, model_years, sectors, and end_uses categories. Each category maps labels to hex/rgb color strings for the UI.", ) @classmethod diff --git a/src/stride/project.py b/src/stride/project.py index f25c4b3..3a82736 100644 --- a/src/stride/project.py +++ b/src/stride/project.py @@ -30,7 +30,7 @@ ProjectConfig, Scenario, ) -from stride.ui.palette import ColorPalette +from stride.ui.palette import ColorCategory, ColorPalette CONFIG_FILE = "project.json5" DATABASE_FILE = "data.duckdb" @@ -271,34 +271,35 @@ def palette(self) -> ColorPalette: >>> project.save_palette() """ if self._palette is None: - self._palette = ColorPalette(self._config.color_palette) + self._palette = ColorPalette.from_dict(self._config.color_palette) self._auto_populate_palette() return self._palette def _auto_populate_palette(self) -> None: - """Auto-populate palette with scenarios and model_years from config if not already present.""" + """Auto-populate palette with scenarios and model_years from config. + + Uses merge logic so that existing name→color mappings are preserved, + new project dimensions get unused theme colors, and extra palette + entries are kept as reserves. + """ if self._palette is None: return - # Auto-populate scenarios from config scenario_names = [scenario.name for scenario in self._config.scenarios] - for name in scenario_names: - if name not in self._palette.scenarios: - self._palette.update(name, category="scenarios") - - # Auto-populate model_years from config - model_years = self._config.list_model_years() - for year in model_years: - year_str = str(year) - if year_str not in self._palette.model_years: - self._palette.update(year_str, category="model_years") + model_years = [str(y) for y in self._config.list_model_years()] + self._palette.merge_with_project_dimensions( + scenarios=scenario_names, + model_years=model_years, + ) def populate_palette_metrics(self) -> None: """Populate the palette with all metrics (sectors and end uses) from the database. - This method queries the database for unique sectors and end uses and adds them - to the metrics category of the palette. It's called automatically during project - creation, but can be called manually to refresh the palette after updates. + Uses merge logic to preserve existing color assignments, assign unused + theme colors to new entries, and retain extra palette entries as reserves. + + This method is called automatically during project creation, but can be + called manually to refresh the palette after updates. Examples -------- @@ -314,19 +315,14 @@ def populate_palette_metrics(self) -> None: api_client = APIClient(self) - # Get all unique sectors and end uses from the database sectors = api_client.get_unique_sectors() end_uses = api_client.get_unique_end_uses() - # Add sectors to metrics category - for sector in sectors: - if self._palette is not None and sector not in self._palette.metrics: - self._palette.update(sector, category="metrics") - - # Add end uses to metrics category - for end_use in end_uses: - if self._palette is not None and end_use not in self._palette.metrics: - self._palette.update(end_use, category="metrics") + if self._palette is not None: + self._palette.merge_with_project_dimensions( + sectors=sectors, + end_uses=end_uses, + ) def refresh_palette_colors(self) -> None: """Refresh all palette colors to use the correct themes for each category. @@ -347,9 +343,10 @@ def refresh_palette_colors(self) -> None: # Refresh colors for each category using the correct theme if self._palette is not None: - self._palette.refresh_category_colors("scenarios") - self._palette.refresh_category_colors("model_years") - self._palette.refresh_category_colors("metrics") + self._palette.refresh_category_colors(ColorCategory.SCENARIO) + self._palette.refresh_category_colors(ColorCategory.MODEL_YEAR) + self._palette.refresh_category_colors(ColorCategory.SECTOR) + self._palette.refresh_category_colors(ColorCategory.END_USE) def save_palette(self) -> None: """Save the current palette state back to the project conig file.""" diff --git a/src/stride/ui/app.py b/src/stride/ui/app.py index 0940830..bc673d3 100644 --- a/src/stride/ui/app.py +++ b/src/stride/ui/app.py @@ -16,11 +16,22 @@ from stride.ui.home import create_home_layout, register_home_callbacks from stride.ui.palette import ColorPalette from stride.ui.plotting import StridePlots +from stride.ui.plotting.utils import ( + DARK_CSS_THEME, + DARK_PLOTLY_TEMPLATE, + DEFAULT_CSS_THEME, + DEFAULT_PLOTLY_TEMPLATE, +) from stride.ui.project_manager import add_recent_project, get_recent_projects from stride.ui.scenario import create_scenario_layout, register_scenario_callbacks from stride.ui.settings import create_settings_layout, register_settings_callbacks -from stride.ui.settings.layout import get_temp_color_edits -from stride.ui.tui import get_max_cached_projects as _get_config_max_cached, list_user_palettes +from stride.ui.settings.layout import ( + get_temp_color_edits, + get_temp_edits_for_category, + parse_temp_edit_key, +) +from stride.config import get_max_cached_projects as _get_config_max_cached +from stride.ui.palette_utils import get_default_user_palette, list_user_palettes assets_path = Path(__file__).parent.absolute() / "assets" app = Dash( @@ -103,17 +114,28 @@ def _evict_oldest_project() -> None: logger.warning(f"Error closing evicted project at {oldest_path}: {e}") -def create_fresh_color_manager(palette: ColorPalette, scenarios: list[str]) -> ColorManager: +def create_fresh_color_manager( + palette: ColorPalette, + scenarios: list[str], + *, + ui_theme: str = "light", +) -> ColorManager: """Create a fresh ColorManager instance, bypassing the singleton. Each project needs its own ColorManager to ensure consistent colors. + Applies *ui_theme* so that model-year colors are sampled from the + WCAG-safe iridescent range and sector/end-use colors use the + correct metric palette for the current theme. """ from itertools import cycle - # Reset the palette's iterators to ensure consistent color assignment + # Reset the scenario iterator (set_ui_theme does not manage it) palette._scenario_iterator = cycle(palette.scenario_theme) - palette._model_year_iterator = cycle(palette.model_year_theme) - palette._metric_iterator = cycle(palette.metric_theme) + + # Ensure all colors match the requested UI theme. This re-samples + # model-year colors from the WCAG-safe iridescent range and + # re-assigns sector/end-use colors, resetting their iterators. + palette.set_ui_theme(ui_theme) # Use object.__new__ to bypass ColorManager's singleton __new__ method color_manager = object.__new__(ColorManager) @@ -123,7 +145,7 @@ def create_fresh_color_manager(palette: ColorPalette, scenarios: list[str]) -> C color_manager.initialize_colors( scenarios=scenarios, sectors=literal_to_list(Sectors), - end_uses=[], + end_uses=list(palette.end_uses.keys()), ) return color_manager @@ -167,11 +189,28 @@ def load_project(project_path: str) -> tuple[bool, str]: project = Project.load(path, read_only=True) data_handler = APIClient(project) # Updates singleton + # Determine current UI theme from existing project plotter + ui_theme = "light" + if _current_project_path and _current_project_path in _loaded_projects: + _, _, existing_plotter, _ = _loaded_projects[_current_project_path] + if existing_plotter: + ui_theme = "dark" if "dark" in existing_plotter._template else "light" + current_template = DARK_PLOTLY_TEMPLATE if ui_theme == "dark" else DEFAULT_PLOTLY_TEMPLATE + # Create a fresh color manager for this project palette = project.palette.copy() - color_manager = create_fresh_color_manager(palette, data_handler.scenarios) + try: + palette.merge_with_project_dimensions( + sectors=data_handler.get_unique_sectors(), + end_uses=data_handler.get_unique_end_uses(), + ) + except Exception as e: + logger.warning(f"Could not populate sectors/end_uses: {e}") + color_manager = create_fresh_color_manager( + palette, data_handler.scenarios, ui_theme=ui_theme + ) - plotter = StridePlots(color_manager, template="plotly_dark") + plotter = StridePlots(color_manager, template=current_template) project_name = project.config.project_id @@ -245,10 +284,21 @@ def create_app( # noqa: C901 current_palette_type = "project" current_palette_name = None + # Ensure palette has sectors and end_uses from the database. + # _auto_populate_palette only handles scenarios/model_years; sectors + # and end_uses are populated during project creation but may be + # missing from older projects or user palettes. + try: + sectors = data_handler.get_unique_sectors() + end_uses = data_handler.get_unique_end_uses() + palette.merge_with_project_dimensions(sectors=sectors, end_uses=end_uses) + except Exception as e: + logger.warning(f"Could not populate sectors/end_uses in palette: {e}") + # Create fresh color manager for this project color_manager = create_fresh_color_manager(palette.copy(), data_handler.scenarios) - plotter = StridePlots(color_manager, template="plotly_dark") + plotter = StridePlots(color_manager, template=DEFAULT_PLOTLY_TEMPLATE) # Store in global cache - store Project (not APIClient) since APIClient is singleton initial_project_name = data_handler.project.config.project_id @@ -299,12 +349,14 @@ def create_app( # noqa: C901 user_palettes = [] project_palette_name = data_handler.project.config.project_id + default_user_palette_name = get_default_user_palette() settings_layout = create_settings_layout( project_palette_name=project_palette_name, user_palettes=user_palettes, current_palette_type=current_palette_type, current_palette_name=current_palette_name, color_manager=color_manager, + default_user_palette=default_user_palette_name, ) # Get current project display name @@ -452,7 +504,7 @@ def create_app( # noqa: C901 dcc.Store(id="current-project-path", data=current_project_path), dcc.Store(id="sidebar-open", data=False), dcc.Store(id="chart-refresh-trigger", data=0), - dcc.Store(id="theme-store", data="dark"), + dcc.Store(id="theme-store", data=DEFAULT_CSS_THEME), # Dynamic scenario CSS that updates with palette changes html.Div( id="scenario-css-container", @@ -541,7 +593,7 @@ def create_app( # noqa: C901 ), dbc.Switch( id="theme-toggle", - value=True, + value=False, style={ "transform": "scale(1.2)", }, @@ -600,11 +652,11 @@ def create_app( # noqa: C901 ), ], id="page-content", - className="page-content dark-theme", + className=f"page-content {DEFAULT_CSS_THEME}", style={"marginLeft": "0px", "transition": "margin-left 0.3s"}, ), ], - className="dark-theme", + className=DEFAULT_CSS_THEME, style={"minHeight": "100vh"}, ) @@ -710,7 +762,6 @@ def toggle_views( ) elif trigger_id == "back-to-dashboard-btn" or trigger_id == "home-link": # Return to home view - apply any temporary color edits and refresh charts - from stride.ui.settings.layout import get_temp_color_edits # Apply temporary color edits to the ColorManager temp_edits = get_temp_color_edits() @@ -718,8 +769,9 @@ def toggle_views( color_manager = get_current_color_manager() if color_manager: palette = color_manager.get_palette() - for label, color in temp_edits.items(): - palette.update(label, color) + for composite_key, color in temp_edits.items(): + cat_str, label = parse_temp_edit_key(composite_key) + palette.update(label, color, category=cat_str) logger.info( f"Applied {len(temp_edits)} temporary color edits when returning to home" ) @@ -765,16 +817,25 @@ def toggle_views( ) def toggle_theme(is_dark: bool, refresh_count: int) -> tuple[str, str, str, int]: """Toggle between light and dark theme.""" - theme = "dark-theme" if is_dark else "light-theme" + theme = DARK_CSS_THEME if is_dark else DEFAULT_CSS_THEME + ui_mode = "dark" if is_dark else "light" - # Update plotter template for all charts + # Update plotter template and palette colors for all charts if plotter: - template = "plotly_dark" if is_dark else "plotly_white" + template = DARK_PLOTLY_TEMPLATE if is_dark else DEFAULT_PLOTLY_TEMPLATE plotter.set_template(template) + # Update palette colors for new theme contrast requirements + plotter.color_manager.get_palette().set_ui_theme(ui_mode) logger.info(f"Switched to {theme} with plot template {template}") - # Note: For OS theme detection, add clientside callback with: - # window.matchMedia('(prefers-color-scheme: dark)').matches + # Also update the cached plotter's palette if project is loaded + if _current_project_path in _loaded_projects: + _, cm, cached_plotter, _ = _loaded_projects[_current_project_path] + if cached_plotter and cached_plotter is not plotter: + cached_plotter.set_template( + DARK_PLOTLY_TEMPLATE if is_dark else DEFAULT_PLOTLY_TEMPLATE + ) + cm.get_palette().set_ui_theme(ui_mode) return theme, f"sidebar-nav {theme}", theme, refresh_count + 1 @@ -784,18 +845,37 @@ def on_palette_change(palette: ColorPalette, palette_type: str, palette_name: st global _loaded_projects, _current_project_path if _current_project_path in _loaded_projects: - cached_project, _, _, project_name = _loaded_projects[_current_project_path] + cached_project, _, old_plotter, project_name = _loaded_projects[_current_project_path] + + # Preserve the current plotly template from the existing plotter + current_template = old_plotter._template if old_plotter else DEFAULT_PLOTLY_TEMPLATE # Create a copy of the palette to avoid modifying the original palette_copy = palette.copy() + # Determine current UI theme + ui_mode = "dark" if "dark" in current_template else "light" + # Get current data handler (singleton) data_handler = APIClient(cached_project) + # Ensure sectors/end_uses are current + try: + palette_copy.merge_with_project_dimensions( + sectors=data_handler.get_unique_sectors(), + end_uses=data_handler.get_unique_end_uses(), + ) + except Exception as e: + logger.warning(f"Could not populate sectors/end_uses: {e}") + # Create fresh color manager with new palette - color_manager = create_fresh_color_manager(palette_copy, data_handler.scenarios) + color_manager = create_fresh_color_manager( + palette_copy, + data_handler.scenarios, + ui_theme=ui_mode, + ) - plotter = StridePlots(color_manager, template="plotly_dark") + plotter = StridePlots(color_manager, template=current_template) # Update cache (preserve project and project_name) _loaded_projects[_current_project_path] = ( @@ -862,8 +942,8 @@ def update_scenario_css(palette_data: dict[str, Any], color_edits: int) -> list[ if color_manager is None: raise PreventUpdate - # Get temporary color edits to apply to CSS - temp_edits = get_temp_color_edits() + # Get scenario-only temporary color edits (plain label keys) + scenario_edits = get_temp_edits_for_category("scenarios") return [ html.Script( @@ -875,7 +955,7 @@ def update_scenario_css(palette_data: dict[str, Any], color_edits: int) -> list[ }} var style = document.createElement('style'); style.id = 'scenario-dynamic-css'; - style.textContent = `{color_manager.generate_scenario_css(temp_edits)}`; + style.textContent = `{color_manager.generate_scenario_css(scenario_edits)}`; document.head.appendChild(style); }})(); """ @@ -1117,6 +1197,7 @@ def update_scenario_css_on_project_change(project_path: str) -> list[Any]: return app + def create_app_no_project( user_palette: ColorPalette | None = None, ) -> Dash: @@ -1217,7 +1298,9 @@ def create_app_no_project( html.P( [ "To create a new project, use the CLI: ", - html.Code("stride projects create ", className="welcome-code"), + html.Code( + "stride projects create ", className="welcome-code" + ), ], className="text-muted small", ), @@ -1357,7 +1440,7 @@ def create_app_no_project( dcc.Store(id="current-project-path", data=""), dcc.Store(id="sidebar-open", data=False), dcc.Store(id="chart-refresh-trigger", data=0), - dcc.Store(id="theme-store", data="dark"), + dcc.Store(id="theme-store", data=DEFAULT_CSS_THEME), dcc.Store(id="color-edits-counter", data=0), # Empty scenario CSS container html.Div(id="scenario-css-container", children=[], style={"display": "none"}), @@ -1428,7 +1511,7 @@ def create_app_no_project( ), dbc.Switch( id="theme-toggle", - value=True, + value=False, style={ "transform": "scale(1.2)", }, @@ -1485,14 +1568,14 @@ def create_app_no_project( ), ], id="page-content", - className="page-content dark-theme", + className=f"page-content {DEFAULT_CSS_THEME}", style={ "marginLeft": "0", "transition": "margin-left 0.3s ease-in-out", }, ), ], - className="dark-theme", + className=DEFAULT_CSS_THEME, ) # Register callbacks for no-project mode @@ -1544,12 +1627,33 @@ def _on_palette_change_no_project( global _loaded_projects, _current_project_path if _current_project_path and _current_project_path in _loaded_projects: - cached_project, _, _, project_name = _loaded_projects[_current_project_path] + cached_project, _, old_plotter, project_name = _loaded_projects[_current_project_path] + + # Preserve the current plotly template from the existing plotter + current_template = old_plotter._template if old_plotter else DEFAULT_PLOTLY_TEMPLATE palette_copy = palette.copy() + + # Determine current UI theme + ui_mode = "dark" if "dark" in current_template else "light" + data_handler = APIClient(cached_project) - new_color_manager = create_fresh_color_manager(palette_copy, data_handler.scenarios) - new_plotter = StridePlots(new_color_manager, template="plotly_dark") + + # Ensure sectors/end_uses are current + try: + palette_copy.merge_with_project_dimensions( + sectors=data_handler.get_unique_sectors(), + end_uses=data_handler.get_unique_end_uses(), + ) + except Exception as e: + logger.warning(f"Could not populate sectors/end_uses: {e}") + + new_color_manager = create_fresh_color_manager( + palette_copy, + data_handler.scenarios, + ui_theme=ui_mode, + ) + new_plotter = StridePlots(new_color_manager, template=current_template) _loaded_projects[_current_project_path] = ( cached_project, @@ -1708,12 +1812,14 @@ def _register_theme_toggle_callback() -> None: ) def toggle_theme(is_dark: bool, refresh_count: int) -> tuple[str, str, str, int]: """Toggle between light and dark theme.""" - theme = "dark-theme" if is_dark else "light-theme" + theme = DARK_CSS_THEME if is_dark else DEFAULT_CSS_THEME + ui_mode = "dark" if is_dark else "light" plotter = _get_current_plotter_no_project() if plotter: - template = "plotly_dark" if is_dark else "plotly_white" + template = DARK_PLOTLY_TEMPLATE if is_dark else DEFAULT_PLOTLY_TEMPLATE plotter.set_template(template) + plotter.color_manager.get_palette().set_ui_theme(ui_mode) logger.info(f"Switched to {theme} with plot template {template}") return f"page-content {theme}", f"sidebar-nav {theme}", theme, refresh_count + 1 @@ -1884,6 +1990,7 @@ def _build_successful_load_response( current_palette_type="project", current_palette_name=None, color_manager=color_manager, + default_user_palette=get_default_user_palette(), ) # Generate scenario CSS @@ -1910,7 +2017,11 @@ def _generate_scenario_css_script( temp_edits: dict[str, str] | None = None, ) -> list[Any]: """Generate the scenario CSS script element.""" - css_content = color_manager.generate_scenario_css(temp_edits) if temp_edits else color_manager.generate_scenario_css() + css_content = ( + color_manager.generate_scenario_css(temp_edits) + if temp_edits + else color_manager.generate_scenario_css() + ) return [ html.Script( f""" @@ -1997,7 +2108,15 @@ def _toggle_views_impl( if trigger_id == "sidebar-settings-btn": if not project_path: raise PreventUpdate - return (True, True, False, {"display": "none"}, selected_view, current_refresh_count, no_update) + return ( + True, + True, + False, + {"display": "none"}, + selected_view, + current_refresh_count, + no_update, + ) if trigger_id in ("back-to-dashboard-btn", "home-link"): temp_edits = get_temp_color_edits() @@ -2005,14 +2124,33 @@ def _toggle_views_impl( color_manager = get_current_color_manager() if color_manager: palette = color_manager.get_palette() - for label, color in temp_edits.items(): - palette.update(label, color) - logger.info(f"Applied {len(temp_edits)} temporary color edits when returning to home") + for composite_key, color in temp_edits.items(): + cat_str, label = parse_temp_edit_key(composite_key) + palette.update(label, color, category=cat_str) + logger.info( + f"Applied {len(temp_edits)} temporary color edits when returning to home" + ) - return (False, True, True, {"display": "block"}, "compare", current_refresh_count + 1, no_update) + return ( + False, + True, + True, + {"display": "block"}, + "compare", + current_refresh_count + 1, + no_update, + ) if selected_view == "compare": - return (False, True, True, {"display": "block"}, selected_view, current_refresh_count, no_update) + return ( + False, + True, + True, + {"display": "block"}, + selected_view, + current_refresh_count, + no_update, + ) # Scenario view - need to create layout data_handler = _get_current_data_handler_no_project() @@ -2021,7 +2159,15 @@ def _toggle_views_impl( raise PreventUpdate scenario_layout = create_scenario_layout(data_handler.years, color_manager) - return (True, False, True, {"display": "block"}, selected_view, current_refresh_count, scenario_layout) + return ( + True, + False, + True, + {"display": "block"}, + selected_view, + current_refresh_count, + scenario_layout, + ) def _register_scenario_css_callback( @@ -2049,5 +2195,5 @@ def update_scenario_css( if color_manager is None: raise PreventUpdate - temp_edits = get_temp_color_edits() - return _generate_scenario_css_script(color_manager, temp_edits) + scenario_edits = get_temp_edits_for_category("scenarios") + return _generate_scenario_css_script(color_manager, scenario_edits) diff --git a/src/stride/ui/assets/dark-theme.css b/src/stride/ui/assets/dark-theme.css index b524cfc..aa38c0e 100644 --- a/src/stride/ui/assets/dark-theme.css +++ b/src/stride/ui/assets/dark-theme.css @@ -136,12 +136,38 @@ body { } /* Sidebar dropdown theming */ -.dark-theme #sidebar .Select-control { +.dark-theme #sidebar .dash-dropdown-trigger { background-color: var(--bg-tertiary) !important; border-color: var(--border-color) !important; } -.dark-theme #sidebar .Select-value-label { +.dark-theme #sidebar .dash-dropdown-value { + color: var(--text-primary) !important; +} + +.dark-theme #sidebar .dash-dropdown-placeholder { + color: var(--text-muted) !important; +} + +.dark-theme #sidebar .dash-dropdown-content { + background-color: var(--bg-tertiary) !important; + border-color: var(--border-color) !important; +} + +.dark-theme #sidebar .dash-dropdown-option { + background-color: var(--bg-tertiary) !important; + color: var(--text-primary) !important; +} + +.dark-theme #sidebar .dash-dropdown-option:hover { + background-color: var(--bg-hover) !important; +} + +.dark-theme #sidebar .dash-dropdown-trigger-icon { + color: var(--text-muted) !important; +} + +.dark-theme #sidebar .dash-dropdown-search { color: var(--text-primary) !important; } @@ -245,35 +271,49 @@ body { DROPDOWNS & INPUTS ============================================ */ -.dark-theme .Select-control, -.dark-theme .dash-dropdown .Select-control { + +.dark-theme .dash-dropdown-trigger { background-color: var(--dropdown-bg) !important; border-color: var(--input-border) !important; color: var(--dropdown-text) !important; } -.dark-theme .Select-placeholder, -.dark-theme .Select-value-label { +.dark-theme .dash-dropdown-placeholder, +.dark-theme .dash-dropdown-value { color: var(--dropdown-text) !important; } -.dark-theme .Select-menu-outer { +.dark-theme .dash-dropdown-content { background-color: var(--dropdown-bg) !important; border-color: var(--input-border) !important; } -.dark-theme .Select-option { +.dark-theme .dash-dropdown-option { background-color: var(--dropdown-bg) !important; color: var(--dropdown-text) !important; } -.dark-theme .Select-option:hover { +.dark-theme .dash-dropdown-option:hover { background-color: var(--dropdown-hover) !important; } -.dark-theme .Select-option.is-selected { +.dark-theme .dash-dropdown-option[aria-selected="true"] { + background-color: var(--accent-primary) !important; + color: #ffffff !important; +} + +.dark-theme .dash-dropdown-search { + color: var(--dropdown-text) !important; +} + +.dark-theme .dash-dropdown-trigger-icon { + color: var(--dropdown-text) !important; +} + +.dark-theme .dash-dropdown-value-item { background-color: var(--accent-primary) !important; color: #ffffff !important; + color: #ffffff !important; } /* Dash Bootstrap Components Dropdowns */ diff --git a/src/stride/ui/assets/light-theme.css b/src/stride/ui/assets/light-theme.css index 9a4d2ae..64aef7d 100644 --- a/src/stride/ui/assets/light-theme.css +++ b/src/stride/ui/assets/light-theme.css @@ -135,14 +135,40 @@ body { color: var(--text-primary) !important; } -/* Sidebar dropdown theming */ -.light-theme #sidebar .Select-control { - background-color: var(--bg-primary) !important; - border-color: var(--border-color) !important; +/* Sidebar dropdown theming - sidebar uses dark styling in light mode */ +.light-theme #sidebar .dash-dropdown-trigger { + background-color: #3a3a3a !important; + border-color: #404040 !important; } -.light-theme #sidebar .Select-value-label { - color: var(--text-primary) !important; +.light-theme #sidebar .dash-dropdown-value { + color: #e0e0e0 !important; +} + +.light-theme #sidebar .dash-dropdown-placeholder { + color: #808080 !important; +} + +.light-theme #sidebar .dash-dropdown-content { + background-color: #3a3a3a !important; + border-color: #404040 !important; +} + +.light-theme #sidebar .dash-dropdown-option { + background-color: #3a3a3a !important; + color: #e0e0e0 !important; +} + +.light-theme #sidebar .dash-dropdown-option:hover { + background-color: #404040 !important; +} + +.light-theme #sidebar .dash-dropdown-trigger-icon { + color: #808080 !important; +} + +.light-theme #sidebar .dash-dropdown-search { + color: #e0e0e0 !important; } /* Sidebar settings button - match nav tabs */ @@ -245,37 +271,41 @@ body { DROPDOWNS & INPUTS ============================================ */ -.light-theme .Select-control, -.light-theme .dash-dropdown .Select-control { + +.light-theme .dash-dropdown-trigger { background-color: var(--dropdown-bg) !important; border-color: var(--input-border) !important; color: var(--dropdown-text) !important; } -.light-theme .Select-placeholder, -.light-theme .Select-value-label { +.light-theme .dash-dropdown-placeholder, +.light-theme .dash-dropdown-value { color: var(--dropdown-text) !important; } -.light-theme .Select-menu-outer { +.light-theme .dash-dropdown-content { background-color: var(--dropdown-bg) !important; border-color: var(--input-border) !important; } -.light-theme .Select-option { +.light-theme .dash-dropdown-option { background-color: var(--dropdown-bg) !important; color: var(--dropdown-text) !important; } -.light-theme .Select-option:hover { +.light-theme .dash-dropdown-option:hover { background-color: var(--dropdown-hover) !important; } -.light-theme .Select-option.is-selected { +.light-theme .dash-dropdown-option[aria-selected="true"] { background-color: var(--accent-primary) !important; color: #ffffff !important; } +.light-theme .dash-dropdown-search { + color: var(--dropdown-text) !important; +} + /* Dash Bootstrap Components Dropdowns */ .light-theme .dropdown-menu { background-color: var(--dropdown-bg) !important; @@ -371,6 +401,10 @@ body { color: #212529 !important; } +.light-theme .btn-outline-secondary.theme-text { + color: #212529 !important; +} + /* ============================================ RADIO BUTTONS & CHECKBOXES ============================================ */ diff --git a/src/stride/ui/assets/theme-detector.js b/src/stride/ui/assets/theme-detector.js index ccdb623..3a9a679 100644 --- a/src/stride/ui/assets/theme-detector.js +++ b/src/stride/ui/assets/theme-detector.js @@ -4,9 +4,10 @@ (function () { "use strict"; - // Always default to dark theme + // Default to light (daytime) theme function getOSThemePreference() { - return true; // Always use dark theme by default + // Detect OS theme preference using window.matchMedia + return window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches; } // Apply theme to all relevant elements @@ -31,11 +32,9 @@ sidebar.className = "sidebar-nav " + theme; } - // Update theme toggle switch - const themeToggle = document.getElementById("theme-toggle"); - if (themeToggle) { - themeToggle.checked = isDark; - } + // NOTE: Do NOT set themeToggle.checked here — that would desync + // the DOM from Dash's React state and break the toggle callback. + // Dash manages the toggle value via dbc.Switch(value=...). console.log("STRIDE Theme applied:", theme); } @@ -52,6 +51,6 @@ applyTheme(prefersDark); } - // OS theme change listener removed - we always default to dark theme + // OS theme change listener removed - we default to light theme // Users can manually toggle theme using the theme toggle switch })(); diff --git a/src/stride/ui/color_manager.py b/src/stride/ui/color_manager.py index 0c54ff9..8184e75 100644 --- a/src/stride/ui/color_manager.py +++ b/src/stride/ui/color_manager.py @@ -1,7 +1,7 @@ import re from typing import Dict, List, Self -from .palette import ColorPalette +from .palette import ColorCategory, ColorPalette class ColorManager: @@ -34,24 +34,46 @@ def initialize_colors( sectors: List[str] | None = None, end_uses: List[str] | None = None, ) -> None: - """Initialize colors for all entities at once to ensure consistency.""" - all_keys = scenarios.copy() + """Initialize colors for all entities at once to ensure consistency. + + Each entity type is stored in its correct palette category: + scenarios → ``ColorCategory.SCENARIO``, sectors/end-uses → + ``ColorCategory.SECTOR`` / ``ColorCategory.END_USE``. + """ + # Scenarios → scenario palette + for key in scenarios: + self.get_color(key, ColorCategory.SCENARIO) + + # Sectors → sector palette if sectors: - all_keys.extend(sectors) - if end_uses: - all_keys.extend(end_uses) + for key in sectors: + self.get_color(key, ColorCategory.SECTOR) - # Pre-generate colors for all keys to ensure consistent assignment - for key in all_keys: - self.get_color(key) + # End-uses → end-use palette + if end_uses: + for key in end_uses: + self.get_color(key, ColorCategory.END_USE) # Generate scenario styling colors self._generate_scenario_colors(scenarios) - def get_color(self, key: str) -> str: - """Get consistent RGBA color for a given key.""" + def get_color( + self, + key: str, + category: ColorCategory | str | None = None, + ) -> str: + """Get consistent RGBA color for a given key. + + Parameters + ---------- + key : str + Label to look up (scenario name, sector, end-use, year, etc.) + category : ColorCategory | str | None + Which palette category to use. When ``None``, all categories are + searched and new keys default to ``ColorCategory.SECTOR``. + """ # Get color from palette (could be hex or rgb string) - color = self._palette.get(key) + color = self._palette.get(key, category=category) # Convert to RGBA for UI usage if color.startswith("#"): @@ -120,7 +142,7 @@ def get_palette(self) -> ColorPalette: def _generate_scenario_colors(self, scenarios: List[str]) -> None: """Generate background and border colors for scenarios.""" for scenario in scenarios: - base_color = self.get_color(scenario) + base_color = self.get_color(scenario, ColorCategory.SCENARIO) r, g, b, _ = self._str_to_rgba(base_color) self._scenario_colors[scenario] = { diff --git a/src/stride/ui/home/callbacks.py b/src/stride/ui/home/callbacks.py index 5ae1f46..c51e7c4 100644 --- a/src/stride/ui/home/callbacks.py +++ b/src/stride/ui/home/callbacks.py @@ -13,7 +13,8 @@ get_hoverlabel_style, get_warning_annotation_style, ) -from stride.ui.settings.layout import get_temp_color_edits +from stride.ui.palette import ColorCategory +from stride.ui.settings.layout import get_temp_edits_for_category if TYPE_CHECKING: from stride.api import APIClient @@ -109,6 +110,13 @@ def update_home_scenario_comparison( # noqa: C901 try: # Convert "None" to None breakdown_value = None if breakdown == "None" else breakdown + breakdown_type = ( + ColorCategory.END_USE + if breakdown_value == "End Use" + else ColorCategory.SECTOR + if breakdown_value == "Sector" + else None + ) # Get the main consumption data df = data_handler.get_annual_electricity_consumption( @@ -120,7 +128,9 @@ def update_home_scenario_comparison( # noqa: C901 # Create the main plot if breakdown_value: stack_col = "metric" if breakdown_value == "End Use" else str(breakdown_value) - fig = plotter.grouped_stacked_bars(df, stack_col=stack_col.lower(), value_col="value") + fig = plotter.grouped_stacked_bars( + df, stack_col=stack_col.lower(), value_col="value", breakdown_type=breakdown_type + ) else: fig = plotter.grouped_single_bars(df, "scenario") @@ -135,7 +145,9 @@ def update_home_scenario_comparison( # noqa: C901 if not secondary_df.empty: # Get scenario color from color manager - scenario_color = plotter.color_manager.get_color(scenario) + scenario_color = plotter.color_manager.get_color( + scenario, ColorCategory.SCENARIO + ) # Add background line when no breakdown (total only) if breakdown_value is None: @@ -273,6 +285,13 @@ def update_home_sector_breakdown( # noqa: C901 try: # Convert "None" to None breakdown_value = None if breakdown == "None" else breakdown + breakdown_type = ( + ColorCategory.END_USE + if breakdown_value == "End Use" + else ColorCategory.SECTOR + if breakdown_value == "Sector" + else None + ) # Get the peak demand data df = data_handler.get_annual_peak_demand( @@ -285,7 +304,9 @@ def update_home_sector_breakdown( # noqa: C901 if breakdown_value: stack_col = "metric" if breakdown_value == "End Use" else str(breakdown_value) - fig = plotter.grouped_stacked_bars(df, stack_col=stack_col.lower(), value_col="value") + fig = plotter.grouped_stacked_bars( + df, stack_col=stack_col.lower(), value_col="value", breakdown_type=breakdown_type + ) else: fig = plotter.grouped_single_bars(df, "scenario") @@ -300,7 +321,9 @@ def update_home_sector_breakdown( # noqa: C901 if not secondary_df.empty: # Get scenario color from color manager - scenario_color = plotter.color_manager.get_color(scenario) + scenario_color = plotter.color_manager.get_color( + scenario, ColorCategory.SCENARIO + ) # Add background line when no breakdown (total only) if breakdown_value is None: @@ -475,6 +498,13 @@ def update_home_scenario_timeseries( # noqa: C901 try: # Convert "None" to None breakdown_value = None if breakdown == "None" else breakdown + breakdown_type = ( + ColorCategory.END_USE + if breakdown_value == "End Use" + else ColorCategory.SECTOR + if breakdown_value == "Sector" + else None + ) # Get the consumption data for all scenarios df = data_handler.get_annual_electricity_consumption( @@ -557,7 +587,10 @@ def update_home_scenario_timeseries( # noqa: C901 name=category, mode="lines", line=dict( - color=plotter.color_manager.get_color(category) + color=plotter.color_manager.get_color( + category, + breakdown_type or ColorCategory.SECTOR, + ) ), fill="tonexty" if j > 0 else "tozeroy", stackgroup="one", @@ -579,7 +612,10 @@ def update_home_scenario_timeseries( # noqa: C901 name=category, mode="lines+markers", line=dict( - color=plotter.color_manager.get_color(category) + color=plotter.color_manager.get_color( + category, + breakdown_type or ColorCategory.SECTOR, + ) ), showlegend=show_legend, legendgroup=category, @@ -607,7 +643,9 @@ def update_home_scenario_timeseries( # noqa: C901 name=scenario, mode="lines", line=dict( - color=plotter.color_manager.get_color(scenario) + color=plotter.color_manager.get_color( + scenario, ColorCategory.SCENARIO + ) ), fill="tozeroy", showlegend=False, @@ -627,7 +665,9 @@ def update_home_scenario_timeseries( # noqa: C901 name=scenario, mode="lines+markers", line=dict( - color=plotter.color_manager.get_color(scenario) + color=plotter.color_manager.get_color( + scenario, ColorCategory.SCENARIO + ) ), showlegend=False, hovertemplate="Year: %{x}
" @@ -647,7 +687,9 @@ def update_home_scenario_timeseries( # noqa: C901 if not scenario_secondary.empty: row = (idx // cols) + 1 col = (idx % cols) + 1 - scenario_color = plotter.color_manager.get_color(scenario) + scenario_color = plotter.color_manager.get_color( + scenario, ColorCategory.SCENARIO + ) fig.add_trace( go.Scatter( @@ -703,6 +745,7 @@ def update_home_scenario_timeseries( # noqa: C901 chart_type=chart_type, group_by=stack_col.lower() if breakdown_value else None, value_col="value", + breakdown_type=breakdown_type, ) warning_style = get_warning_annotation_style(plotter.get_template()) fig.add_annotation( @@ -724,6 +767,7 @@ def update_home_scenario_timeseries( # noqa: C901 chart_type=chart_type, group_by=stack_col.lower() if breakdown_value else None, value_col="value", + breakdown_type=breakdown_type, ) error_style = get_error_annotation_style(plotter.get_template()) fig.add_annotation( @@ -745,6 +789,7 @@ def update_home_scenario_timeseries( # noqa: C901 chart_type=chart_type, group_by=stack_col.lower() if breakdown_value else None, value_col="value", + breakdown_type=breakdown_type, ) error_msg = str(e) if "does not exist" in error_msg.lower() or "not found" in error_msg.lower(): @@ -771,6 +816,7 @@ def update_home_scenario_timeseries( # noqa: C901 chart_type=chart_type, group_by=stack_col.lower() if breakdown_value else None, value_col="value", + breakdown_type=breakdown_type, ) return fig @@ -861,8 +907,8 @@ def _update_button_styles_1( if current_color_manager is None: return [{}] * len(button_ids) - # Get temporary color edits - temp_edits = get_temp_color_edits() + # Get scenario-only temporary color edits (plain label keys) + scenario_edits = get_temp_edits_for_category("scenarios") styles = [] selected_scenarios = selected_scenarios or [] @@ -872,13 +918,13 @@ def _update_button_styles_1( is_selected = scenario in selected_scenarios # Check if there's a temporary edit for this scenario - if scenario in temp_edits: - base_color = temp_edits[scenario] + if scenario in scenario_edits: + base_color = scenario_edits[scenario] # Temp edits are stored as hex, convert to rgba if base_color.startswith("#"): base_color = current_color_manager._hex_to_rgba_str(base_color) else: - base_color = current_color_manager.get_color(scenario) + base_color = current_color_manager.get_color(scenario, ColorCategory.SCENARIO) r, g, b, _ = current_color_manager._str_to_rgba(base_color) alpha = 0.9 if is_selected else 0.3 @@ -953,8 +999,8 @@ def _update_button_styles_2( if current_color_manager is None: return [{}] * len(button_ids) - # Get temporary color edits - temp_edits = get_temp_color_edits() + # Get scenario-only temporary color edits (plain label keys) + scenario_edits = get_temp_edits_for_category("scenarios") styles = [] selected_scenarios = selected_scenarios or [] @@ -964,13 +1010,13 @@ def _update_button_styles_2( is_selected = scenario in selected_scenarios # Check if there's a temporary edit for this scenario - if scenario in temp_edits: - base_color = temp_edits[scenario] + if scenario in scenario_edits: + base_color = scenario_edits[scenario] # Temp edits are stored as hex, convert to rgba if base_color.startswith("#"): base_color = current_color_manager._hex_to_rgba_str(base_color) else: - base_color = current_color_manager.get_color(scenario) + base_color = current_color_manager.get_color(scenario, ColorCategory.SCENARIO) r, g, b, _ = current_color_manager._str_to_rgba(base_color) alpha = 0.9 if is_selected else 0.3 @@ -1045,8 +1091,8 @@ def _update_button_styles_3( if current_color_manager is None: return [{}] * len(button_ids) - # Get temporary color edits - temp_edits = get_temp_color_edits() + # Get scenario-only temporary color edits (plain label keys) + scenario_edits = get_temp_edits_for_category("scenarios") styles = [] selected_scenarios = selected_scenarios or [] @@ -1056,13 +1102,13 @@ def _update_button_styles_3( is_selected = scenario in selected_scenarios # Check if there's a temporary edit for this scenario - if scenario in temp_edits: - base_color = temp_edits[scenario] + if scenario in scenario_edits: + base_color = scenario_edits[scenario] # Temp edits are stored as hex, convert to rgba if base_color.startswith("#"): base_color = current_color_manager._hex_to_rgba_str(base_color) else: - base_color = current_color_manager.get_color(scenario) + base_color = current_color_manager.get_color(scenario, ColorCategory.SCENARIO) r, g, b, _ = current_color_manager._str_to_rgba(base_color) alpha = 0.9 if is_selected else 0.3 @@ -1137,8 +1183,8 @@ def _update_button_styles_4( if current_color_manager is None: return [{}] * len(button_ids) - # Get temporary color edits - temp_edits = get_temp_color_edits() + # Get scenario-only temporary color edits (plain label keys) + scenario_edits = get_temp_edits_for_category("scenarios") styles = [] selected_scenarios = selected_scenarios or [] @@ -1148,13 +1194,13 @@ def _update_button_styles_4( is_selected = scenario in selected_scenarios # Check if there's a temporary edit for this scenario - if scenario in temp_edits: - base_color = temp_edits[scenario] + if scenario in scenario_edits: + base_color = scenario_edits[scenario] # Temp edits are stored as hex, convert to rgba if base_color.startswith("#"): base_color = current_color_manager._hex_to_rgba_str(base_color) else: - base_color = current_color_manager.get_color(scenario) + base_color = current_color_manager.get_color(scenario, ColorCategory.SCENARIO) r, g, b, _ = current_color_manager._str_to_rgba(base_color) alpha = 0.9 if is_selected else 0.3 diff --git a/src/stride/ui/home/layout.py b/src/stride/ui/home/layout.py index 79f89e0..fd908ef 100644 --- a/src/stride/ui/home/layout.py +++ b/src/stride/ui/home/layout.py @@ -5,6 +5,7 @@ from stride.api.utils import SecondaryMetric, literal_to_list from stride.ui.color_manager import ColorManager +from stride.ui.palette import ColorCategory def create_home_layout( @@ -28,7 +29,7 @@ def create_styled_checklist(scenarios_list: list[str], checklist_id: Any) -> htm scenario_buttons = [] for scenario in scenarios_list: # Get the scenario color from color manager - base_color = color_manager.get_color(scenario) + base_color = color_manager.get_color(scenario, ColorCategory.SCENARIO) r, g, b, _ = color_manager._str_to_rgba(base_color) # Determine if scenario is selected diff --git a/src/stride/ui/palette.py b/src/stride/ui/palette.py index c7525a8..1c71304 100644 --- a/src/stride/ui/palette.py +++ b/src/stride/ui/palette.py @@ -9,10 +9,10 @@ """ import re +from enum import StrEnum from itertools import cycle from typing import Any, Mapping, MutableSequence, TypedDict -from plotly import colors # can have a project color palette, or a user color palette? # can toggle between project and use color palette? @@ -24,6 +24,180 @@ rgb_color_pattern = re.compile(r"^rgba?\(\s*\d+\s*,\s*\d+\s*,\s*\d+\s*(?:,\s*[\d.]+\s*)?\)$") +class ColorCategory(StrEnum): + """Categories for color palette entries. + + ``SECTOR`` and ``END_USE`` share the same color theme but maintain + independent iterators so each group starts from position 0. + """ + + SCENARIO = "scenario" + MODEL_YEAR = "model_year" + SECTOR = "sector" + END_USE = "end_use" + + +# ============================================================================ +# Paul Tol color-blind-safe palettes +# Source: https://sronpersonalpages.nl/~pault/ +# ============================================================================ + +# Scenarios: Tol Bright (7 colors) — primary qualitative, color-blind safe +TOL_BRIGHT = [ + "#4477AA", # blue + "#CCBB44", # yellow + "#228833", # green + "#EE6677", # red + "#66CCEE", # cyan + "#AA3377", # purple + "#BBBBBB", # grey +] + +# Metrics (light mode): dark-enough colors from Tol Muted + Discrete Rainbow 14 +TOL_METRICS_LIGHT = [ + "#CC6677", # muted rose + "#999933", # muted olive + "#5289C7", # DR14 med blue + "#117733", # muted green + "#882255", # muted wine + "#1965B0", # DR14 blue + "#E8601C", # DR14 red-orange + "#332288", # muted indigo + "#AA4499", # muted purple + "#DC050C", # DR14 red + "#AE76A3", # DR14 mauve + "#882E72", # DR14 dk purple +] + +# Metrics (dark mode): light-enough colors from Tol Muted + Discrete Rainbow 14 +TOL_METRICS_DARK = [ + "#CC6677", # muted rose + "#DDCC77", # muted sand + "#88CCEE", # muted cyan + "#44AA99", # muted teal + "#AA4499", # muted purple + "#5289C7", # DR14 med blue + "#999933", # muted olive + "#F4A736", # DR14 orange + "#90C987", # DR14 lt green + "#D1BBD7", # DR14 lavender + "#AE76A3", # DR14 mauve + "#7BAFDE", # DR14 lt blue + "#E8601C", # DR14 red-orange + "#DDDDDD", # muted pale grey + "#DC050C", # DR14 red + "#117733", # muted green +] + +# Model years: Tol Iridescent (23 colors, sequential, designed for interpolation) +TOL_IRIDESCENT = [ + "#FEFBE9", # idx 0 + "#FCF7D5", # idx 1 + "#F5F3C1", # idx 2 + "#EAF0B5", # idx 3 + "#DDECBF", # idx 4 + "#D0E7CA", # idx 5 + "#C2E3D2", # idx 6 + "#B5DDD8", # idx 7 + "#A8D8DC", # idx 8 + "#9BD2E1", # idx 9 + "#8DCBE4", # idx 10 + "#81C4E7", # idx 11 + "#7BBCE7", # idx 12 + "#7EB2E4", # idx 13 + "#88A5DD", # idx 14 + "#9398D2", # idx 15 + "#9B8AC4", # idx 16 + "#9D7DB2", # idx 17 + "#9A709E", # idx 18 + "#906388", # idx 19 + "#805770", # idx 20 + "#684957", # idx 21 + "#46353A", # idx 22 +] + +# WCAG-derived usable index ranges for Iridescent (3.0:1 contrast threshold) +IRIDESCENT_LIGHT_START = 16 # First index passing 3:1 on #FFFFFF +IRIDESCENT_LIGHT_END = 22 # Last index (inclusive) +IRIDESCENT_DARK_START = 0 # First index passing 3:1 on #1A1A1A +IRIDESCENT_DARK_END = 19 # Last index passing 3:1 on #1A1A1A + + +def _hex_to_rgb(hex_color: str) -> tuple[int, int, int]: + """Convert hex color string to RGB tuple.""" + h = hex_color.lstrip("#") + return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) + + +def _rgb_to_hex(r: int, g: int, b: int) -> str: + """Convert RGB values to hex color string.""" + return f"#{r:02X}{g:02X}{b:02X}" + + +def _interpolate_hex(color1: str, color2: str, t: float) -> str: + """Linearly interpolate between two hex colors. + + Parameters + ---------- + color1 : str + Start color (hex) + color2 : str + End color (hex) + t : float + Interpolation factor, 0.0 = color1, 1.0 = color2 + """ + r1, g1, b1 = _hex_to_rgb(color1) + r2, g2, b2 = _hex_to_rgb(color2) + r = round(r1 + (r2 - r1) * t) + g = round(g1 + (g2 - g1) * t) + b = round(b1 + (b2 - b1) * t) + return _rgb_to_hex(r, g, b) + + +def sample_iridescent(n: int, theme: str = "light") -> list[str]: + """Sample n evenly-spaced colors from the Tol Iridescent ramp. + + Uses the WCAG-safe index range for the given theme, and linearly + interpolates between defined colors when more colors are needed + than available in the usable range. + + Parameters + ---------- + n : int + Number of colors to produce + theme : str + ``"light"`` or ``"dark"`` — selects the usable index range + + Returns + ------- + list[str] + List of n hex color strings + """ + if theme == "dark": + start, end = IRIDESCENT_DARK_START, IRIDESCENT_DARK_END + else: + start, end = IRIDESCENT_LIGHT_START, IRIDESCENT_LIGHT_END + + if n <= 0: + return [] + if n == 1: + mid = (start + end) // 2 + return [TOL_IRIDESCENT[mid]] + + # Generate n evenly-spaced positions in the continuous [start, end] range + positions = [start + i * (end - start) / (n - 1) for i in range(n)] + result = [] + for pos in positions: + idx_low = int(pos) + idx_high = min(idx_low + 1, len(TOL_IRIDESCENT) - 1) + if idx_low == idx_high: + result.append(TOL_IRIDESCENT[idx_low]) + else: + t = pos - idx_low + result.append(_interpolate_hex(TOL_IRIDESCENT[idx_low], TOL_IRIDESCENT[idx_high], t)) + return result + + class PaletteItem(TypedDict): """Structure for a palette item with label, color, and order.""" @@ -39,63 +213,43 @@ class ColorPalette: Keys typically map to label values in a stack chart or chart label. """ - def __init__( # noqa: C901 + def __init__( self, - palette: dict[str, dict[str, str]] | dict[str, str] | None = None, + *, + scenario_theme: list[str] | None = None, + model_year_theme: list[str] | None = None, + metric_theme: list[str] | None = None, ): - """Initializes a new ColorPalette instance with colors organized by category. + """Create an empty palette with the given color themes. + + Use :meth:`load` to construct a palette from serialized data. Parameters ---------- - palette : dict[str, str] | dict[str, dict[str, str]] | None, optional - Either a flat dictionary of label->color mappings (legacy format) or - a structured dictionary with 'scenarios', 'model_years', and 'metrics' keys. + scenario_theme : list[str] | None + Custom color cycle for scenarios. Defaults to :data:`TOL_BRIGHT`. + model_year_theme : list[str] | None + Custom color cycle for model years. Defaults to :data:`TOL_IRIDESCENT`. + metric_theme : list[str] | None + Custom color cycle for sectors and end uses. Defaults to + :data:`TOL_METRICS_LIGHT`. """ - # Different themes for each category - self.scenario_theme = colors.qualitative.Antique # type: ignore[attr-defined] - self.model_year_theme = colors.sequential.YlOrRd # type: ignore[attr-defined] - self.metric_theme = colors.qualitative.Prism # type: ignore[attr-defined] + # Color-blind-safe themes (Paul Tol palettes) — overridable + self.scenario_theme: list[str] = list(scenario_theme or TOL_BRIGHT) + self.model_year_theme: list[str] = list(model_year_theme or TOL_IRIDESCENT) + self.metric_theme: list[str] = list(metric_theme or TOL_METRICS_LIGHT) + self._ui_theme: str = "light" # "light" or "dark" self._scenario_iterator = cycle(self.scenario_theme) self._model_year_iterator = cycle(self.model_year_theme) - self._metric_iterator = cycle(self.metric_theme) + self._sector_iterator = cycle(self.metric_theme) + self._end_use_iterator = cycle(self.metric_theme) # Separate palettes for each category self.scenarios: dict[str, str] = {} self.model_years: dict[str, str] = {} - self.metrics: dict[str, str] = {} - - if palette: - # Check if it's the new structured format - if ( - isinstance(palette, dict) - and all(k in palette for k in ["scenarios", "model_years", "metrics"]) - and all(isinstance(v, dict) for v in palette.values()) - ): - # New structured format - scenarios_dict = palette["scenarios"] - model_years_dict = palette["model_years"] - metrics_dict = palette["metrics"] - - if isinstance(scenarios_dict, dict): - for label, color in scenarios_dict.items(): - if isinstance(color, str): - self.update(label, color, category="scenarios") - - if isinstance(model_years_dict, dict): - for label, color in model_years_dict.items(): - if isinstance(color, str): - self.update(label, color, category="model_years") - - if isinstance(metrics_dict, dict): - for label, color in metrics_dict.items(): - if isinstance(color, str): - self.update(label, color, category="metrics") - else: - # Legacy flat format - default to metrics - for label, color_value in palette.items(): - if isinstance(color_value, str): - self.update(label, color_value) + self.sectors: dict[str, str] = {} + self.end_uses: dict[str, str] = {} @property def palette(self) -> dict[str, str]: @@ -109,20 +263,34 @@ def palette(self) -> dict[str, str]: result = {} result.update(self.scenarios) result.update(self.model_years) - result.update(self.metrics) + result.update(self.sectors) + result.update(self.end_uses) return result def __str__(self) -> str: """Return a string representation of the palette.""" num_scenarios = len(self.scenarios) num_model_years = len(self.model_years) - num_metrics = len(self.metrics) - return f"ColorPalette(scenarios={num_scenarios}, model_years={num_model_years}, metrics={num_metrics})" + num_sectors = len(self.sectors) + num_end_uses = len(self.end_uses) + return ( + f"ColorPalette(scenarios={num_scenarios}, model_years={num_model_years}, " + f"sectors={num_sectors}, end_uses={num_end_uses})" + ) def __repr__(self) -> str: """Return a detailed string representation of the palette.""" return self.__str__() + @property + def has_custom_themes(self) -> bool: + """Return ``True`` if any theme differs from the built-in defaults.""" + return ( + self.scenario_theme != list(TOL_BRIGHT) + or self.model_year_theme != list(TOL_IRIDESCENT) + or self.metric_theme != list(TOL_METRICS_LIGHT) + ) + def copy(self) -> "ColorPalette": """Create a deep copy of this ColorPalette. @@ -131,362 +299,581 @@ def copy(self) -> "ColorPalette": ColorPalette A new ColorPalette instance with the same colors and structure. """ - return ColorPalette(self.to_dict()) + return ColorPalette.from_dict(self.to_dict()) + + # -- Helper methods (used by update / get / pop / set_ui_theme) ----------- + + @staticmethod + def _is_valid_color(color: str | None) -> bool: + """Return ``True`` if *color* is a recognised hex or rgb/rgba string.""" + return isinstance(color, str) and bool( + hex_color_pattern.match(color) or rgb_color_pattern.match(color) + ) + + def _get_target(self, category: ColorCategory) -> tuple[dict[str, str], Any]: + """Return ``(color_dict, iterator)`` for *category*.""" + _map = { + ColorCategory.SCENARIO: (self.scenarios, self._scenario_iterator), + ColorCategory.MODEL_YEAR: (self.model_years, self._model_year_iterator), + ColorCategory.SECTOR: (self.sectors, self._sector_iterator), + ColorCategory.END_USE: (self.end_uses, self._end_use_iterator), + } + return _map[category] - def update(self, key: str, color: str | None = None, category: str | None = None) -> None: # noqa: C901 - """Updates or creates a new color for the given *key* in the specified category. + def _resolve_str_category( + self, + category: "ColorCategory | str | None", + ) -> "ColorCategory | None": + """Convert a plain string to :class:`ColorCategory`, passing through ``None``.""" + if category is None or isinstance(category, ColorCategory): + return category + return ColorCategory(category) + + def _sort_model_years(self) -> None: + """Re-sort model-year entries in chronological order (in-place).""" + sorted_items = sorted( + self.model_years.items(), key=lambda x: int(x[0]) if x[0].isdigit() else 0 + ) + self.model_years.clear() + self.model_years.update(sorted_items) + + def _reassign_and_reset( + self, + category: ColorCategory, + auto_colors: set[str] | None = None, + ) -> None: + """Re-color entries in *category* from position 0 and reset its iterator. + + Parameters + ---------- + auto_colors : set[str] | None + When provided, only colors that appear in this set are + overwritten; all others are treated as user-customised and + preserved. Pass ``None`` to overwrite everything (legacy + behaviour used by ``reset_to_defaults``). + """ + target_dict, _ = self._get_target(category) + _theme_attr = { + ColorCategory.SCENARIO: "scenario_theme", + ColorCategory.MODEL_YEAR: "model_year_theme", + ColorCategory.SECTOR: "metric_theme", + ColorCategory.END_USE: "metric_theme", + } + _iter_attr = { + ColorCategory.SCENARIO: "_scenario_iterator", + ColorCategory.MODEL_YEAR: "_model_year_iterator", + ColorCategory.SECTOR: "_sector_iterator", + ColorCategory.END_USE: "_end_use_iterator", + } + theme = getattr(self, _theme_attr[category]) + fresh = cycle(theme) + for key in target_dict: + new_color = next(fresh) + if auto_colors is None or target_dict[key] in auto_colors: + target_dict[key] = new_color + # else: preserve user-customised color + # Reset iterator, advanced past assigned entries + new_iter = cycle(theme) + for _ in range(len(target_dict)): + next(new_iter) + setattr(self, _iter_attr[category], new_iter) + + # -- Public API ----------------------------------------------------------- + + def update( + self, + key: str, + color: str | None = None, + *, + category: ColorCategory | str, + ) -> None: + """Update or create a color for the given *key*. Keys are normalized to lowercase for consistent lookups. Parameters ---------- key : str - The lookup key for which to assign or update the color + The lookup key for which to assign or update the color. color : str | None, optional - A hex string or rgb/rgba string representation of the color. If ``None`` or invalid, a new color - is assigned based on the theme. - category : str | None, optional - The category to update: 'scenarios', 'model_years', or 'metrics'. - If None, attempts to determine automatically or defaults to 'metrics'. - - Raises - ------ - TypeError - If *key* is not a string. - ValueError - If *category* is not a valid category name. + A hex or rgb/rgba color string. If ``None`` or invalid a new + color is assigned from the category's theme. + category : ColorCategory | str + Target category (required). """ - if not isinstance(key, str): msg = "ColorPalette: Key must be a string" raise TypeError(msg) - # Normalize key to lowercase for consistent lookups key = key.lower() - - # Determine which palette to update to get the right color iterator - if category is None: - # Auto-detect: check if key exists in any category - if key in self.scenarios: - category = "scenarios" - elif key in self.model_years: - category = "model_years" - elif key in self.metrics: - category = "metrics" - else: - # Default to metrics for new keys - category = "metrics" - - # Get color from appropriate theme if not provided or invalid - if color is None or not isinstance(color, str): - if category == "scenarios": - color = next(self._scenario_iterator) - elif category == "model_years": - color = next(self._model_year_iterator) - else: # metrics - color = next(self._metric_iterator) - elif not (hex_color_pattern.match(color) or rgb_color_pattern.match(color)): - if category == "scenarios": - color = next(self._scenario_iterator) - elif category == "model_years": - color = next(self._model_year_iterator) - else: # metrics - color = next(self._metric_iterator) - - if category == "scenarios": - self.scenarios[key] = color - elif category == "model_years": - self.model_years[key] = color - # Re-sort to maintain chronological order (but don't reassign colors) - # This ensures display order is correct without changing existing color assignments - self.model_years = dict( - sorted(self.model_years.items(), key=lambda x: int(x[0]) if x[0].isdigit() else 0) - ) - elif category == "metrics": - self.metrics[key] = color - else: - msg = ( - f"Invalid category: {category}. Must be 'scenarios', 'model_years', or 'metrics'." - ) + resolved = self._resolve_str_category(category) + if resolved is None: + msg = "ColorPalette.update() requires a valid category" raise ValueError(msg) + target_dict, iterator = self._get_target(resolved) + target_dict[key] = ( + color if color is not None and self._is_valid_color(color) else next(iterator) + ) - def get(self, key: str, category: str | None = None) -> str: - """Returns the hex string representation of the color for a given *key*. + if resolved == ColorCategory.MODEL_YEAR: + self._sort_model_years() - Keys are normalized to lowercase for consistent lookups. - Searches across all categories (scenarios, model_years, metrics) unless - a specific category is provided. If *key* does not exist, a new color is - generated based on the theme, stored in the metrics category, and returned. + def get(self, key: str, category: ColorCategory | str | None = None) -> str: + """Return the color for *key*, generating one if it does not exist. + + Keys are normalized to lowercase. Searches across all categories + unless *category* is specified. Parameters ---------- key : str - The lookup key for the color. - category : str | None, optional - Specific category to search: 'scenarios', 'model_years', or 'metrics'. - If None, searches all categories. - - Returns - ------- - str - hex string representing the color for a given *key* + The lookup key. + category : ColorCategory | str | None, optional + Specific category to search / store into. """ - # Normalize key to lowercase for consistent lookups key = key.lower() + resolved = self._resolve_str_category(category) - color = None - - if category: - # Search specific category - if category == "scenarios": - color = self.scenarios.get(key) - elif category == "model_years": - color = self.model_years.get(key) - elif category == "metrics": - color = self.metrics.get(key) + if resolved is not None: + target_dict, _ = self._get_target(resolved) + if key in target_dict: + return target_dict[key] else: - # Search all categories - color = self.scenarios.get(key) or self.model_years.get(key) or self.metrics.get(key) - - if color is None: - # Get the next color from the appropriate theme and store it in metrics by default - color = next(self._metric_iterator) - self.metrics[key] = color - + for cat in ColorCategory: + d, _ = self._get_target(cat) + if key in d: + return d[key] + + # Generate a new color + effective = resolved or ColorCategory.SECTOR + target_dict, iterator = self._get_target(effective) + color = str(next(iterator)) + target_dict[key] = color return color - def pop(self, key: str, category: str | None = None) -> str: - """Removes the entry from the palette and returns the color string. - - Keys are normalized to lowercase for consistent lookups. + def pop(self, key: str, *, category: ColorCategory | str) -> str: + """Remove *key* from the palette and return its color. Parameters ---------- key : str - The key to remove from the palette - category : str | None, optional - Specific category to remove from. If None, searches all categories. - - Returns - ------- - str - The color string that was associated with *key* + Key to remove. + category : ColorCategory | str + Category to remove from (required). Raises ------ KeyError - If *key* is not present in any category + If *key* is not found. """ - # Normalize key to lowercase for consistent lookups key = key.lower() + resolved = self._resolve_str_category(category) + if resolved is None: + msg = "ColorPalette.pop() requires a valid category" + raise ValueError(msg) - if category: - # Remove from specific category - if category == "scenarios" and key in self.scenarios: - return self.scenarios.pop(key) - elif category == "model_years" and key in self.model_years: - return self.model_years.pop(key) - elif category == "metrics" and key in self.metrics: - return self.metrics.pop(key) - else: - # Search all categories - if key in self.scenarios: - return self.scenarios.pop(key) - elif key in self.model_years: - return self.model_years.pop(key) - elif key in self.metrics: - return self.metrics.pop(key) + d, _ = self._get_target(resolved) + if key in d: + return d.pop(key) msg = f"ColorPalette: unable to remove key: {key}" raise KeyError(msg) - @classmethod - def from_dict(cls, palette: dict[str, dict[str, str]] | dict[str, str]) -> "ColorPalette": # noqa: C901 + def set_ui_theme(self, theme: str) -> None: + """Switch palettes for the given UI theme (``"light"`` or ``"dark"``). + + Updates the metric theme, re-assigns sector/end-use colors, and + re-samples Iridescent colors for model years. colors that have + been manually customised (i.e. do not appear in the old metric + theme) are preserved. """ - Loads the color palette from a dictionary representation with sanitization. + if theme not in ("light", "dark"): + msg = f"Invalid UI theme: {theme!r}. Must be 'light' or 'dark'." + raise ValueError(msg) + + # Remember old theme colors so we can detect custom assignments + old_auto_colors = set(self.metric_theme) + old_theme = self._ui_theme + + self._ui_theme = theme + self.metric_theme = list(TOL_METRICS_LIGHT if theme == "light" else TOL_METRICS_DARK) + + # Re-assign sector and end-use colors, preserving custom ones + self._reassign_and_reset(ColorCategory.SECTOR, old_auto_colors) + self._reassign_and_reset(ColorCategory.END_USE, old_auto_colors) + + # Re-sample model year colors, preserving user-customised ones + n_years = len(self.model_years) + if n_years > 0: + old_model_year_auto = set(sample_iridescent(n_years, theme=old_theme)) + new_colors = sample_iridescent(n_years, theme=theme) + for (key, old_color), color in zip(list(self.model_years.items()), new_colors): + if old_color in old_model_year_auto: + self.model_years[key] = color + + self.model_year_theme = list(TOL_IRIDESCENT) + self._model_year_iterator = cycle(self.model_year_theme) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ColorPalette": # noqa: C901 + """Construct a :class:`ColorPalette` from a serialized dictionary. + + Accepts three on-disk shapes: + + * **Structured (current)** — top-level keys ``scenarios``, + ``model_years``, ``sectors``, ``end_uses`` (each mapping + names to hex colors). Optionally includes a ``themes`` key + with per-category color-cycle overrides ("full" palette). + * **Legacy structured** — ``scenarios``, ``model_years``, and + ``metrics`` (the old 3-key format). + * **Legacy flat** — a single-level ``{name: color}`` dict. + + When ``themes`` is present the palette is *full*: custom color + cycles replace the built-in TOL defaults. Otherwise the palette + is *minimal* and the TOL defaults are used for any names not + already assigned a color. Parameters ---------- - palette : dict[str, str] | dict[str, dict[str, str]] - Either a flat mapping of string keys to hex color strings (legacy format) - or a structured dictionary with 'scenarios', 'model_years', and 'metrics' keys. + data : dict[str, Any] + Serialized palette dictionary. Returns ------- ColorPalette - A new :class:`ColorPalette` instance populated with the provided colors. - Invalid values are replaced by the next available color in the theme. The default - theme is Plotly's Prism palette. + A new populated instance. """ - - new_palette = cls() - - # Check if it's the new structured format - if ( - isinstance(palette, dict) - and all(k in palette for k in ["scenarios", "model_years", "metrics"]) - and all(isinstance(v, dict) for v in palette.values()) - ): - # Process each category with appropriate theme - for category_name in ["scenarios", "model_years", "metrics"]: - category_value = palette.get(category_name) + # Extract optional custom themes ("full" palette) + themes_raw = data.get("themes") + custom_scenario_theme: list[str] | None = None + custom_model_year_theme: list[str] | None = None + custom_metric_theme: list[str] | None = None + + if isinstance(themes_raw, dict): + st = themes_raw.get("scenarios") + if isinstance(st, list) and st: + custom_scenario_theme = st + myt = themes_raw.get("model_years") + if isinstance(myt, list) and myt: + custom_model_year_theme = myt + # sectors and end_uses share the metric theme; prefer "sectors" + mt = themes_raw.get("sectors") or themes_raw.get("end_uses") + if isinstance(mt, list) and mt: + custom_metric_theme = mt + + new_palette = cls( + scenario_theme=custom_scenario_theme, + model_year_theme=custom_model_year_theme, + metric_theme=custom_metric_theme, + ) + + # Detect structured format — ignore "themes" key for this check + _category_keys = {"scenarios", "model_years", "metrics", "sectors", "end_uses"} + category_values = {k: v for k, v in data.items() if k in _category_keys} + is_structured = bool(category_values) and all( + isinstance(v, dict) for v in category_values.values() + ) + + if is_structured: + # Map category names to (target_dict, theme) + _category_targets: list[tuple[str, dict[str, str], list[str]]] = [ + ("scenarios", new_palette.scenarios, new_palette.scenario_theme), + ("model_years", new_palette.model_years, new_palette.model_year_theme), + ("sectors", new_palette.sectors, new_palette.metric_theme), + ("end_uses", new_palette.end_uses, new_palette.metric_theme), + ("metrics", new_palette.sectors, new_palette.metric_theme), # legacy compat + ] + + for category_name, target_dict, theme in _category_targets: + category_value = data.get(category_name) if not isinstance(category_value, dict): continue - category_dict: dict[str, str] = category_value - - # Get appropriate color iterator for this category - if category_name == "scenarios": - color_iterator = cycle(colors.qualitative.Bold) # type: ignore[attr-defined] - elif category_name == "model_years": - color_iterator = cycle(colors.sequential.YlOrRd) # type: ignore[attr-defined] - else: # metrics - color_iterator = cycle(colors.qualitative.Prism) # type: ignore[attr-defined] + color_iterator = cycle(theme) # Sort model years as integers before processing to ensure proper color gradient - items = list(category_dict.items()) + items = list(category_value.items()) if category_name == "model_years": items.sort(key=lambda x: int(x[0]) if x[0].isdigit() else 0) for key, color in items: - # Normalize key to lowercase normalized_key = key.lower() if not (hex_color_pattern.match(color) or rgb_color_pattern.match(color)): color = next(color_iterator) - if category_name == "scenarios": - new_palette.scenarios[normalized_key] = color - elif category_name == "model_years": - new_palette.model_years[normalized_key] = color - elif category_name == "metrics": - new_palette.metrics[normalized_key] = color + # Skip duplicates (e.g. legacy metrics overlapping with sectors) + if normalized_key not in target_dict: + target_dict[normalized_key] = color else: - # Legacy flat format - default to metrics - metric_iterator = cycle(colors.qualitative.Prism) # type: ignore[attr-defined] - for key, color_value in palette.items(): + # Legacy flat format - default to sectors + metric_iterator = cycle(new_palette.metric_theme) + for key, color_value in data.items(): if not isinstance(color_value, str): continue - # Normalize key to lowercase normalized_key = key.lower() if not ( hex_color_pattern.match(color_value) or rgb_color_pattern.match(color_value) ): color_value = next(metric_iterator) - new_palette.metrics[normalized_key] = color_value + new_palette.sectors[normalized_key] = color_value return new_palette - def refresh_category_colors(self, category: str) -> None: + def refresh_category_colors(self, category: ColorCategory | str) -> None: """Reassign colors for all items in a category using the correct theme. - This is useful for fixing palettes that may have been assigned incorrect - colors or to refresh colors after theme changes. - Parameters ---------- - category : str - The category to refresh: 'scenarios', 'model_years', or 'metrics' - - Raises - ------ - ValueError - If category is not a valid category name - - Examples - -------- - >>> palette.refresh_category_colors("metrics") + category : ColorCategory | str + The category to refresh. The legacy string ``"metrics"`` + refreshes both ``SECTOR`` and ``END_USE``. """ - if category == "scenarios": - labels = list(self.scenarios.keys()) - self.scenarios.clear() - for label in labels: - self.update(label, category="scenarios") - elif category == "model_years": - labels = list(self.model_years.keys()) - # Sort model years as integers so earliest gets yellow, latest gets red + resolved = self._resolve_str_category(category) + if resolved is None: + return + + target_dict, _ = self._get_target(resolved) + labels = list(target_dict.keys()) + if resolved == ColorCategory.MODEL_YEAR: labels.sort(key=lambda x: int(x) if x.isdigit() else 0) - self.model_years.clear() - for label in labels: - self.update(label, category="model_years") - elif category == "metrics": - labels = list(self.metrics.keys()) - self.metrics.clear() - for label in labels: - self.update(label, category="metrics") - else: - msg = ( - f"Invalid category: {category}. Must be 'scenarios', 'model_years', or 'metrics'." + target_dict.clear() + + # Reset the iterator for this category + theme = ( + self.metric_theme + if resolved in (ColorCategory.SECTOR, ColorCategory.END_USE) + else ( + self.scenario_theme + if resolved == ColorCategory.SCENARIO + else self.model_year_theme ) - raise ValueError(msg) + ) + new_iter = cycle(theme) + if resolved == ColorCategory.SCENARIO: + self._scenario_iterator = new_iter + elif resolved == ColorCategory.MODEL_YEAR: + self._model_year_iterator = new_iter + elif resolved == ColorCategory.SECTOR: + self._sector_iterator = new_iter + elif resolved == ColorCategory.END_USE: + self._end_use_iterator = new_iter + + for label in labels: + self.update(label, category=resolved) + + def merge_with_project_dimensions( + self, + scenarios: list[str] | None = None, + model_years: list[str] | None = None, + sectors: list[str] | None = None, + end_uses: list[str] | None = None, + ) -> None: + """Merge this palette with a project's actual dimensions. + + For each category the logic is: + + 1. **Matched names** — entries present in both the palette and the + project keep their stored color. + 2. **Reserve collection** — entries in the palette but *not* in the + project are set aside as reserves. Their colors are returned to + the front of the available-color pool so they are reused before + cycling through the theme. + 3. **New-name assignment** — names the project has but the palette + does not are assigned colors by drawing first from the reserve + pool, then from the theme (skipping colors already claimed by + matched entries). + + After merging, the category dict is reordered so that project-active + entries come first (in the order given), followed by reserves. + + Parameters + ---------- + scenarios : list[str] | None + Scenario names present in the project. + model_years : list[str] | None + Model year labels (as strings) present in the project. + sectors : list[str] | None + Sector names present in the project. + end_uses : list[str] | None + End-use names present in the project. + """ + _plan: list[tuple[list[str] | None, ColorCategory]] = [ + (scenarios, ColorCategory.SCENARIO), + (model_years, ColorCategory.MODEL_YEAR), + (sectors, ColorCategory.SECTOR), + (end_uses, ColorCategory.END_USE), + ] + + for project_names, cat in _plan: + if project_names is None: + continue + self._merge_category(project_names, cat) + + def _get_theme(self, category: ColorCategory) -> list[str]: + """Return the color theme list for *category*.""" + if category == ColorCategory.SCENARIO: + return self.scenario_theme + if category == ColorCategory.MODEL_YEAR: + return self.model_year_theme + return self.metric_theme + + def _reset_iterator(self, category: ColorCategory, advance: int) -> None: + """Reset *category*'s iterator, advanced past *advance* entries.""" + theme = self._get_theme(category) + new_iter = cycle(theme) + for _ in range(advance): + next(new_iter) + attr = { + ColorCategory.SCENARIO: "_scenario_iterator", + ColorCategory.MODEL_YEAR: "_model_year_iterator", + ColorCategory.SECTOR: "_sector_iterator", + ColorCategory.END_USE: "_end_use_iterator", + }[category] + setattr(self, attr, new_iter) + + def _merge_category(self, project_names: list[str], category: ColorCategory) -> None: + """Merge a single category with the project's dimension names. + + Order of operations: + 1. Match names present in both palette and project — keep their colors. + 2. Collect reserve entries (palette names not in the project). Their + colors are returned to the front of the available-color pool so + they get reused before cycling through the theme. + 3. Assign colors to new project names by drawing first from reserve + colors, then from the theme (skipping colors used by matches). + """ + target_dict, _ = self._get_target(category) + + # Normalize project names + normalized = [n.lower() for n in project_names] + project_set = set(normalized) + + # 1. Matched names — in both palette and project + matched_colors: dict[str, str] = {} + for name in normalized: + if name in target_dict: + matched_colors[name] = target_dict[name] + + used_colors = set(matched_colors.values()) + + # 2. Reserve entries — in palette but not in project. + # Their colors go back into the available pool (front of the line). + reserve_entries: dict[str, str] = { + k: v for k, v in target_dict.items() if k not in project_set + } + reserve_colors = [v for v in reserve_entries.values() if v not in used_colors] + + # 3. Build color iterator: reserve colors first, then theme (skipping used) + theme = self._get_theme(category) + + def _available_color_iter() -> Any: + """Yield reserve colors first, then theme colors, skipping used.""" + yield from reserve_colors + seen_skip: set[str] = set() + for color in cycle(theme): + if color in used_colors and color not in seen_skip: + seen_skip.add(color) + continue + yield color + + color_iter = _available_color_iter() + + # Assign colors to new (unmatched) project names + unmatched_names = [n for n in normalized if n not in target_dict] + new_assignments: dict[str, str] = {} + for name in unmatched_names: + new_assignments[name] = next(color_iter) + + # Rebuild the category dict: active entries (in project order), + # then reserves + target_dict.clear() + for name in normalized: + if name in matched_colors: + target_dict[name] = matched_colors[name] + else: + target_dict[name] = new_assignments[name] + target_dict.update(reserve_entries) + + self._reset_iterator(category, len(target_dict)) + + if category == ColorCategory.MODEL_YEAR: + self._sort_model_years() def get_display_items( - self, category: str | None = None + self, category: ColorCategory | str | None = None ) -> dict[str, list[tuple[str, str, str]]]: """Get palette items formatted for display with proper capitalization. Returns tuples of (display_label, lowercase_key, color) for each item. + """ + resolved = self._resolve_str_category(category) if category is not None else None - Parameters - ---------- - category : str | None, optional - Specific category to get: 'scenarios', 'model_years', or 'metrics'. - If None, returns all categories. + def _fmt(d: dict[str, str]) -> list[tuple[str, str, str]]: + return [(k.capitalize(), k, c) for k, c in d.items()] - Returns - ------- - dict[str, list[tuple[str, str, str]]] - Dictionary mapping category names to lists of (display_label, key, color) tuples. - The display_label is capitalized for presentation, while key is the lowercase - lookup key. - - Examples - -------- - >>> palette.get_display_items("metrics") - {'metrics': [('Industrial', 'industrial', 'rgb(95, 70, 144)'), ...]} - """ - result: dict[str, list[tuple[str, str, str]]] = {} - - def format_items(items_dict: dict[str, str]) -> list[tuple[str, str, str]]: - """Convert dict items to display tuples.""" - return [(key.capitalize(), key, color) for key, color in items_dict.items()] - - if category is None: - # Return all categories - if self.scenarios: - result["scenarios"] = format_items(self.scenarios) - if self.model_years: - result["model_years"] = format_items(self.model_years) - if self.metrics: - result["metrics"] = format_items(self.metrics) - elif category == "scenarios": - result["scenarios"] = format_items(self.scenarios) - elif category == "model_years": - result["model_years"] = format_items(self.model_years) - elif category == "metrics": - result["metrics"] = format_items(self.metrics) - else: - msg = ( - f"Invalid category: {category}. Must be 'scenarios', 'model_years', or 'metrics'." - ) - raise ValueError(msg) + groups: dict[str, dict[str, str]] = { + "scenarios": self.scenarios, + "model_years": self.model_years, + "sectors": self.sectors, + "end_uses": self.end_uses, + } - return result + if resolved is None: + return {name: _fmt(d) for name, d in groups.items() if d} + + _cat_to_group = { + ColorCategory.SCENARIO: "scenarios", + ColorCategory.MODEL_YEAR: "model_years", + ColorCategory.SECTOR: "sectors", + ColorCategory.END_USE: "end_uses", + } + group_name = _cat_to_group.get(resolved) + if group_name: + d = groups[group_name] + return {group_name: _fmt(d)} if d else {} + + return {} - def to_dict(self) -> dict[str, dict[str, str]]: + def to_dict(self) -> dict[str, Any]: """Serializes the internal palette to a structured dictionary. + Includes a ``"themes"`` key when the palette uses custom color + cycles (a "full" palette). Minimal palettes omit it. + Returns ------- - dict[str, dict[str, str]] - A dictionary with 'scenarios', 'model_years', and 'metrics' keys, - each containing a mapping of labels to corresponding hex color strings. + dict[str, Any] + A dictionary with 'scenarios', 'model_years', 'sectors', and + 'end_uses' keys, each mapping labels to hex color strings. + Optionally includes 'themes' for full palettes. + """ + result: dict[str, Any] = { + "scenarios": self.scenarios.copy(), + "model_years": self.model_years.copy(), + "sectors": self.sectors.copy(), + "end_uses": self.end_uses.copy(), + } + if self.has_custom_themes: + result["themes"] = { + "scenarios": list(self.scenario_theme), + "model_years": list(self.model_year_theme), + "sectors": list(self.metric_theme), + "end_uses": list(self.metric_theme), + } + return result + + def to_dict_legacy(self) -> dict[str, dict[str, str]]: + """Serializes the palette using the legacy 3-key format. + + Sectors and end-uses are merged under a single ``"metrics"`` key. + Prefer :meth:`to_dict` for new code. """ return { "scenarios": self.scenarios.copy(), "model_years": self.model_years.copy(), - "metrics": self.metrics.copy(), + "metrics": {**self.sectors, **self.end_uses}, } def to_flat_dict(self) -> dict[str, str]: @@ -500,7 +887,8 @@ def to_flat_dict(self) -> dict[str, str]: result = {} result.update(self.scenarios) result.update(self.model_years) - result.update(self.metrics) + result.update(self.sectors) + result.update(self.end_uses) return result def move_item_up(self, items: MutableSequence[dict[str, Any]], index: int) -> bool: @@ -562,7 +950,8 @@ def palette_to_grouped_items( Parameters ---------- palette : dict[str, dict[str, str]] - Structured palette with 'scenarios', 'model_years', 'metrics' categories + Structured palette with 'scenarios', 'model_years', 'sectors', + and 'end_uses' categories (also accepts legacy 'metrics' key). Returns ------- @@ -575,17 +964,26 @@ def palette_to_grouped_items( category_display_names = { "scenarios": "Scenarios", "model_years": "Model Years", - "metrics": "Metrics", + "sectors": "Sectors", + "end_uses": "End Uses", + "metrics": "Sectors", # legacy compat } - for category_name in ["scenarios", "model_years", "metrics"]: + for category_name in ["scenarios", "model_years", "sectors", "end_uses", "metrics"]: category_dict = palette.get(category_name, {}) if category_dict: items: list[dict[str, Any]] = [] for order, (label, color) in enumerate(category_dict.items()): items.append({"label": label, "color": color, "order": order}) display_name = category_display_names.get(category_name, category_name) - result[display_name] = items + if display_name in result: + # Append to existing group (e.g. legacy metrics merging into Sectors) + offset = len(result[display_name]) + for item in items: + item["order"] += offset + result[display_name].extend(items) + else: + result[display_name] = items return result @@ -603,19 +1001,22 @@ def grouped_items_to_palette( Returns ------- dict[str, dict[str, str]] - Structured palette with 'scenarios', 'model_years', 'metrics' categories + Structured palette with 'scenarios', 'model_years', 'sectors', + and 'end_uses' categories. """ # Map display names back to internal names display_to_category = { "Scenarios": "scenarios", "Model Years": "model_years", - "Metrics": "metrics", + "Sectors": "sectors", + "End Uses": "end_uses", } palette: dict[str, dict[str, str]] = { "scenarios": {}, "model_years": {}, - "metrics": {}, + "sectors": {}, + "end_uses": {}, } for display_name, items in grouped_items.items(): diff --git a/src/stride/ui/palette_utils.py b/src/stride/ui/palette_utils.py new file mode 100644 index 0000000..ded6e63 --- /dev/null +++ b/src/stride/ui/palette_utils.py @@ -0,0 +1,196 @@ +"""Utility functions for managing user palettes. + +Provides functions for saving, loading, and listing user palettes stored +in ``~/.stride/palettes/``, as well as managing palette-related settings +(default palette, palette priority) in the Stride configuration file. +""" + +import json +from pathlib import Path + +from stride.config import ( + load_stride_config, + save_stride_config, +) +from stride.ui.palette import ColorPalette + + +def get_user_palette_dir() -> Path: + """Get the user's palette directory, creating it if necessary. + + Returns + ------- + Path + Path to ~/.stride/palettes/ + """ + palette_dir = Path.home() / ".stride" / "palettes" + palette_dir.mkdir(parents=True, exist_ok=True) + return palette_dir + + +def list_user_palettes() -> list[Path]: + """List all user palettes. + + Returns + ------- + list[Path] + List of paths to user palette files + """ + palette_dir = get_user_palette_dir() + return sorted(palette_dir.glob("*.json")) + + +def save_user_palette(name: str, palette: dict[str, str] | dict[str, dict[str, str]]) -> Path: + """Save a palette to the user's palette directory. + + Parameters + ---------- + name : str + Name for the palette (will be used as filename) + palette : dict[str, str] | dict[str, dict[str, str]] + Palette dictionary to save (either flat or structured format) + + Returns + ------- + Path + Path to the saved palette file + """ + palette_dir = get_user_palette_dir() + palette_path = palette_dir / f"{name}.json" + + data = { + "name": name, + "palette": palette, + } + + with open(palette_path, "w") as f: + json.dump(data, f, indent=2) + + return palette_path + + +def load_user_palette(name: str) -> ColorPalette: + """Load a user palette by name. + + Parameters + ---------- + name : str + Name of the palette to load + + Returns + ------- + ColorPalette + Loaded color palette + + Raises + ------ + FileNotFoundError + If the palette does not exist + """ + palette_dir = get_user_palette_dir() + palette_path = palette_dir / f"{name}.json" + + if not palette_path.exists(): + msg = f"User palette '{name}' not found" + raise FileNotFoundError(msg) + + with open(palette_path) as f: + data = json.load(f) + # Handle both nested {"palette": {...}} and flat {...} structures + if isinstance(data, dict): + if "palette" in data: + palette_dict = data["palette"] + else: + palette_dict = data + else: + msg = f"Invalid palette format in {name}.json" + raise ValueError(msg) + + return ColorPalette.from_dict(palette_dict) + + +def delete_user_palette(name: str) -> None: + """Delete a user palette by name. + + Parameters + ---------- + name : str + Name of the palette to delete + + Raises + ------ + FileNotFoundError + If the palette does not exist + """ + palette_dir = get_user_palette_dir() + palette_path = palette_dir / f"{name}.json" + + if not palette_path.exists(): + msg = f"User palette '{name}' not found" + raise FileNotFoundError(msg) + + palette_path.unlink() + + +def set_default_user_palette(name: str | None) -> None: + """Set the default user palette. + + Parameters + ---------- + name : str | None + Name of the user palette to set as default, or None to clear the default + """ + config = load_stride_config() + + if name is None: + config.pop("default_user_palette", None) + else: + # Verify the palette exists + palette_dir = get_user_palette_dir() + palette_path = palette_dir / f"{name}.json" + if not palette_path.exists(): + msg = f"User palette '{name}' not found at {palette_path}" + raise FileNotFoundError(msg) + config["default_user_palette"] = name + + save_stride_config(config) + + +def get_default_user_palette() -> str | None: + """Get the default user palette name. + + Returns + ------- + str | None + Name of the default user palette, or None if not set + """ + config = load_stride_config() + return config.get("default_user_palette") + + +def set_palette_priority(priority: str) -> None: + """Set the palette priority for dashboard launch. + + Parameters + ---------- + priority : str + Priority setting: "user" to prefer user palette, "project" to prefer project palette + """ + if priority not in ("user", "project"): + msg = f"Invalid palette priority: {priority!r}. Must be 'user' or 'project'." + raise ValueError(msg) + config = load_stride_config() + config["palette_priority"] = priority + save_stride_config(config) + + +def get_palette_priority() -> str: + """Get the palette priority for dashboard launch. + + Returns + ------- + str + Priority setting: "user" or "project". Defaults to "user". + """ + config = load_stride_config() + return str(config.get("palette_priority", "user")) diff --git a/src/stride/ui/plotting/__init__.py b/src/stride/ui/plotting/__init__.py index 1ab84a9..3ee4a90 100644 --- a/src/stride/ui/plotting/__init__.py +++ b/src/stride/ui/plotting/__init__.py @@ -4,13 +4,16 @@ import plotly.graph_objects as go from . import facets, simple +from .utils import DEFAULT_PLOTLY_TEMPLATE + +from stride.ui.palette import ColorCategory if TYPE_CHECKING: from stride.ui.color_manager import ColorManager class StridePlots: - def __init__(self, color_generator: "ColorManager", template: str = "plotly_white"): + def __init__(self, color_generator: "ColorManager", template: str = DEFAULT_PLOTLY_TEMPLATE): """ Initialize StridePlots with a color generator function. @@ -78,11 +81,20 @@ def grouped_single_bars( return fig def grouped_multi_bars( - self, df: pd.DataFrame, x_group: str = "scenario", y_group: str = "end_use" + self, + df: pd.DataFrame, + x_group: str = "scenario", + y_group: str = "end_use", + breakdown_type: ColorCategory | None = None, ) -> go.Figure: """Create grouped and multi-level bar chart.""" fig = simple.grouped_multi_bars( - df, self._color_generator, x_group, y_group, template=self._template + df, + self._color_generator, + x_group, + y_group, + template=self._template, + breakdown_type=breakdown_type, ) fig.update_layout(template=self._template) return fig @@ -95,6 +107,7 @@ def grouped_stacked_bars( stack_col: str = "metric", value_col: str = "demand", show_scenario_indicators: bool = True, + breakdown_type: ColorCategory | None = None, ) -> go.Figure: """Create grouped and stacked bar chart.""" fig = simple.grouped_stacked_bars( @@ -106,16 +119,26 @@ def grouped_stacked_bars( value_col, show_scenario_indicators, template=self._template, + breakdown_type=breakdown_type, ) fig.update_layout(template=self._template) return fig def time_series( - self, df: pd.DataFrame, group_by: str | None = None, chart_type: str = "Line" + self, + df: pd.DataFrame, + group_by: str | None = None, + chart_type: str = "Line", + breakdown_type: ColorCategory | None = None, ) -> go.Figure: """Plot time series data for multiple years of a single scenario.""" fig = simple.time_series( - df, self._color_generator, group_by, chart_type, template=self._template + df, + self._color_generator, + group_by, + chart_type, + template=self._template, + breakdown_type=breakdown_type, ) fig.update_layout(template=self._template) return fig @@ -138,10 +161,17 @@ def faceted_time_series( chart_type: str = "Line", group_by: str | None = None, value_col: str = "value", + breakdown_type: ColorCategory | None = None, ) -> go.Figure: """Create faceted subplots for each scenario with shared legend.""" fig = facets.faceted_time_series( - df, self._color_generator, chart_type, group_by, value_col, template=self._template + df, + self._color_generator, + chart_type, + group_by, + value_col, + template=self._template, + breakdown_type=breakdown_type, ) fig.update_layout(template=self._template) return fig diff --git a/src/stride/ui/plotting/facets.py b/src/stride/ui/plotting/facets.py index 97ee9fc..f72d357 100644 --- a/src/stride/ui/plotting/facets.py +++ b/src/stride/ui/plotting/facets.py @@ -4,14 +4,17 @@ import plotly.graph_objects as go from plotly.subplots import make_subplots +from stride.ui.palette import ColorCategory + from .utils import ( + DEFAULT_PLOTLY_TEMPLATE, TRANSPARENT, calculate_subplot_layout, create_faceted_traces, create_seasonal_annotations, determine_facet_layout, + get_axis_style, get_hoverlabel_style, - get_plotly_template, update_faceted_layout, ) @@ -47,7 +50,7 @@ def add_seasonal_line_traces( "mode": "lines", "name": str(year), "line": dict( - color=color_generator.get_color(str(year)), + color=color_generator.get_color(str(year), ColorCategory.MODEL_YEAR), dash=line_styles[j % len(line_styles)], shape="spline", ), @@ -61,7 +64,7 @@ def add_seasonal_line_traces( def seasonal_load_lines( - df: pd.DataFrame, color_generator: "ColorManager", template: str = "plotly_dark" + df: pd.DataFrame, color_generator: "ColorManager", template: str = DEFAULT_PLOTLY_TEMPLATE ) -> go.Figure: """Create faceted subplots for seasonal load lines.""" if df.empty: @@ -94,11 +97,13 @@ def seasonal_load_lines( add_seasonal_line_traces(fig, df, layout_config, color_generator) # Update layout + axis = get_axis_style(template) + if layout_config["facet_col"]: annotations_list = create_seasonal_annotations(layout_config) fig.update_layout( - template=get_plotly_template(), + template=template, plot_bgcolor=TRANSPARENT, paper_bgcolor=TRANSPARENT, margin=dict(l=60, r=20, t=80, b=80), @@ -114,19 +119,19 @@ def seasonal_load_lines( range=[0, 23], showgrid=True, gridwidth=1, - gridcolor="lightgray", + gridcolor=axis["grid_color"], tickvals=[0, 6, 12, 18, 23], ticktext=["0", "6", "12", "18", "23"], showline=True, linewidth=1, - linecolor="black", + linecolor=axis["axis_color"], mirror=True, title_text="", ) fig.update_yaxes( showline=True, linewidth=1, - linecolor="black", + linecolor=axis["axis_color"], mirror=True, title_text="", ) @@ -138,14 +143,14 @@ def seasonal_load_lines( fig.add_vline( x=hour, line_dash="dot", - line_color="lightgray", + line_color=axis["vline_color"], line_width=1, row=row_idx, col=col_idx, ) else: fig.update_layout( - template=get_plotly_template(), + template=template, plot_bgcolor=TRANSPARENT, paper_bgcolor=TRANSPARENT, margin=dict(l=20, r=20, t=20, b=40), @@ -155,15 +160,15 @@ def seasonal_load_lines( range=[0, 23], showgrid=True, gridwidth=1, - gridcolor="lightgray", + gridcolor=axis["grid_color"], tickvals=[0, 6, 12, 18, 23], ticktext=["0", "6", "12", "18", "23"], showline=True, linewidth=1, - linecolor="black", + linecolor=axis["axis_color"], mirror=True, ), - yaxis=dict(showline=True, linewidth=1, linecolor="black", mirror=True), + yaxis=dict(showline=True, linewidth=1, linecolor=axis["axis_color"], mirror=True), legend=dict(orientation="v", yanchor="top", y=1, xanchor="left", x=1.02), hoverlabel=hoverlabel_style, hovermode="x unified", @@ -171,7 +176,7 @@ def seasonal_load_lines( # Add vertical lines for single plot for hour in [6, 12, 18]: - fig.add_vline(x=hour, line_dash="dot", line_color="lightgray", line_width=1) + fig.add_vline(x=hour, line_dash="dot", line_color=axis["vline_color"], line_width=1) return fig @@ -200,6 +205,7 @@ def _add_stacked_area_traces( color_generator: "ColorManager", breakdown_col: str | None, breakdown_categories: list[str], + breakdown_type: ColorCategory | None = None, ) -> None: """Add stacked area traces to the figure for each facet.""" for i, facet_value in enumerate(layout_config["facet_categories"]): @@ -227,7 +233,11 @@ def _add_stacked_area_traces( "y": category_df["value"], "mode": "lines", "name": category, - "line": dict(color=color_generator.get_color(category)), + "line": dict( + color=color_generator.get_color( + category, breakdown_type or ColorCategory.SECTOR + ) + ), "fill": "tonexty" if j > 0 else "tozeroy", "stackgroup": f"facet_{i}" if layout_config["facet_col"] else "one", "showlegend": show_legend, @@ -252,7 +262,8 @@ def _add_stacked_area_traces( "name": str(facet_value) if layout_config["facet_col"] else "Load", "line": dict( color=color_generator.get_color( - str(facet_value) if layout_config["facet_col"] else "Load" + str(facet_value) if layout_config["facet_col"] else "Load", + breakdown_type or ColorCategory.SECTOR, ) ), "fill": "tozeroy", @@ -266,7 +277,7 @@ def _add_stacked_area_traces( def seasonal_load_area( - df: pd.DataFrame, color_generator: "ColorManager", template: str = "plotly_dark" + df: pd.DataFrame, color_generator: "ColorManager", template: str = DEFAULT_PLOTLY_TEMPLATE ) -> go.Figure: """Create faceted area charts for seasonal load patterns.""" if df.empty: @@ -299,15 +310,23 @@ def seasonal_load_area( # Add area traces _add_stacked_area_traces( - fig, df, layout_config, color_generator, breakdown_col, breakdown_categories + fig, + df, + layout_config, + color_generator, + breakdown_col, + breakdown_categories, + breakdown_type=ColorCategory.END_USE, ) # Update layout + axis = get_axis_style(template) + if layout_config["facet_col"]: annotations_list = create_seasonal_annotations(layout_config) fig.update_layout( - template=get_plotly_template(), + template=template, plot_bgcolor=TRANSPARENT, paper_bgcolor=TRANSPARENT, margin=dict(l=60, r=20, t=80, b=80), @@ -325,19 +344,25 @@ def seasonal_load_area( range=[0, 23], showgrid=True, gridwidth=1, - gridcolor="lightgray", + gridcolor=axis["grid_color"], tickvals=[0, 6, 12, 18, 23], ticktext=["0", "6", "12", "18", "23"], showline=True, linewidth=1, - linecolor="black", + linecolor=axis["axis_color"], + mirror=True, + title_text="", + ) + fig.update_yaxes( + showline=True, + linewidth=1, + linecolor=axis["axis_color"], mirror=True, title_text="", ) - fig.update_yaxes(showline=True, linewidth=1, linecolor="black", mirror=True, title_text="") else: fig.update_layout( - template=get_plotly_template(), + template=template, plot_bgcolor=TRANSPARENT, paper_bgcolor=TRANSPARENT, margin=dict(l=20, r=20, t=20, b=40), @@ -347,15 +372,15 @@ def seasonal_load_area( range=[0, 23], showgrid=True, gridwidth=1, - gridcolor="lightgray", + gridcolor=axis["grid_color"], tickvals=[0, 6, 12, 18, 23], ticktext=["0", "6", "12", "18", "23"], showline=True, linewidth=1, - linecolor="black", + linecolor=axis["axis_color"], mirror=True, ), - yaxis=dict(showline=True, linewidth=1, linecolor="black", mirror=True), + yaxis=dict(showline=True, linewidth=1, linecolor=axis["axis_color"], mirror=True), showlegend=has_breakdown, legend=dict(orientation="v", yanchor="top", y=1, xanchor="left", x=1.02) if has_breakdown @@ -373,7 +398,8 @@ def faceted_time_series( chart_type: str = "Line", group_by: str | None = None, value_col: str = "value", - template: str = "plotly_dark", + template: str = DEFAULT_PLOTLY_TEMPLATE, + breakdown_type: ColorCategory | None = None, ) -> go.Figure: """ Create faceted subplots for each scenario with shared legend. @@ -415,13 +441,13 @@ def faceted_time_series( # Create and add traces traces_info = create_faceted_traces( - df, scenarios, color_generator, chart_type, group_by, value_col + df, scenarios, color_generator, chart_type, group_by, value_col, breakdown_type ) for trace, row, col in traces_info: fig.add_trace(trace, row=row, col=col) # Update layout - update_faceted_layout(fig, rows, group_by) + update_faceted_layout(fig, rows, group_by, template=template) # Add hover styling fig.update_layout( diff --git a/src/stride/ui/plotting/simple.py b/src/stride/ui/plotting/simple.py index 8d0a404..208cf65 100644 --- a/src/stride/ui/plotting/simple.py +++ b/src/stride/ui/plotting/simple.py @@ -3,13 +3,16 @@ import pandas as pd import plotly.graph_objects as go +from stride.ui.palette import ColorCategory + from .utils import ( DEFAULT_BAR_COLOR, + DEFAULT_PLOTLY_TEMPLATE, TRANSPARENT, create_time_series_area_traces, create_time_series_line_traces, + get_axis_style, get_hoverlabel_style, - get_plotly_template, get_time_series_breakdown_info, ) @@ -23,7 +26,7 @@ def grouped_single_bars( color_generator: "ColorManager", use_color_manager: bool = True, fixed_color: str | None = None, - template: str = "plotly_dark", + template: str = DEFAULT_PLOTLY_TEMPLATE, ) -> go.Figure: """ Create a bar plot with 2 levels of x axis. @@ -41,7 +44,7 @@ def grouped_single_bars( fixed_color : str | None, optional Fixed color to use for all bars (overrides use_color_manager), by default None template : str, optional - Plotly template name for theme-aware styling, by default "plotly_dark" + Plotly template name for theme-aware styling Returns ------- @@ -83,8 +86,10 @@ def grouped_single_bars( ) ) + axis = get_axis_style(template) + fig.update_layout( - template=get_plotly_template(), + template=template, plot_bgcolor=TRANSPARENT, paper_bgcolor=TRANSPARENT, margin_b=0, @@ -94,6 +99,8 @@ def grouped_single_bars( barmode="group", hoverlabel=hoverlabel_style, hovermode="x unified", + xaxis=dict(gridcolor=axis["grid_color"], linecolor=axis["axis_color"]), + yaxis=dict(gridcolor=axis["grid_color"], linecolor=axis["axis_color"]), ) return fig @@ -104,7 +111,8 @@ def grouped_multi_bars( color_generator: "ColorManager", x_group: str = "scenario", y_group: str = "end_use", - template: str = "plotly_dark", + template: str = DEFAULT_PLOTLY_TEMPLATE, + breakdown_type: ColorCategory | None = None, ) -> go.Figure: """ Create grouped and multi-level bar chart. @@ -120,7 +128,7 @@ def grouped_multi_bars( y_group : str, optional Secondary grouping column (creates stacked bars), by default "end_use" template : str, optional - Plotly template name for theme-aware styling, by default "plotly_dark" + Plotly template name for theme-aware styling Returns ------- @@ -150,7 +158,9 @@ def grouped_multi_bars( go.Bar( x=df_subset["year"].astype(str), y=df_subset["value"], - marker_color=color_generator.get_color(y_value), + marker_color=color_generator.get_color( + y_value, breakdown_type or ColorCategory.SECTOR + ), name=y_value, offsetgroup=x_value, legendgroup=y_value, @@ -182,7 +192,7 @@ def grouped_multi_bars( legendgrouptitle_text="Scenarios" if not scenario_title_added else None, legendrank=2, marker=dict( - color=color_generator.get_color(x_value), + color=color_generator.get_color(x_value, ColorCategory.SCENARIO), pattern_shape="/", pattern_solidity=0.3, ), @@ -195,9 +205,11 @@ def grouped_multi_bars( if not scenario_title_added: scenario_title_added = True + axis = get_axis_style(template) + fig = go.Figure(data=bars) fig.update_layout( - template=get_plotly_template(), + template=template, plot_bgcolor=TRANSPARENT, paper_bgcolor=TRANSPARENT, margin_b=0, @@ -205,7 +217,12 @@ def grouped_multi_bars( margin_l=20, margin_r=20, barmode="stack", - yaxis=dict(range=[-indicator_height * 4, max_value * 1.1]), + yaxis=dict( + range=[-indicator_height * 4, max_value * 1.1], + gridcolor=axis["grid_color"], + linecolor=axis["axis_color"], + ), + xaxis=dict(gridcolor=axis["grid_color"], linecolor=axis["axis_color"]), legend=dict( itemclick="toggle", # Single click toggles sectors (or scenario indicators) itemdoubleclick=False, # Disabled - can't handle 2D toggling properly @@ -225,7 +242,8 @@ def grouped_stacked_bars( stack_col: str = "metric", value_col: str = "demand", show_scenario_indicators: bool = True, - template: str = "plotly_dark", + template: str = DEFAULT_PLOTLY_TEMPLATE, + breakdown_type: ColorCategory | None = None, ) -> go.Figure: """ Create grouped and stacked bar chart. @@ -250,7 +268,7 @@ def grouped_stacked_bars( show_scenario_indicators : bool, optional Whether to show hatched scenario indicator bars, by default True template : str, optional - Plotly template name for theme-aware styling, by default "plotly_dark" + Plotly template name for theme-aware styling Returns ------- @@ -291,7 +309,9 @@ def grouped_stacked_bars( legendgrouptitle_text=stack_col.replace("_", " ").title() if not stack_group_title_added else None, - marker_color=color_generator.get_color(stack_cat), + marker_color=color_generator.get_color( + stack_cat, breakdown_type or ColorCategory.SECTOR + ), offsetgroup=group, legendrank=1, showlegend=stack_cat not in added_stack_legend, @@ -317,7 +337,7 @@ def grouped_stacked_bars( legendgrouptitle_text="Scenarios" if not scenario_title_added else None, legendrank=2, marker=dict( - color=color_generator.get_color(group), + color=color_generator.get_color(group, ColorCategory.SCENARIO), pattern_shape="/", pattern_solidity=0.3, ), @@ -333,8 +353,10 @@ def grouped_stacked_bars( # Adjust y-axis range based on whether scenario indicators are shown y_min = -indicator_height * 4 if show_scenario_indicators else 0 + axis = get_axis_style(template) + fig.update_layout( - template=get_plotly_template(), + template=template, plot_bgcolor=TRANSPARENT, paper_bgcolor=TRANSPARENT, margin_b=50, @@ -342,7 +364,12 @@ def grouped_stacked_bars( margin_l=20, margin_r=20, barmode="stack", - yaxis=dict(range=[y_min, max_value * 1.1]), + yaxis=dict( + range=[y_min, max_value * 1.1], + gridcolor=axis["grid_color"], + linecolor=axis["axis_color"], + ), + xaxis=dict(gridcolor=axis["grid_color"], linecolor=axis["axis_color"]), legend=dict( itemclick="toggle", # Single click toggles sectors (or scenario indicators) itemdoubleclick=False, # Disabled - can't handle 2D toggling properly @@ -359,7 +386,8 @@ def time_series( color_generator: "ColorManager", group_by: str | None = None, chart_type: str = "Line", - template: str = "plotly_dark", + template: str = DEFAULT_PLOTLY_TEMPLATE, + breakdown_type: ColorCategory | None = None, ) -> go.Figure: """ Plot time series data for multiple years of a single scenario. @@ -376,7 +404,7 @@ def time_series( chart_type : str, optional "Line" or "Area" chart type, by default "Line" template : str, optional - Plotly template name for theme-aware styling, by default "plotly_dark" + Plotly template name for theme-aware styling Returns ------- @@ -404,21 +432,29 @@ def time_series( else: # Create traces based on chart type if chart_type == "Area": - traces = create_time_series_area_traces(df, color_generator, breakdown_info) + traces = create_time_series_area_traces( + df, color_generator, breakdown_info, breakdown_type + ) else: # Line chart - traces = create_time_series_line_traces(df, color_generator, breakdown_info) + traces = create_time_series_line_traces( + df, color_generator, breakdown_info, breakdown_type + ) # Add all traces to figure for trace in traces: fig.add_trace(trace) + axis = get_axis_style(template) + fig.update_layout( - template=get_plotly_template(), + template=template, plot_bgcolor=TRANSPARENT, paper_bgcolor=TRANSPARENT, margin=dict(l=20, r=20, t=20, b=40), xaxis_title="Time Period", yaxis_title="Average Power Demand (MW)", + xaxis=dict(gridcolor=axis["grid_color"], linecolor=axis["axis_color"]), + yaxis=dict(gridcolor=axis["grid_color"], linecolor=axis["axis_color"]), legend=dict(orientation="v", yanchor="top", y=1, xanchor="left", x=1.02), hoverlabel=hoverlabel_style, hovermode="x unified", @@ -428,7 +464,7 @@ def time_series( def demand_curve( - df: pd.DataFrame, color_generator: "ColorManager", template: str = "plotly_dark" + df: pd.DataFrame, color_generator: "ColorManager", template: str = DEFAULT_PLOTLY_TEMPLATE ) -> go.Figure: """ Create a load duration curve plot. @@ -441,7 +477,7 @@ def demand_curve( color_generator : Callable[[str], str] Color generator function template : str, optional - Plotly template name for theme-aware styling, by default "plotly_dark" + Plotly template name for theme-aware styling Returns ------- @@ -459,14 +495,16 @@ def demand_curve( x=df.index.values, y=df[scenario], mode="lines", - marker=dict(color=color_generator.get_color(scenario)), + marker=dict(color=color_generator.get_color(scenario, ColorCategory.SCENARIO)), name=scenario, showlegend=True, hovertemplate="%{fullData.name}: %{y:.2f}", ) ) + axis = get_axis_style(template) + fig.update_layout( - template=get_plotly_template(), + template=template, plot_bgcolor=TRANSPARENT, paper_bgcolor=TRANSPARENT, margin_b=0, @@ -474,7 +512,12 @@ def demand_curve( margin_l=20, margin_r=20, barmode="stack", - yaxis=dict(title="Power Demand (MW)"), + yaxis=dict( + title="Power Demand (MW)", + gridcolor=axis["grid_color"], + linecolor=axis["axis_color"], + ), + xaxis=dict(gridcolor=axis["grid_color"], linecolor=axis["axis_color"]), hoverlabel=hoverlabel_style, hovermode="x unified", ) @@ -486,6 +529,7 @@ def area_plot( color_generator: "ColorManager", scenario_name: str, metric: str = "demand", + template: str = DEFAULT_PLOTLY_TEMPLATE, ) -> go.Figure: """ Create a stacked area plot for a single scenario. @@ -514,21 +558,25 @@ def area_plot( x=end_use_df["year"], y=end_use_df[metric], mode="lines", - line=dict(color=color_generator.get_color(end_use)), + line=dict(color=color_generator.get_color(end_use, ColorCategory.END_USE)), showlegend=False, stackgroup="one", ) ) fig.update_layout(title=scenario_name) + axis = get_axis_style(template) + fig.update_layout( - template=get_plotly_template(), + template=template, plot_bgcolor=TRANSPARENT, paper_bgcolor=TRANSPARENT, margin_b=50, margin_t=50, margin_l=10, margin_r=10, + xaxis=dict(gridcolor=axis["grid_color"], linecolor=axis["axis_color"]), + yaxis=dict(gridcolor=axis["grid_color"], linecolor=axis["axis_color"]), ) return fig diff --git a/src/stride/ui/plotting/utils.py b/src/stride/ui/plotting/utils.py index 6deeed9..9b10e78 100644 --- a/src/stride/ui/plotting/utils.py +++ b/src/stride/ui/plotting/utils.py @@ -4,11 +4,22 @@ import pandas as pd import plotly.graph_objects as go +from stride.ui.palette import ColorCategory + if TYPE_CHECKING: from stride.ui.color_manager import ColorManager TRANSPARENT = "rgba(0, 0, 0, 0)" DEFAULT_BAR_COLOR = "rgba(0,0,200,0.8)" + +# Plotly template names +DEFAULT_PLOTLY_TEMPLATE = "plotly_white" +DARK_PLOTLY_TEMPLATE = "plotly_dark" + +# CSS theme class names (applied to DOM elements) +DEFAULT_CSS_THEME = "light-theme" +DARK_CSS_THEME = "dark-theme" + # Theme-aware neutral gray colors LIGHT_THEME_GRAY = "rgba(100, 100, 100, 0.8)" # Darker gray for light backgrounds DARK_THEME_GRAY = "rgba(180, 180, 180, 0.8)" # Lighter gray for dark backgrounds @@ -16,6 +27,42 @@ DARK_THEME_BG = "rgb(26, 26, 26)" # Dark background matching CSS #1a1a1a LIGHT_THEME_BG = "rgb(255, 255, 255)" # White background matching CSS #ffffff +# Theme-aware grid and axis line colors +LIGHT_GRID_COLOR = "rgba(0, 0, 0, 0.15)" # Visible on white backgrounds +DARK_GRID_COLOR = "rgba(255, 255, 255, 0.15)" # Visible on dark backgrounds +LIGHT_AXIS_COLOR = "rgba(0, 0, 0, 0.4)" # Axis lines on white backgrounds +DARK_AXIS_COLOR = "rgba(255, 255, 255, 0.4)" # Axis lines on dark backgrounds +LIGHT_VLINE_COLOR = "rgba(0, 0, 0, 0.18)" # Subtle reference lines on white +DARK_VLINE_COLOR = "rgba(255, 255, 255, 0.18)" # Subtle reference lines on dark + + +def get_axis_style(template: str) -> dict[str, str]: + """Return theme-aware colors for grids, axis lines, and reference vlines. + + Parameters + ---------- + template : str + Plotly template name (e.g., 'plotly_white', 'plotly_dark') + + Returns + ------- + dict + Keys: ``grid_color``, ``axis_color``, ``vline_color``, ``bg_color`` + """ + if "dark" in template.lower(): + return { + "grid_color": DARK_GRID_COLOR, + "axis_color": DARK_AXIS_COLOR, + "vline_color": DARK_VLINE_COLOR, + "bg_color": DARK_THEME_BG, + } + return { + "grid_color": LIGHT_GRID_COLOR, + "axis_color": LIGHT_AXIS_COLOR, + "vline_color": LIGHT_VLINE_COLOR, + "bg_color": LIGHT_THEME_BG, + } + def get_error_annotation_style(template: str) -> dict[str, Any]: """ @@ -113,19 +160,6 @@ def get_background_color(template: str) -> str: return LIGHT_THEME_BG -def get_plotly_template() -> str: - """ - Get the Plotly template for charts. - - Returns - ------- - str - Plotly template name (defaults to 'plotly_dark' to match app's default theme) - """ - # Default to dark theme to match the app's default - return "plotly_dark" - - def get_hoverlabel_style(template: str) -> dict[str, Any]: """ Get hover label styling based on the current template/theme. @@ -414,7 +448,10 @@ def get_time_series_breakdown_info( def create_time_series_line_traces( - df: pd.DataFrame, color_generator: "ColorManager", breakdown_info: dict[str, Any] + df: pd.DataFrame, + color_generator: "ColorManager", + breakdown_info: dict[str, Any], + breakdown_type: ColorCategory | None = None, ) -> list[go.Scatter]: """ Create line traces for time series data. @@ -448,7 +485,7 @@ def create_time_series_line_traces( mode="lines", name=str(year), line=dict( - color=color_generator.get_color(str(year)), + color=color_generator.get_color(str(year), ColorCategory.MODEL_YEAR), dash=line_styles[i % len(line_styles)], ), showlegend=True, @@ -479,7 +516,9 @@ def create_time_series_line_traces( mode="lines", name=legend_name, line=dict( - color=color_generator.get_color(category), + color=color_generator.get_color( + category, breakdown_type or ColorCategory.SECTOR + ), dash=line_style, ), legendgroup=category, @@ -491,7 +530,10 @@ def create_time_series_line_traces( def create_time_series_area_traces( - df: pd.DataFrame, color_generator: "ColorManager", breakdown_info: dict[str, Any] + df: pd.DataFrame, + color_generator: "ColorManager", + breakdown_info: dict[str, Any], + breakdown_type: ColorCategory | None = None, ) -> list[go.Scatter]: """ Create area traces for time series data. @@ -523,7 +565,9 @@ def create_time_series_area_traces( y=year_df["value"], mode="lines", name=str(year), - line=dict(color=color_generator.get_color(str(year))), + line=dict( + color=color_generator.get_color(str(year), ColorCategory.MODEL_YEAR) + ), fill="tozeroy", showlegend=True, ) @@ -551,7 +595,11 @@ def create_time_series_area_traces( y=category_df["value"], mode="lines", name=legend_name, - line=dict(color=color_generator.get_color(category)), + line=dict( + color=color_generator.get_color( + category, breakdown_type or ColorCategory.SECTOR + ) + ), fill="tonexty" if j > 0 else "tozeroy", stackgroup=f"year_{year}", legendgroup=category, @@ -608,6 +656,7 @@ def create_faceted_traces( chart_type: str, group_by: str | None = None, value_col: str = "value", + breakdown_type: ColorCategory | None = None, ) -> list[tuple[go.Scatter, int, int]]: """ Create traces for faceted time series plots. @@ -658,6 +707,7 @@ def create_faceted_traces( j, show_legend, category, + breakdown_type or ColorCategory.SECTOR, ) traces_info.append((trace, row, col)) else: @@ -670,7 +720,15 @@ def create_faceted_traces( continue trace = _create_single_trace( - scenario_df, scenario, color_generator, chart_type, value_col, 0, False, scenario + scenario_df, + scenario, + color_generator, + chart_type, + value_col, + 0, + False, + scenario, + ColorCategory.SCENARIO, ) traces_info.append((trace, row, col)) @@ -686,13 +744,14 @@ def _create_single_trace( stack_index: int, show_legend: bool, legend_group: str, + category: ColorCategory | None = None, ) -> go.Scatter: """Create a single trace for faceted plots.""" base_kwargs: dict[str, Any] = { "x": data_df["year"], "y": data_df[value_col], "name": name, - "line": dict(color=color_generator.get_color(legend_group)), + "line": dict(color=color_generator.get_color(legend_group, category)), "showlegend": show_legend, "legendgroup": legend_group, } @@ -715,7 +774,12 @@ def _create_single_trace( return go.Scatter(**base_kwargs) -def update_faceted_layout(fig: go.Figure, rows: int, group_by: str | None = None) -> None: +def update_faceted_layout( + fig: go.Figure, + rows: int, + group_by: str | None = None, + template: str = DEFAULT_PLOTLY_TEMPLATE, +) -> None: """ Update layout for faceted time series plots. @@ -727,8 +791,11 @@ def update_faceted_layout(fig: go.Figure, rows: int, group_by: str | None = None Number of subplot rows group_by : str, optional Group by column name + template : str + Plotly template name for theme-aware styling """ height = 400 if rows == 1 else 600 if rows == 2 else 800 + axis = get_axis_style(template) fig.update_layout( plot_bgcolor=TRANSPARENT, @@ -741,5 +808,14 @@ def update_faceted_layout(fig: go.Figure, rows: int, group_by: str | None = None height=height, ) - fig.update_xaxes(title_text="Year") - fig.update_yaxes(title_text="Energy Consumption (MWh)", col=1) + fig.update_xaxes( + title_text="Year", + gridcolor=axis["grid_color"], + linecolor=axis["axis_color"], + ) + fig.update_yaxes( + title_text="Energy Consumption (MWh)", + col=1, + gridcolor=axis["grid_color"], + linecolor=axis["axis_color"], + ) diff --git a/src/stride/ui/scenario/callbacks.py b/src/stride/ui/scenario/callbacks.py index dc4cf4f..fcc4631 100644 --- a/src/stride/ui/scenario/callbacks.py +++ b/src/stride/ui/scenario/callbacks.py @@ -16,6 +16,8 @@ ) from stride.ui.plotting.utils import get_error_annotation_style, get_neutral_color +from stride.ui.palette import ColorCategory + if TYPE_CHECKING: from stride.api import APIClient from stride.ui.plotting import StridePlots @@ -281,6 +283,13 @@ def update_consumption_plot( try: # Convert "None" to None breakdown_value = None if breakdown == "None" else breakdown + breakdown_type = ( + ColorCategory.END_USE + if breakdown_value == "End Use" + else ColorCategory.SECTOR + if breakdown_value == "Sector" + else None + ) # Get consumption data for this scenario df = data_handler.get_annual_electricity_consumption( scenarios=[scenario], group_by=breakdown_value @@ -294,6 +303,7 @@ def update_consumption_plot( value_col="value", group_col="scenario", show_scenario_indicators=False, + breakdown_type=breakdown_type, ) else: # Use theme-aware neutral gray color for the bars @@ -415,6 +425,13 @@ def update_peak_plot( try: # Convert "None" to None breakdown_value = None if breakdown == "None" else breakdown + breakdown_type = ( + ColorCategory.END_USE + if breakdown_value == "End Use" + else ColorCategory.SECTOR + if breakdown_value == "Sector" + else None + ) # Get peak demand data for this scenario df = data_handler.get_annual_peak_demand(scenarios=[scenario], group_by=breakdown_value) # Create plot @@ -426,6 +443,7 @@ def update_peak_plot( value_col="value", group_col="scenario", show_scenario_indicators=False, + breakdown_type=breakdown_type, ) else: # Use theme-aware neutral gray color for the bars @@ -441,7 +459,9 @@ def update_peak_plot( if not secondary_df.empty: # Get scenario color from color manager - scenario_color = plotter.color_manager.get_color(scenario) + scenario_color = plotter.color_manager.get_color( + scenario, ColorCategory.SCENARIO + ) # Add secondary metric as a line trace on the right y-axis fig.add_trace( @@ -559,6 +579,13 @@ def update_timeseries_plot( # Convert "None" to None breakdown_value = None if breakdown == "None" else breakdown + breakdown_type = ( + ColorCategory.END_USE + if breakdown_value == "End Use" + else ColorCategory.SECTOR + if breakdown_value == "Sector" + else None + ) # Get timeseries data. Need to pass "End Use" Literal Hera df = data_handler.get_time_series_comparison( @@ -570,7 +597,11 @@ def update_timeseries_plot( # Need to assign to new variable for typing. stack_col = "metric" if breakdown_value == "End Use" else str(breakdown_value) # Use the new time_series function for better multi-year visualization - fig = plotter.time_series(df, group_by=stack_col.lower() if breakdown_value else None) + fig = plotter.time_series( + df, + group_by=stack_col.lower() if breakdown_value else None, + breakdown_type=breakdown_type, + ) # Add weather variable if selected if weather_var and weather_var != "None": diff --git a/src/stride/ui/settings/callbacks.py b/src/stride/ui/settings/callbacks.py index e4e5e68..431d7c8 100644 --- a/src/stride/ui/settings/callbacks.py +++ b/src/stride/ui/settings/callbacks.py @@ -13,13 +13,15 @@ clear_temp_color_edits, create_color_preview_content, get_temp_color_edits, + parse_temp_edit_key, set_temp_color_edit, ) -from stride.ui.tui import ( +from stride.ui.palette_utils import ( delete_user_palette, list_user_palettes, load_user_palette, save_user_palette, + set_default_user_palette, ) @@ -94,15 +96,16 @@ def toggle_color_picker_modal( # noqa: C901 # Close modal and apply color on apply button if triggered_id == "color-picker-apply-btn": if current_label and picked_color: - # Store the color change temporarily + # current_label is a composite key "category:label" set_temp_color_edit(current_label, picked_color) - logger.info(f"Temporarily updated color for '{current_label}' to {picked_color}") + _, display_label = parse_temp_edit_key(current_label) + logger.info(f"Temporarily updated color for '{display_label}' to {picked_color}") # Increment counter to trigger refresh (will be handled by separate callback) return False, None, "", "#000000", "#000000", no_update # type: ignore[return-value] # Open modal when a color item is clicked if isinstance(triggered_id, dict) and triggered_id.get("type") == "color-item": - # Get the index of the clicked item + # Get the index of the clicked item (composite key "category:label") index = triggered_id.get("index") if index is None: raise PreventUpdate @@ -125,12 +128,15 @@ def toggle_color_picker_modal( # noqa: C901 if color_manager is None: raise PreventUpdate + # Parse composite key to get category and label + category_str, label = parse_temp_edit_key(index) + # Get current color (check temp edits first) temp_edits = get_temp_color_edits() if index in temp_edits: current_color = temp_edits[index] else: - current_color = color_manager.get_color(index) + current_color = color_manager.get_color(label, category_str) # Convert color to hex format for the color input hex_color = _convert_to_hex(current_color) @@ -138,7 +144,7 @@ def toggle_color_picker_modal( # noqa: C901 return ( True, index, - f"Edit Color: {index}", + f"Edit Color: {label}", hex_color, hex_color, no_update, # type: ignore[return-value] @@ -146,6 +152,24 @@ def toggle_color_picker_modal( # noqa: C901 raise PreventUpdate + @callback( + Output("unsaved-changes-indicator", "children"), + Input("color-edits-counter", "data"), + Input("settings-palette-applied", "data"), + prevent_initial_call=True, + ) + def update_unsaved_indicator(counter: int, palette_data: dict[str, Any]) -> html.Div | str: + """Show an indicator when there are unsaved color edits.""" + temp_edits = get_temp_color_edits() + if temp_edits: + n = len(temp_edits) + label = "change" if n == 1 else "changes" + return html.Div( + f"⚠ {n} unsaved color {label}. Use the save options below to keep them.", + className="text-warning small mt-1 mb-2", + ) + return "" + @callback( Output("color-preview-container", "children"), Input("color-edits-counter", "data"), @@ -203,27 +227,74 @@ def sync_color_inputs(color_value: str, hex_value: str) -> tuple[str, str]: Output("user-palette-selector-container", "style"), Output("user-palette-selector", "disabled"), Output("delete-user-palette-btn", "disabled"), + Output("set-default-palette-btn", "disabled"), Input("palette-type-selector", "value"), State("user-palette-selector", "value"), ) def toggle_user_palette_selector( palette_type: str, selected_palette: str | None - ) -> tuple[dict[str, str], bool, bool]: + ) -> tuple[dict[str, str], bool, bool, bool]: """Enable/disable user palette selector based on palette type.""" if palette_type == "user": - return {"display": "block"}, False, not selected_palette - else: - return {"display": "none"}, True, True + return {"display": "block"}, False, not selected_palette, not selected_palette + return {"display": "none"}, True, True, True @callback( Output("delete-user-palette-btn", "disabled", allow_duplicate=True), + Output("set-default-palette-btn", "disabled", allow_duplicate=True), Input("user-palette-selector", "value"), State("palette-type-selector", "value"), prevent_initial_call=True, ) - def update_delete_button(selected_palette: str | None, palette_type: str) -> bool: - """Enable/disable delete button based on whether a palette is selected.""" - return palette_type == "project" or not selected_palette + def update_delete_button(selected_palette: str | None, palette_type: str) -> tuple[bool, bool]: + """Enable/disable delete and set-default buttons based on selection.""" + disabled = palette_type != "user" or not selected_palette + return disabled, disabled + + @callback( + Output("set-default-palette-btn", "children", allow_duplicate=True), + Output("set-default-palette-btn", "color", allow_duplicate=True), + Output("default-user-palette-store", "data", allow_duplicate=True), + Input("set-default-palette-btn", "n_clicks"), + State("user-palette-selector", "value"), + State("default-user-palette-store", "data"), + prevent_initial_call=True, + ) + def toggle_default_palette( + n_clicks: int | None, + selected_palette: str | None, + current_default: str | None, + ) -> tuple[str, str, str | None]: + """Toggle setting/clearing the dashboard default palette.""" + if not n_clicks or not selected_palette: + raise PreventUpdate + + if current_default == selected_palette: + # Clear the default + set_default_user_palette(None) + logger.info("Cleared default user palette") + return "Set as Dashboard Default", "secondary", None + else: + # Set as default + set_default_user_palette(selected_palette) + logger.info(f"Set default user palette to: {selected_palette}") + return "Dashboard Default \u2713 (Clear)", "success", selected_palette + + @callback( + Output("set-default-palette-btn", "children", allow_duplicate=True), + Output("set-default-palette-btn", "color", allow_duplicate=True), + Input("user-palette-selector", "value"), + State("default-user-palette-store", "data"), + prevent_initial_call=True, + ) + def update_default_button_label( + selected_palette: str | None, + current_default: str | None, + ) -> tuple[str, str]: + """Update default button label when palette selection changes.""" + if selected_palette and selected_palette == current_default: + return "Dashboard Default \u2713 (Clear)", "success" + return "Set as Dashboard Default", "secondary" @callback( Output("revert-changes-status", "children"), @@ -347,7 +418,7 @@ def apply_selected_palette( logger.info("Switched to project palette") return {"type": "project", "name": None}, counter + 1 - elif ( + if ( triggered_id == "palette-type-selector" and palette_type == "user" and user_palette_name @@ -363,7 +434,7 @@ def apply_selected_palette( logger.error(f"Error loading user palette '{user_palette_name}': {e}") raise PreventUpdate - elif triggered_id == "user-palette-selector" and user_palette_name: + if triggered_id == "user-palette-selector" and user_palette_name: # User selected a different user palette from dropdown try: logger.info(f"Switching to user palette: {user_palette_name}") @@ -383,87 +454,69 @@ def apply_selected_palette( raise PreventUpdate @callback( - Output("save-palette-status", "children"), - Input("save-current-palette-btn", "n_clicks"), - State("settings-palette-applied", "data"), + Output("palette-type-selector", "value"), + Output("settings-palette-applied", "data", allow_duplicate=True), + Output("color-edits-counter", "data", allow_duplicate=True), + Input("reset-to-defaults-btn", "n_clicks"), + State("color-edits-counter", "data"), prevent_initial_call=True, ) - def save_current_palette( + def reset_to_defaults( n_clicks: int | None, - current_palette_data: dict[str, Any], - ) -> html.Div: - """Save edits to the currently active palette.""" + counter: int, + ) -> tuple[str, dict[str, Any], int]: + """Reset palette to defaults by creating fresh TOL colors and switching to project mode.""" if not n_clicks: raise PreventUpdate - try: - data_handler = get_data_handler_func() - color_manager = get_color_manager_func() - - if data_handler is None or color_manager is None: - return html.Div( - "✗ Error: No project loaded", - className="text-danger mt-2", - ) - - # Apply temporary edits to a copy of the palette - temp_edits = get_temp_color_edits() - palette = color_manager.get_palette() - palette_copy = palette.copy() - for label, color in temp_edits.items(): - palette_copy.update(label, color) - - # Get the palette data from the copy - palette_data = palette_copy.to_dict() + data_handler = get_data_handler_func() + if data_handler is None: + raise PreventUpdate - # Determine where to save based on current palette type - palette_type = current_palette_data.get("type", "project") - palette_name = current_palette_data.get("name") + # Create fresh palette from TOL themes + project dimensions + palette = ColorPalette() + data_handler.project._palette = palette + data_handler.project._auto_populate_palette() + palette = data_handler.project._palette - if palette_type == "project": - # Update project's palette with the modified copy - data_handler.project._palette = palette_copy - data_handler.project.save_palette() - message = "✓ Palette saved to project" - elif palette_type == "user" and palette_name: - # Save to existing user palette - save_user_palette(palette_name, palette_data) - message = f"✓ Palette saved to '{palette_name}'" - else: - return html.Div( - "✗ Error: No active palette to save to", - className="text-danger mt-2", - ) + # Save to project so it persists + data_handler.project.save_palette() - # Clear temporary edits after saving - clear_temp_color_edits() + clear_temp_color_edits() + on_palette_change_func(palette, "project", None) + logger.info("Reset palette to defaults and saved to project") - # Refresh the palette in the UI by calling on_palette_change - if palette_type == "project": - on_palette_change_func(palette_copy, "project", None) - elif palette_type == "user" and palette_name: - on_palette_change_func(palette_copy, "user", palette_name) + return "project", {"type": "project", "name": None}, counter + 1 - logger.info(f"Saved current palette ({palette_type}: {palette_name or 'project'})") - return html.Div( - message, - className="text-success mt-2", - ) - except Exception as e: - logger.error(f"Error saving current palette: {e}") + @callback( + Output("palette-source-hint", "children"), + Input("palette-type-selector", "value"), + Input("user-palette-selector", "value"), + prevent_initial_call=True, + ) + def update_palette_source_hint( + palette_type: str, user_palette_name: str | None + ) -> html.Div | str: + """Show a hint when user palette is selected.""" + if palette_type == "user" and user_palette_name: return html.Div( - f"✗ Error: {str(e)}", - className="text-danger mt-2", + f"Viewing user palette '{user_palette_name}'. " + "Changes apply to this session only. " + "Use 'Save to Project' to make permanent.", + className="text-info small mt-1 mb-2", ) + return "" @callback( - Output("save-palette-status", "children", allow_duplicate=True), + Output("save-palette-status", "children"), + Output("palette-type-selector", "value", allow_duplicate=True), + Output("settings-palette-applied", "data", allow_duplicate=True), Input("save-to-project-btn", "n_clicks"), prevent_initial_call=True, ) def save_to_project( n_clicks: int | None, - ) -> html.Div: + ) -> tuple[html.Div, str, dict[str, Any]]: """Save current palette to project.json.""" if not n_clicks: raise PreventUpdate @@ -473,17 +526,22 @@ def save_to_project( color_manager = get_color_manager_func() if data_handler is None or color_manager is None: - return html.Div( - "✗ Error: No project loaded", - className="text-danger mt-2", + return ( + html.Div( + "✗ Error: No project loaded", + className="text-danger mt-2", + ), + no_update, # type: ignore[return-value] + no_update, # type: ignore[return-value] ) # Apply temporary edits to a copy of the palette temp_edits = get_temp_color_edits() palette = color_manager.get_palette() palette_copy = palette.copy() - for label, color in temp_edits.items(): - palette_copy.update(label, color) + for composite_key, color in temp_edits.items(): + cat_str, label = parse_temp_edit_key(composite_key) + palette_copy.update(label, color, category=cat_str) # Update project's palette with the modified copy data_handler.project._palette = palette_copy @@ -497,15 +555,23 @@ def save_to_project( on_palette_change_func(palette_copy, "project", None) logger.info("Saved palette to project") - return html.Div( - "✓ Palette saved to project", - className="text-success mt-2", + return ( + html.Div( + "✓ Palette saved to project", + className="text-success mt-2", + ), + "project", + {"type": "project", "name": None}, ) except Exception as e: logger.error(f"Error saving palette to project: {e}") - return html.Div( - f"✗ Error: {str(e)}", - className="text-danger mt-2", + return ( + html.Div( + f"✗ Error: {str(e)}", + className="text-danger mt-2", + ), + no_update, # type: ignore[return-value] + no_update, # type: ignore[return-value] ) @callback( @@ -525,6 +591,8 @@ def show_new_palette_name_input(n_clicks: int | None) -> dict[str, str]: Output("save-new-palette-name", "value"), Output("user-palette-selector", "options", allow_duplicate=True), Output("user-palette-selector", "value", allow_duplicate=True), + Output("palette-type-selector", "value", allow_duplicate=True), + Output("settings-palette-applied", "data", allow_duplicate=True), Input("save-to-new-palette-btn", "n_clicks"), State("save-new-palette-name", "value"), prevent_initial_call=True, @@ -532,7 +600,7 @@ def show_new_palette_name_input(n_clicks: int | None) -> dict[str, str]: def save_to_new_palette( n_clicks: int | None, palette_name: str | None, - ) -> tuple[html.Div, dict[str, str], str, list[dict[str, str]], str]: + ) -> tuple[html.Div, dict[str, str], str, list[dict[str, str]], str, str, dict[str, Any]]: """Save current palette to a new user palette.""" if not n_clicks: raise PreventUpdate @@ -548,6 +616,8 @@ def save_to_new_palette( "", no_update, # type: ignore[return-value] no_update, # type: ignore[return-value] + no_update, # type: ignore[return-value] + no_update, # type: ignore[return-value] ) try: @@ -557,8 +627,9 @@ def save_to_new_palette( temp_edits = get_temp_color_edits() palette = color_manager.get_palette() palette_copy = palette.copy() - for label, color in temp_edits.items(): - palette_copy.update(label, color) + for composite_key, color in temp_edits.items(): + cat_str, label = parse_temp_edit_key(composite_key) + palette_copy.update(label, color, category=cat_str) # Get the palette data from the copy palette_data = palette_copy.to_dict() @@ -578,15 +649,18 @@ def save_to_new_palette( on_palette_change_func(palette_copy, "user", palette_name.strip()) logger.info(f"Saved palette to new user palette: {palette_name}") + trimmed = palette_name.strip() return ( html.Div( - f"✓ Palette saved as '{palette_name.strip()}'", + f"✓ Palette saved as '{trimmed}'", className="text-success mt-2", ), {"display": "none"}, "", updated_options, - palette_name.strip(), + trimmed, + "user", + {"type": "user", "name": trimmed}, ) except Exception as e: logger.error(f"Error saving new palette: {e}") @@ -599,6 +673,8 @@ def save_to_new_palette( palette_name, no_update, # type: ignore[return-value] no_update, # type: ignore[return-value] + no_update, # type: ignore[return-value] + no_update, # type: ignore[return-value] ) @callback( @@ -750,8 +826,9 @@ def populate_json_editor(is_open: bool, reset_clicks: int | None) -> str: temp_edits = get_temp_color_edits() # Apply temp edits to get the current state - for label, color in temp_edits.items(): - palette.update(label, color) + for composite_key, color in temp_edits.items(): + cat_str, label = parse_temp_edit_key(composite_key) + palette.update(label, color, category=cat_str) # Get the palette as a dict palette_dict = palette.to_dict() @@ -794,11 +871,15 @@ def apply_json_palette( no_update, # type: ignore[return-value] ) - # Check if it has the expected structure - if not all(key in palette_dict for key in ["scenarios", "model_years", "metrics"]): + # Check if it has the expected structure (accept both new and legacy formats) + _required_new = {"scenarios", "model_years", "sectors", "end_uses"} + _required_legacy = {"scenarios", "model_years", "metrics"} + if not ( + _required_new <= palette_dict.keys() or _required_legacy <= palette_dict.keys() + ): return ( html.Div( - "✗ Invalid palette structure: must have 'scenarios', 'model_years', and 'metrics' keys", + "✗ Invalid palette structure: must have 'scenarios', 'model_years', 'sectors', and 'end_uses' keys (or legacy 'metrics' key)", className="text-danger mt-2", ), no_update, # type: ignore[return-value] @@ -806,7 +887,7 @@ def apply_json_palette( ) # Create a ColorPalette from the JSON - palette = ColorPalette(palette_dict) + palette = ColorPalette.from_dict(palette_dict) # Apply it to the color manager on_palette_change_func(palette, "custom", None) @@ -881,7 +962,7 @@ def save_max_cached_projects( ) from stride.ui.app import _evict_oldest_project, set_max_cached_projects_override - from stride.ui.tui import set_max_cached_projects + from stride.config import set_max_cached_projects # Persist to config file set_max_cached_projects(n) diff --git a/src/stride/ui/settings/layout.py b/src/stride/ui/settings/layout.py index a07d624..ea7ee8b 100644 --- a/src/stride/ui/settings/layout.py +++ b/src/stride/ui/settings/layout.py @@ -6,8 +6,11 @@ from dash import dcc, html from stride.ui.color_manager import ColorManager +from stride.ui.palette import ColorCategory -# Store for temporarily edited colors before saving +# Store for temporarily edited colors before saving. +# Keys are composite ``"category_value:label"`` strings (e.g. +# ``"scenarios:baseline"``). _temp_color_edits: dict[str, str] = {} @@ -17,6 +20,7 @@ def create_settings_layout( current_palette_type: str, current_palette_name: str | None, color_manager: ColorManager, + default_user_palette: str | None = None, ) -> html.Div: """ Create the settings page layout. @@ -33,6 +37,8 @@ def create_settings_layout( Name of currently active user palette (if type is 'user') color_manager : ColorManager Color manager instance for displaying current colors + default_user_palette : str | None + Name of the current default user palette, or None if not set Returns ------- @@ -45,18 +51,22 @@ def create_settings_layout( # Get structured palette with categories structured_palette = palette.to_dict() - # Extract colors for each category and convert to RGBA for display + # Extract colors for each category and convert to RGBA for display. scenario_colors = {} for label in structured_palette.get("scenarios", {}): - scenario_colors[label] = color_manager.get_color(label) + scenario_colors[label] = color_manager.get_color(label, ColorCategory.SCENARIO) model_year_colors = {} for label in structured_palette.get("model_years", {}): - model_year_colors[label] = color_manager.get_color(label) + model_year_colors[label] = color_manager.get_color(label, ColorCategory.MODEL_YEAR) - metric_colors = {} - for label in structured_palette.get("metrics", {}): - metric_colors[label] = color_manager.get_color(label) + sector_colors = {} + for label in structured_palette.get("sectors", {}): + sector_colors[label] = color_manager.get_color(label, ColorCategory.SECTOR) + + end_use_colors = {} + for label in structured_palette.get("end_uses", {}): + end_use_colors[label] = color_manager.get_color(label, ColorCategory.END_USE) # Get temporary color edits temp_edits = get_temp_color_edits() @@ -203,7 +213,49 @@ def create_settings_layout( placeholder="Select a user palette...", disabled=( current_palette_type - == "project" + != "user" + ), + ), + dbc.Button( + "Delete", + id="delete-user-palette-btn", + color="danger", + outline=True, + size="sm", + className="ms-2 mt-2", + disabled=( + current_palette_type + != "user" + or not current_palette_name + ), + ), + dbc.Button( + ( + "Dashboard Default ✓ (Clear)" + if ( + current_palette_name + and current_palette_name + == default_user_palette + ) + else "Set as Dashboard Default" + ), + id="set-default-palette-btn", + color=( + "success" + if ( + current_palette_name + and current_palette_name + == default_user_palette + ) + else "secondary" + ), + outline=True, + size="sm", + className="ms-2 mt-2 theme-text", + disabled=( + current_palette_type + != "user" + or not current_palette_name ), ), ], @@ -216,6 +268,24 @@ def create_settings_layout( ) }, ), + # Palette source hint (shown dynamically) + html.Div( + id="palette-source-hint", + className="mt-2", + ), + # Unsaved changes indicator (shown dynamically) + html.Div( + id="unsaved-changes-indicator", + ), + # Reset to Defaults button + dbc.Button( + "Reset to Defaults", + id="reset-to-defaults-btn", + color="secondary", + outline=True, + size="sm", + className="mt-3", + ), ] ) ], @@ -250,7 +320,7 @@ def create_settings_layout( html.Div( [ _create_color_item( - label, color, temp_edits + ColorCategory.SCENARIO.value, label, color, temp_edits ) for label, color in scenario_colors.items() ], @@ -270,7 +340,7 @@ def create_settings_layout( html.Div( [ _create_color_item( - label, color, temp_edits + ColorCategory.MODEL_YEAR.value, label, color, temp_edits ) for label, color in model_year_colors.items() ], @@ -284,21 +354,41 @@ def create_settings_layout( html.Div( [ html.H6( - "Metrics", + "Sectors", className="mb-2 text-muted", ), html.Div( [ _create_color_item( - label, color, temp_edits + ColorCategory.SECTOR.value, label, color, temp_edits ) - for label, color in metric_colors.items() + for label, color in sector_colors.items() + ], + className="d-flex flex-wrap gap-2 mb-3", + ), + ] + ) + if sector_colors + else None, + # End Uses + html.Div( + [ + html.H6( + "End Uses", + className="mb-2 text-muted", + ), + html.Div( + [ + _create_color_item( + ColorCategory.END_USE.value, label, color, temp_edits + ) + for label, color in end_use_colors.items() ], className="d-flex flex-wrap gap-2", ), ] ) - if metric_colors + if end_use_colors else None, ], ) @@ -488,6 +578,8 @@ def create_settings_layout( dcc.Store(id="selected-color-label", data=None), # Hidden store for tracking color edits (triggers refresh) dcc.Store(id="color-edits-counter", data=0), + # Hidden store for tracking the current default user palette + dcc.Store(id="default-user-palette-store", data=default_user_palette), # Save Options Section dbc.Row( html.Div( @@ -505,13 +597,6 @@ def create_settings_layout( ), html.Div( [ - dbc.Button( - "Save Current Palette", - id="save-current-palette-btn", - color="primary", - outline=True, - className="m-1 theme-text", - ), dbc.Button( "Save to Project", id="save-to-project-btn", @@ -520,7 +605,7 @@ def create_settings_layout( className="m-1 theme-text", ), dbc.Button( - "Save to New Palette", + "Save As User Palette", id="save-to-new-palette-btn", color="info", outline=True, @@ -529,21 +614,9 @@ def create_settings_layout( dbc.Button( "Revert Changes", id="revert-changes-btn", - color="warning", - outline=True, - className="m-1 theme-text", - ), - dbc.Button( - "Delete Selected User Palette", - id="delete-user-palette-btn", - color="danger", + color="secondary", outline=True, className="m-1 theme-text", - disabled=( - current_palette_type - == "project" - or not current_palette_name - ), ), ], className="d-flex flex-wrap mb-3", @@ -601,26 +674,31 @@ def create_settings_layout( ) -def _create_color_item(label: str, color: str, temp_edits: dict[str, str]) -> html.Div: +def _create_color_item( + category: str, label: str, color: str, temp_edits: dict[str, str] +) -> html.Div: """ Create a color preview item with label. Parameters ---------- + category : str + Category value string (e.g. ``"scenarios"``, ``"model_years"``). label : str Label name color : str Color value (hex, rgb, or rgba) temp_edits : dict[str, str] - Dictionary of temporary color edits + Dictionary of temporary color edits (composite key → color) Returns ------- html.Div Color preview component """ + composite_key = f"{category}:{label}" # Check if there's a temporary edit for this color - display_color = temp_edits.get(label, color) + display_color = temp_edits.get(composite_key, color) return html.Div( [ @@ -646,7 +724,7 @@ def _create_color_item(label: str, color: str, temp_edits: dict[str, str]) -> ht }, ), ], - id={"type": "color-item", "index": label}, + id={"type": "color-item", "index": composite_key}, n_clicks=0, style={ "display": "inline-flex", @@ -666,7 +744,10 @@ def _create_color_item(label: str, color: str, temp_edits: dict[str, str]) -> ht def get_temp_color_edits() -> dict[str, str]: - """Get the temporary color edits dictionary.""" + """Get the temporary color edits dictionary. + + Keys are composite ``"category_value:label"`` strings. + """ return _temp_color_edits @@ -675,9 +756,55 @@ def clear_temp_color_edits() -> None: _temp_color_edits.clear() -def set_temp_color_edit(label: str, color: str) -> None: - """Set a temporary color edit.""" - _temp_color_edits[label] = color +def set_temp_color_edit(composite_key: str, color: str) -> None: + """Set a temporary color edit. + + Parameters + ---------- + composite_key : str + Key in ``"category_value:label"`` format. + color : str + Hex color string. + """ + _temp_color_edits[composite_key] = color + + +def parse_temp_edit_key(composite_key: str) -> tuple[str, str]: + """Split a composite temp-edit key into ``(category_value, label)``. + + Parameters + ---------- + composite_key : str + Key in ``"category_value:label"`` format. + + Returns + ------- + tuple[str, str] + ``(category_value, label)`` + """ + category, _, label = composite_key.partition(":") + return category, label + + +def get_temp_edits_for_category(category_value: str) -> dict[str, str]: + """Return temp edits for one category with plain label keys. + + Parameters + ---------- + category_value : str + The ``ColorCategory`` ``.value`` string, e.g. ``"scenarios"``. + + Returns + ------- + dict[str, str] + ``{label: color}`` for entries matching the given category. + """ + prefix = f"{category_value}:" + return { + key[len(prefix):]: color + for key, color in _temp_color_edits.items() + if key.startswith(prefix) + } def create_color_preview_content(color_manager: ColorManager) -> list[html.Div]: @@ -700,18 +827,22 @@ def create_color_preview_content(color_manager: ColorManager) -> list[html.Div]: # Get structured palette with categories structured_palette = palette.to_dict() - # Extract colors for each category and convert to RGBA for display + # Extract colors for each category and convert to RGBA for display. scenario_colors = {} for label in structured_palette.get("scenarios", {}): - scenario_colors[label] = color_manager.get_color(label) + scenario_colors[label] = color_manager.get_color(label, ColorCategory.SCENARIO) model_year_colors = {} for label in structured_palette.get("model_years", {}): - model_year_colors[label] = color_manager.get_color(label) + model_year_colors[label] = color_manager.get_color(label, ColorCategory.MODEL_YEAR) - metric_colors = {} - for label in structured_palette.get("metrics", {}): - metric_colors[label] = color_manager.get_color(label) + sector_colors = {} + for label in structured_palette.get("sectors", {}): + sector_colors[label] = color_manager.get_color(label, ColorCategory.SECTOR) + + end_use_colors = {} + for label in structured_palette.get("end_uses", {}): + end_use_colors[label] = color_manager.get_color(label, ColorCategory.END_USE) # Get temporary color edits temp_edits = get_temp_color_edits() @@ -730,7 +861,7 @@ def create_color_preview_content(color_manager: ColorManager) -> list[html.Div]: ), html.Div( [ - _create_color_item(label, color, temp_edits) + _create_color_item(ColorCategory.SCENARIO.value, label, color, temp_edits) for label, color in scenario_colors.items() ], className="d-flex flex-wrap gap-2 mb-3", @@ -750,7 +881,7 @@ def create_color_preview_content(color_manager: ColorManager) -> list[html.Div]: ), html.Div( [ - _create_color_item(label, color, temp_edits) + _create_color_item(ColorCategory.MODEL_YEAR.value, label, color, temp_edits) for label, color in model_year_colors.items() ], className="d-flex flex-wrap gap-2 mb-3", @@ -759,19 +890,39 @@ def create_color_preview_content(color_manager: ColorManager) -> list[html.Div]: ) ) - # Metrics - if metric_colors: + # Sectors + if sector_colors: + content.append( + html.Div( + [ + html.H6( + "Sectors", + className="mb-2 text-muted", + ), + html.Div( + [ + _create_color_item(ColorCategory.SECTOR.value, label, color, temp_edits) + for label, color in sector_colors.items() + ], + className="d-flex flex-wrap gap-2 mb-3", + ), + ] + ) + ) + + # End Uses + if end_use_colors: content.append( html.Div( [ html.H6( - "Metrics", + "End Uses", className="mb-2 text-muted", ), html.Div( [ - _create_color_item(label, color, temp_edits) - for label, color in metric_colors.items() + _create_color_item(ColorCategory.END_USE.value, label, color, temp_edits) + for label, color in end_use_colors.items() ], className="d-flex flex-wrap gap-2", ), diff --git a/src/stride/ui/tui.py b/src/stride/ui/tui.py deleted file mode 100644 index 2b9ccc3..0000000 --- a/src/stride/ui/tui.py +++ /dev/null @@ -1,1240 +0,0 @@ -""" -TUI (Text User Interface) framework for managing color palettes in Stride. - -This module provides a terminal-based interface for viewing and managing color palettes -at both the user and project levels. It uses the Textual library to create an interactive -interface with multiple columns for different label groups. -""" - -import re -from pathlib import Path -from typing import Any - -from loguru import logger -from rich.style import Style -from rich.text import Text -from textual.app import App, ComposeResult -from textual.binding import Binding -from textual.containers import Horizontal, ScrollableContainer -from textual.widgets import DataTable, Footer, Header, Input, Label, Static - -from stride.ui.palette import ColorPalette, hex_color_pattern - - -def color_to_rich_format(color: str) -> str: - """Convert color string to Rich-compatible format. - - Rich doesn't support rgba() format, so we need to convert it. - - Parameters - ---------- - color : str - Color string in hex, rgb, or rgba format - - Returns - ------- - str - Color string that Rich can parse - """ - # If it's rgba, convert to rgb by dropping the alpha - if color.startswith("rgba("): - # Extract rgb values only - match = re.match(r"rgba?\((\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*(?:,\s*[\d.]+)?\)", color) - if match: - r, g, b = match.groups() - return f"rgb({r},{g},{b})" - return color - - -def validate_color(color: str) -> bool: - """Validate if the color string is in a valid format. - - Parameters - ---------- - color : str - Color string to validate - - Returns - ------- - bool - True if valid, False otherwise - """ - if not color: - return False - - # Check hex format - if hex_color_pattern.match(color): - return True - - # Check rgb/rgba format - rgb_pattern = re.compile(r"^rgba?\(\s*\d+\s*,\s*\d+\s*,\s*\d+\s*(?:,\s*[\d.]+\s*)?\)$") - if rgb_pattern.match(color): - return True - - return False - - -class PaletteInfo(Static): - """Widget to display palette metadata (name, location).""" - - def __init__(self, name: str, location: Path, palette_type: str, **kwargs: Any) -> None: - """Initialize the palette info widget. - - Parameters - ---------- - name : str - Name of the palette (derived from filename) - location : Path - Full path to the palette file - palette_type : str - Type of palette ('user' or 'project') - """ - super().__init__(**kwargs) - self.palette_name = name - self.palette_location = location - self.palette_type = palette_type - - def compose(self) -> ComposeResult: - """Compose the palette info display.""" - info_text = ( - f"[bold cyan]Palette:[/bold cyan] {self.palette_name} | " - f"[bold cyan]Type:[/bold cyan] {self.palette_type} | " - f"[bold cyan]Location:[/bold cyan] {self.palette_location}" - ) - yield Label(info_text) - - -class LabelGroupColumn(Static): - """Widget to display a single label group as a column.""" - - def __init__( - self, - group_name: str, - labels: dict[str, str], - parent_viewer: "PaletteViewer", - **kwargs: Any, - ) -> None: - """Initialize a label group column. - - Parameters - ---------- - group_name : str - Name of the label group (e.g., "End Uses", "Scenarios") - labels : dict[str, str] - Mapping of label names to hex color strings - parent_viewer : PaletteViewer - Reference to parent viewer for edit callbacks - """ - super().__init__(**kwargs) - self.group_name = group_name - self.labels = labels - self.parent_viewer = parent_viewer - - def compose(self) -> ComposeResult: - """Compose the label group column.""" - # Group header - yield Label(f"[bold white on blue] {self.group_name} [/bold white on blue]") - - # Create a data table for the labels - # Use a valid CSS ID (replace spaces and special chars with underscores) - table_id = f"table_{self.group_name.replace(' ', '_').replace('-', '_')}" - table: DataTable[Any] = DataTable(zebra_stripes=True, classes="label-table", id=table_id) - table.cursor_type = "cell" - table.show_cursor = True - yield table - - def on_mount(self) -> None: - """Populate the table after mounting.""" - table_id = f"table_{self.group_name.replace(' ', '_').replace('-', '_')}" - try: - table: DataTable[Any] = self.query_one("DataTable", DataTable) - table.add_columns("Label", "Color", "Preview") - - # Add rows (preserve order from dict - insertion order is maintained in Python 3.7+) - for label, color in self.labels.items(): - # Create a color preview using the color (convert to Rich-compatible format) - rich_color = color_to_rich_format(color) - preview = Text("████", style=Style(color=rich_color)) - table.add_row(label, color, preview) - except Exception as e: - logger.error(f"Error populating table {table_id}: {e}") - raise - - -class PaletteViewer(App[None]): - """Main TUI application for viewing and managing color palettes.""" - - CSS = """ - Screen { - background: $surface; - } - - #palette-info { - height: auto; - padding: 1 2; - margin-bottom: 1; - background: $panel; - border: solid $primary; - } - - #columns-container { - height: 1fr; - padding: 1; - overflow-x: auto; - overflow-y: auto; - } - - LabelGroupColumn { - width: auto; - min-width: 25; - max-width: 35; - height: auto; - margin: 0 1; - padding: 1; - background: $panel; - border: solid $accent; - } - - .label-table { - height: auto; - min-height: 10; - max-height: 30; - margin-top: 1; - } - - /* Highlight only specific columns, not the preview */ - DataTable > .datatable--cursor { - background: transparent; - } - - /* Highlight Label column (column 0) when selected */ - DataTable > .datatable--cursor-cell-0-0 { - background: $accent 30%; - } - - /* Highlight Color column (column 1) when selected */ - DataTable > .datatable--cursor-cell-0-1 { - background: $accent 30%; - } - - Label { - margin: 0 0 1 0; - } - - Horizontal { - width: auto; - height: auto; - } - """ - - BINDINGS = [ - Binding("q", "quit", "Quit", priority=True), - Binding("e", "edit_color", "Edit"), - Binding("a", "add_label", "Add Label"), - Binding("x", "delete_label", "Delete Label"), - Binding("X", "delete_group", "Delete Group"), - Binding("s", "save_palette", "Save"), - Binding("u", "move_up", "Move Up"), - Binding("d", "move_down", "Move Down"), - Binding("r", "refresh", "Refresh"), - ("?", "help", "Help"), - Binding("escape", "cancel_edit", "Cancel Edit", show=False), - ] - - def __init__( - self, - palette_name: str, - palette_location: Path, - palette_type: str, - label_groups: dict[str, dict[str, str]], - **kwargs: Any, - ) -> None: - """Initialize the palette viewer application. - - Parameters - ---------- - palette_name : str - Name of the palette - palette_location : Path - Path to the palette file - palette_type : str - Type of palette ('user' or 'project') - label_groups : dict[str, dict[str, str]] - Nested dictionary of group_name -> label_name -> color - """ - super().__init__(**kwargs) - self.palette_name = palette_name - self.palette_location = palette_location - self.palette_type = palette_type - self.label_groups = label_groups - self.has_unsaved_changes = False - self.editing_mode = False - self.editing_table: DataTable[Any] | None = None - self.editing_row: int | None = None - self.editing_label: str | None = None - self.original_color: str | None = None - self.input_mode: str | None = ( - None # Tracks what we're inputting: 'edit', 'add_label', 'add_color' - ) - self.temp_label_name: str | None = None # Temporary storage for label name when adding - self.temp_group_name: str | None = ( - None # Temporary storage for group name when adding labels - ) - - def compose(self) -> ComposeResult: - """Compose the main UI layout.""" - yield Header(show_clock=True) - - # Palette info section - yield PaletteInfo( - self.palette_name, - self.palette_location, - self.palette_type, - id="palette-info", - ) - - # Columns container with horizontal layout - with ScrollableContainer(id="columns-container"): - with Horizontal(): - # Create a column for each label group - # Create a column for each label group - if self.label_groups: - for group_name, labels in self.label_groups.items(): - yield LabelGroupColumn(group_name, labels, self) - else: - # If palette is empty, show a helpful message - yield Label( - "[dim]Empty palette. Press 'a' to add a label.[/dim]", - id="empty-palette-msg", - ) - - yield Footer() - - def on_mount(self) -> None: - """Called when the app is mounted and all widgets are ready.""" - # Focus first table if available - tables = self.query(DataTable) - if tables: - tables.first().focus() - - def action_edit_color(self) -> None: - """Enter edit mode for the selected color cell.""" - # Find which table has focus - focused_table = None - for table in self.query(DataTable): - if table.has_focus: - focused_table = table - break - - if not focused_table: - self.notify("No color selected. Navigate to a color first.", severity="warning") - return - - # Make sure we're in the Color column (column 1) - if focused_table.cursor_column != 1: - # Move to color column - focused_table.move_cursor(column=1) - - if focused_table.cursor_row is None or focused_table.cursor_row < 0: - self.notify("No color selected. Navigate to a color first.", severity="warning") - return - - row_key = focused_table.get_row_at(focused_table.cursor_row) - label = str(row_key[0]) - current_color = str(row_key[1]) - - # Enter editing mode - self.editing_mode = True - self.editing_table = focused_table - self.editing_row = focused_table.cursor_row - self.editing_label = label - self.original_color = current_color - - # Replace the cell with an Input widget - self.mount_inline_editor(focused_table, current_color) - - def mount_inline_editor(self, table: DataTable[Any], current_color: str) -> None: - """Mount an inline input widget for editing the color. - - Parameters - ---------- - table : DataTable - The table containing the cell to edit - current_color : str - The current color value - """ - # Create an input widget for inline editing - input_widget = Input( - value=current_color, - placeholder="e.g., #FF5733 or rgb(255,87,51)", - id="inline-color-input", - ) - - # Mount it near the table - self.mount(input_widget) - input_widget.focus() - - self.notify("Type new color and press Enter to save, Esc to cancel", timeout=2) - - def on_input_submitted(self, event: Input.Submitted) -> None: - """Handle inline input submission.""" - if event.input.id == "inline-color-input" and self.editing_mode: - new_color = event.value.strip() - - # Validate the color - if validate_color(new_color): - self.apply_color_edit(new_color) - event.input.remove() - self.editing_mode = False - if self.editing_table: - self.editing_table.focus() - else: - self.notify( - "Invalid color format. Use hex (#FF5733) or rgb (rgb(255,87,51))", - severity="error", - ) - elif event.input.id == "add-label-input": - self.handle_add_label_name(event.value.strip(), event.input) - elif event.input.id == "add-color-input": - self.handle_add_label_color(event.value.strip(), event.input) - - def action_cancel_edit(self) -> None: - """Cancel inline editing or input mode.""" - if self.editing_mode: - input_widget = self.query_one("#inline-color-input", Input) - input_widget.remove() - self.editing_mode = False - if self.editing_table: - self.editing_table.focus() - self.notify("Edit cancelled") - elif self.input_mode: - # Cancel any other input mode - input_ids = ["add-label-input", "add-color-input"] - for input_id in input_ids: - try: - input_widget = self.query_one(f"#{input_id}", Input) - input_widget.remove() - except Exception: - pass - self.input_mode = None - self.temp_label_name = None - self.temp_group_name = None - self.notify("Input cancelled") - - def apply_color_edit(self, new_color: str) -> None: - """Apply the color edit to the table and data. - - Parameters - ---------- - new_color : str - The new color value - """ - if not self.editing_table or self.editing_row is None or not self.editing_label: - return - - # Update the label_groups data - for group_name, labels in self.label_groups.items(): - if self.editing_label in labels: - labels[self.editing_label] = new_color - self.has_unsaved_changes = True - - # Refresh the table to show the updated color - self._refresh_table(self.editing_table, group_name) - - self.label_groups[group_name][self.editing_label] = new_color - self.notify(f"Updated {self.editing_label} to {new_color}") - break - - def action_add_label(self) -> None: - """Add a new label to the currently focused group.""" - if self.editing_mode or self.input_mode: - return - - # Find which table has focus (or which group to add to) - group_name = None - - for table in self.query(DataTable): - if table.has_focus: - # Find the parent column to get the group name - parent_column = table.parent - while parent_column and not isinstance(parent_column, LabelGroupColumn): - parent_column = parent_column.parent - if parent_column: - group_name = parent_column.group_name # type: ignore[attr-defined] - break - - # If no group is focused and we have groups, ask which group - # If no groups exist, we need to create one first - if not group_name: - if not self.label_groups: - self.notify( - "No groups exist. Cannot add labels to empty palette.", severity="warning" - ) - return - # Default to first group - group_name = list(self.label_groups.keys())[0] - self.notify(f"No group selected. Adding to '{group_name}'", timeout=2) - - # Prompt for label name - self.input_mode = "add_label" - self.temp_group_name = group_name - input_widget = Input( - placeholder="Enter label name (e.g., 'Heating', 'Residential')", - id="add-label-input", - ) - self.mount(input_widget) - input_widget.focus() - self.notify(f"Enter label name for group '{group_name}'", timeout=3) - - def handle_add_label_name(self, label_name: str, input_widget: Input) -> None: - """Handle the label name input when adding a new label.""" - if not label_name: - self.notify("Label name cannot be empty", severity="error") - return - - # Check if label already exists in this group - group_name = self.temp_group_name - if group_name and label_name in self.label_groups.get(group_name, {}): - self.notify(f"Label '{label_name}' already exists in '{group_name}'", severity="error") - return - - # Store the label name and prompt for color - self.temp_label_name = label_name - input_widget.remove() - - # Prompt for color (optional - can press Enter for auto-assigned color) - self.input_mode = "add_color" - color_input = Input( - placeholder="Enter color (e.g., #FF5733) or press Enter for auto-color", - id="add-color-input", - ) - self.mount(color_input) - color_input.focus() - self.notify("Enter color or press Enter to auto-assign", timeout=3) - - def handle_add_label_color(self, color: str, input_widget: Input) -> None: - """Handle the color input when adding a new label.""" - group_name = self.temp_group_name - label_name = self.temp_label_name - - if not group_name or not label_name: - self.notify("Error: Missing group or label name", severity="error") - input_widget.remove() - self.input_mode = None - return - - # If no color provided, auto-assign from theme using ColorPalette - if not color: - # Create a temporary ColorPalette with existing labels to get next color - temp_palette = ColorPalette(self.label_groups.get(group_name, {})) - # This will automatically cycle to the next color in the theme - color = temp_palette.get(label_name) - elif not validate_color(color): - self.notify("Invalid color format. Use hex (#FF5733) or rgb format", severity="error") - return - - # Add the label to the group - if group_name not in self.label_groups: - self.label_groups[group_name] = {} - - self.label_groups[group_name][label_name] = color - self.has_unsaved_changes = True - - # Refresh the display - self._refresh_display() - - # Clean up - input_widget.remove() - self.input_mode = None - self.temp_label_name = None - self.temp_group_name = None - - self.notify(f"Added label '{label_name}' to '{group_name}'", severity="information") - - def _refresh_display(self) -> None: - """Refresh the entire display with updated label groups.""" - # Remove the columns container and rebuild it - container = self.query_one("#columns-container", ScrollableContainer) - - # Remove all children - for child in list(container.children): - child.remove() - - # Create horizontal layout and mount it - horizontal = Horizontal() - container.mount(horizontal) - - if self.label_groups: - # Remove empty palette message if it exists - try: - msg = self.query_one("#empty-palette-msg") - msg.remove() - except Exception: - pass - - for group_name, labels in self.label_groups.items(): - column = LabelGroupColumn(group_name, labels, self) - horizontal.mount(column) - else: - # Show empty message - label = Label( - "[dim]Empty palette. Press 'c' to create a new group or 'a' to add a label.[/dim]", - id="empty-palette-msg", - ) - horizontal.mount(label) - - def action_delete_label(self) -> None: - """Delete the currently selected label.""" - if self.editing_mode or self.input_mode: - return - - # Find which table has focus - focused_table = None - group_name = None - - for table in self.query(DataTable): - if table.has_focus: - focused_table = table - # Find the parent column to get the group name - parent_column = table.parent - while parent_column and not isinstance(parent_column, LabelGroupColumn): - parent_column = parent_column.parent - if parent_column: - group_name = parent_column.group_name # type: ignore[attr-defined] - break - - if not focused_table or not group_name: - self.notify("No label selected", severity="warning") - return - - if focused_table.cursor_row is None or focused_table.cursor_row < 0: - self.notify("No label selected", severity="warning") - return - - # Get the label name from the current row - row_key = focused_table.get_row_at(focused_table.cursor_row) - label = str(row_key[0]) - - # Delete the label from the group - if group_name in self.label_groups and label in self.label_groups[group_name]: - del self.label_groups[group_name][label] - self.has_unsaved_changes = True - - # If group is now empty, optionally remove it (or keep it empty) - if not self.label_groups[group_name]: - # Keep the empty group for now - user can delete it with 'X' - pass - - # Refresh the display - self._refresh_display() - - self.notify(f"Deleted label '{label}' from '{group_name}'", severity="information") - else: - self.notify(f"Label '{label}' not found", severity="error") - - def action_delete_group(self) -> None: - """Delete the currently focused group/category (disabled for pre-defined groups).""" - if self.editing_mode or self.input_mode: - return - - # Find which table has focus to determine the group - group_name = None - - for table in self.query(DataTable): - if table.has_focus: - # Find the parent column to get the group name - parent_column = table.parent - while parent_column and not isinstance(parent_column, LabelGroupColumn): - parent_column = parent_column.parent - if parent_column: - group_name = parent_column.group_name # type: ignore[attr-defined] - break - - if not group_name: - self.notify("No group selected", severity="warning") - return - - # Prevent deletion of pre-defined groups - predefined_groups = {"Scenarios", "Model Years", "Metrics"} - if group_name in predefined_groups: - self.notify( - f"Cannot delete pre-defined group '{group_name}'. You can only delete labels within it.", - severity="warning", - ) - return - - # Confirm deletion (since this removes all labels in the group) - if group_name in self.label_groups: - label_count = len(self.label_groups[group_name]) - del self.label_groups[group_name] - self.has_unsaved_changes = True - - # Refresh the display - self._refresh_display() - - msg = f"Deleted group '{group_name}'" - if label_count > 0: - msg += f" and {label_count} label(s)" - self.notify(msg, severity="information") - else: - self.notify(f"Group '{group_name}' not found", severity="error") - - def action_save_palette(self) -> None: - """Save the current palette to disk.""" - if not self.has_unsaved_changes: - self.notify("No changes to save") - return - - try: - # Convert label_groups to structured format - # Map display names back to internal names - display_to_category = { - "Scenarios": "scenarios", - "Model Years": "model_years", - "Metrics": "metrics", - } - - structured_palette: dict[str, dict[str, str]] = { - "scenarios": {}, - "model_years": {}, - "metrics": {}, - } - - for group_name, labels in self.label_groups.items(): - category_name = display_to_category.get(group_name) - if category_name: - structured_palette[category_name] = labels - else: - # Legacy/unknown groups - add to metrics - structured_palette["metrics"].update(labels) - - if self.palette_type == "project": - # Save to project.json5 - from stride.models import ProjectConfig - - config = ProjectConfig.from_file(self.palette_location) - config.color_palette = structured_palette - self.palette_location.write_text(config.model_dump_json(indent=2)) - self.notify( - f"Saved project palette to {self.palette_location}", severity="information" - ) - else: - # Save to user palette JSON - import json - - data = { - "name": self.palette_name, - "palette": structured_palette, - } - with open(self.palette_location, "w") as f: - json.dump(data, f, indent=2) - self.notify( - f"Saved user palette to {self.palette_location}", severity="information" - ) - - self.has_unsaved_changes = False - except Exception as e: - self.notify(f"Error saving palette: {e}", severity="error") - logger.error(f"Error saving palette: {e}") - - def action_move_up(self) -> None: - """Move the selected item up within its group.""" - if self.editing_mode: - return - - # Find which table has focus - focused_table = None - for table in self.query(DataTable): - if table.has_focus: - focused_table = table - break - - if not focused_table: - return - - # Get the current cursor position - cursor_row = focused_table.cursor_row - if cursor_row <= 0: - return - - # Find the parent column to get the group name - parent_column = focused_table.parent - while parent_column and not isinstance(parent_column, LabelGroupColumn): - parent_column = parent_column.parent - - if not parent_column: - return - - group_name = parent_column.group_name # type: ignore[attr-defined] - - # Convert the group's labels to a list of items - from stride.ui.palette import ColorPalette - - labels_dict = self.label_groups[group_name] - items = [ - {"label": label, "color": color, "order": idx} - for idx, (label, color) in enumerate(labels_dict.items()) - ] - - # Move the item up - palette = ColorPalette() - if palette.move_item_up(items, cursor_row): - # Update the label_groups with new order - self.label_groups[group_name] = { - str(item["label"]): str(item["color"]) for item in items - } - - # Refresh the table - self._refresh_table(focused_table, group_name) - - # Move cursor to follow the item - focused_table.move_cursor(row=cursor_row - 1) - - self.has_unsaved_changes = True - - def action_move_down(self) -> None: - """Move the selected item down within its group.""" - if self.editing_mode: - return - - # Find which table has focus - focused_table = None - for table in self.query(DataTable): - if table.has_focus: - focused_table = table - break - - if not focused_table: - return - - # Get the current cursor position - cursor_row = focused_table.cursor_row - if cursor_row >= focused_table.row_count - 1: - return - - # Find the parent column to get the group name - parent_column = focused_table.parent - while parent_column and not isinstance(parent_column, LabelGroupColumn): - parent_column = parent_column.parent - - if not parent_column: - return - - group_name = parent_column.group_name # type: ignore[attr-defined] - - # Convert the group's labels to a list of items - from stride.ui.palette import ColorPalette - - labels_dict = self.label_groups[group_name] - items = [ - {"label": label, "color": color, "order": idx} - for idx, (label, color) in enumerate(labels_dict.items()) - ] - - # Move the item down - palette = ColorPalette() - if palette.move_item_down(items, cursor_row): - # Update the label_groups with new order - self.label_groups[group_name] = { - str(item["label"]): str(item["color"]) for item in items - } - - # Refresh the table - self._refresh_table(focused_table, group_name) - - # Move cursor to follow the item - focused_table.move_cursor(row=cursor_row + 1) - - self.has_unsaved_changes = True - - def _refresh_table(self, table: DataTable[Any], group_name: str) -> None: - """Refresh a table with updated data from label_groups. - - Parameters - ---------- - table : DataTable - The table to refresh - group_name : str - The name of the group to refresh from - """ - # Clear and repopulate the table - table.clear() - labels = self.label_groups[group_name] - - for label, color in labels.items(): - # Create a color preview using the color - rich_color = color_to_rich_format(color) - preview = Text("████", style=Style(color=rich_color)) - table.add_row(label, color, preview) - - def action_refresh(self) -> None: - """Refresh the palette display.""" - self.notify("Palette refreshed") - - def action_help(self) -> None: - """Show help information.""" - help_text = """ -Stride Palette Viewer - Keyboard Shortcuts - -Navigation: -- Arrow keys: Navigate between cells -- Tab/Shift+Tab: Move between columns - -Actions: -- a: Add new label to current group -- e: Edit color (type directly, Enter to save, Esc to cancel) -- x: Delete current label -- X: Delete current group (disabled for pre-defined groups) -- u: Move item up within its group -- d: Move item down within its group -- s: Save changes to disk -- q: Quit -- r: Refresh -- ?: Show this help -- Esc: Cancel current input - -Note: Groups (Scenarios, Model Years, Metrics) are pre-defined and cannot be deleted. -""" - self.notify(help_text, timeout=12) - - -def organize_palette_by_groups( - palette: dict[str, str] | dict[str, dict[str, str]], - project_config: Any | None = None, -) -> dict[str, dict[str, str]]: - """Organize a palette dictionary into the three pre-defined groups. - - Parameters - ---------- - palette : dict[str, str] | dict[str, dict[str, str]] - Either a flat dictionary of label -> color mappings (legacy format) or - a structured dictionary with 'scenarios', 'model_years', and 'metrics' keys. - project_config : Any | None, optional - Optional project configuration (not currently used but kept for compatibility) - - Returns - ------- - dict[str, dict[str, str]] - Nested dictionary organized by the three pre-defined groups: - 'Scenarios', 'Model Years', and 'Metrics' - """ - # Check if it's the new structured format - if ( - isinstance(palette, dict) - and "scenarios" in palette - and "model_years" in palette - and "metrics" in palette - ): - # Use the structured format directly, mapping to display names - scenarios_dict = palette.get("scenarios", {}) - model_years_dict = palette.get("model_years", {}) - metrics_dict = palette.get("metrics", {}) - - # Ensure all values are dicts of strings - result: dict[str, dict[str, str]] = { - "Scenarios": scenarios_dict if isinstance(scenarios_dict, dict) else {}, - "Model Years": model_years_dict if isinstance(model_years_dict, dict) else {}, - "Metrics": metrics_dict if isinstance(metrics_dict, dict) else {}, - } - return result - else: - # Legacy flat format - put everything in Metrics for now - metrics_palette: dict[str, str] = {} - if isinstance(palette, dict): - for key, value in palette.items(): - if isinstance(value, str): - metrics_palette[key] = value - return { - "Scenarios": {}, - "Model Years": {}, - "Metrics": metrics_palette, - } - - -def launch_palette_viewer( - palette_path: Path, - palette_type: str = "project", - project_config: Any | None = None, -) -> None: - """Launch the palette viewer TUI. - - Parameters - ---------- - palette_path : Path - Path to the palette file (project.json5 or user palette file) - palette_type : str, optional - Type of palette ('user' or 'project'), by default "project" - project_config : Any | None, optional - Optional project configuration for better label grouping - """ - if palette_type == "project": - # For project palettes, we need to load from the project config - from stride.models import ProjectConfig - - config = ProjectConfig.from_file(palette_path) - palette_dict = config.color_palette - palette_name = config.project_id - else: - # For user palettes, load directly (assuming JSON format) - import json - - with open(palette_path) as f: - data = json.load(f) - palette_dict = data.get("palette", data) - palette_name = palette_path.stem - - # Organize palette into groups - label_groups = organize_palette_by_groups(palette_dict, project_config) - - # Launch the TUI - app = PaletteViewer( - palette_name=palette_name, - palette_location=palette_path, - palette_type=palette_type, - label_groups=label_groups, - ) - app.run() - - -def get_user_palette_dir() -> Path: - """Get the user's palette directory, creating it if necessary. - - Returns - ------- - Path - Path to ~/.stride/palettes/ - """ - palette_dir = Path.home() / ".stride" / "palettes" - palette_dir.mkdir(parents=True, exist_ok=True) - return palette_dir - - -def list_user_palettes() -> list[Path]: - """List all user palettes. - - Returns - ------- - list[Path] - List of paths to user palette files - """ - palette_dir = get_user_palette_dir() - return sorted(palette_dir.glob("*.json")) - - -def save_user_palette(name: str, palette: dict[str, str] | dict[str, dict[str, str]]) -> Path: - """Save a palette to the user's palette directory. - - Parameters - ---------- - name : str - Name for the palette (will be used as filename) - palette : dict[str, str] | dict[str, dict[str, str]] - Palette dictionary to save (either flat or structured format) - - Returns - ------- - Path - Path to the saved palette file - """ - import json - - palette_dir = get_user_palette_dir() - palette_path = palette_dir / f"{name}.json" - - data = { - "name": name, - "palette": palette, - } - - with open(palette_path, "w") as f: - json.dump(data, f, indent=2) - - return palette_path - - -def load_user_palette(name: str) -> ColorPalette: - """Load a user palette by name. - - Parameters - ---------- - name : str - Name of the palette to load - - Returns - ------- - ColorPalette - Loaded color palette - - Raises - ------ - FileNotFoundError - If the palette does not exist - """ - import json - - palette_dir = get_user_palette_dir() - palette_path = palette_dir / f"{name}.json" - - if not palette_path.exists(): - msg = f"User palette '{name}' not found" - raise FileNotFoundError(msg) - - with open(palette_path) as f: - data = json.load(f) - # Handle both nested {"palette": {...}} and flat {...} structures - if isinstance(data, dict): - if "palette" in data: - palette_dict = data["palette"] - else: - palette_dict = data - else: - msg = f"Invalid palette format in {name}.json" - raise ValueError(msg) - - return ColorPalette(palette_dict) - - -def delete_user_palette(name: str) -> None: - """Delete a user palette by name. - - Parameters - ---------- - name : str - Name of the palette to delete - - Raises - ------ - FileNotFoundError - If the palette does not exist - """ - palette_dir = get_user_palette_dir() - palette_path = palette_dir / f"{name}.json" - - if not palette_path.exists(): - msg = f"User palette '{name}' not found" - raise FileNotFoundError(msg) - - palette_path.unlink() - - -def get_stride_config_dir() -> Path: - """Get the stride configuration directory, creating it if necessary. - - Returns - ------- - Path - Path to ~/.stride/ - """ - config_dir = Path.home() / ".stride" - config_dir.mkdir(parents=True, exist_ok=True) - return config_dir - - -def get_stride_config_path() -> Path: - """Get the stride configuration file path. - - Returns - ------- - Path - Path to ~/.stride/config.json - """ - return get_stride_config_dir() / "config.json" - - -def load_stride_config() -> dict[str, Any]: - """Load the stride configuration file. - - Returns - ------- - dict[str, Any] - Configuration dictionary, or empty dict if file doesn't exist - """ - import json - - config_path = get_stride_config_path() - if not config_path.exists(): - return {} - - with open(config_path) as f: - result: dict[str, Any] = json.load(f) - return result - - -def save_stride_config(config: dict[str, Any]) -> None: - """Save the stride configuration file. - - Parameters - ---------- - config : dict[str, Any] - Configuration dictionary to save - """ - import json - - config_path = get_stride_config_path() - with open(config_path, "w") as f: - json.dump(config, f, indent=2) - - -def set_default_user_palette(name: str | None) -> None: - """Set the default user palette. - - Parameters - ---------- - name : str | None - Name of the user palette to set as default, or None to clear the default - """ - config = load_stride_config() - - if name is None: - config.pop("default_user_palette", None) - else: - # Verify the palette exists - palette_dir = get_user_palette_dir() - palette_path = palette_dir / f"{name}.json" - if not palette_path.exists(): - msg = f"User palette '{name}' not found at {palette_path}" - raise FileNotFoundError(msg) - config["default_user_palette"] = name - - save_stride_config(config) - - -def get_default_user_palette() -> str | None: - """Get the default user palette name. - - Returns - ------- - str | None - Name of the default user palette, or None if not set - """ - config = load_stride_config() - return config.get("default_user_palette") - - -def get_max_cached_projects() -> int | None: - """Get the max cached projects setting from config. - - Returns - ------- - int | None - Configured max cached projects, or None if not set - """ - config = load_stride_config() - value = config.get("max_cached_projects") - if value is not None: - return int(value) - return None - - -def set_max_cached_projects(n: int) -> None: - """Set the max cached projects in the config file. - - Parameters - ---------- - n : int - Number of max cached projects (will be clamped to [1, 10]) - """ - n = max(1, min(10, n)) - config = load_stride_config() - config["max_cached_projects"] = n - save_stride_config(config) diff --git a/tests/palette/test_auto_color.py b/tests/palette/test_auto_color.py index dda07c2..9ddb706 100644 --- a/tests/palette/test_auto_color.py +++ b/tests/palette/test_auto_color.py @@ -59,7 +59,7 @@ def test_auto_color_with_existing_palette() -> None: "Existing3": "#0000FF", } - palette = ColorPalette(existing) + palette = ColorPalette.from_dict(existing) # Add new labels - should get different colors from the theme new_labels = ["New1", "New2", "New3"] @@ -86,7 +86,7 @@ def test_color_cycling() -> None: palette = ColorPalette() - # The Prism theme has multiple colors, add more labels than theme has colors + # The metric theme has multiple colors, add more labels than theme has colors # to test cycling behavior num_labels = 20 labels = [f"Label{i}" for i in range(num_labels)] @@ -132,7 +132,7 @@ def test_palette_integration() -> None: # Simulate adding labels to Scenarios group print("\n Adding to Scenarios group:") - scenarios_palette = ColorPalette(label_groups["Scenarios"]) + scenarios_palette = ColorPalette.from_dict(label_groups["Scenarios"]) for scenario in ["Baseline", "Alternative", "High Growth"]: color = scenarios_palette.get(scenario) label_groups["Scenarios"][scenario] = color @@ -140,7 +140,7 @@ def test_palette_integration() -> None: # Simulate adding labels to Sectors group print("\n Adding to Sectors group:") - sectors_palette = ColorPalette(label_groups["Sectors"]) + sectors_palette = ColorPalette.from_dict(label_groups["Sectors"]) for sector in ["Residential", "Commercial", "Industrial"]: color = sectors_palette.get(sector) label_groups["Sectors"][sector] = color @@ -148,7 +148,7 @@ def test_palette_integration() -> None: # Simulate adding labels to Years group print("\n Adding to Years group:") - years_palette = ColorPalette(label_groups["Years"]) + years_palette = ColorPalette.from_dict(label_groups["Years"]) for year in ["2025", "2030", "2035", "2040"]: color = years_palette.get(year) label_groups["Years"][year] = color diff --git a/tests/palette/test_color_manager_update.py b/tests/palette/test_color_manager_update.py index 68c650f..be8c838 100644 --- a/tests/palette/test_color_manager_update.py +++ b/tests/palette/test_color_manager_update.py @@ -3,15 +3,17 @@ import pytest from stride.ui.color_manager import ColorManager -from stride.ui.palette import ColorPalette +from stride.ui.palette import ColorCategory, ColorPalette def test_color_manager_updates_with_new_palette() -> None: """Test that ColorManager updates when a new palette is provided.""" - # Create first palette - palette1 = ColorPalette({"Label1": "#FF0000", "Label2": "#00FF00"}) + # Create first palette with labels in the sectors category + palette1 = ColorPalette.from_dict( + {"scenarios": {}, "model_years": {}, "metrics": {"Label1": "#FF0000", "Label2": "#00FF00"}} + ) cm1 = ColorManager(palette1) - cm1.initialize_colors(["Label1", "Label2"]) + cm1.initialize_colors([], sectors=["Label1", "Label2"]) # Get colors from first palette color1_label1 = cm1.get_color("Label1") @@ -22,9 +24,11 @@ def test_color_manager_updates_with_new_palette() -> None: assert "0, 255, 0" in color1_label2 # Green # Create second palette with different colors - palette2 = ColorPalette({"Label1": "#0000FF", "Label2": "#FFFF00"}) + palette2 = ColorPalette.from_dict( + {"scenarios": {}, "model_years": {}, "metrics": {"Label1": "#0000FF", "Label2": "#FFFF00"}} + ) cm2 = ColorManager(palette2) - cm2.initialize_colors(["Label1", "Label2"]) + cm2.initialize_colors([], sectors=["Label1", "Label2"]) # Get colors from second palette color2_label1 = cm2.get_color("Label1") @@ -44,10 +48,10 @@ def test_color_manager_updates_with_new_palette() -> None: def test_color_manager_singleton_behavior() -> None: """Test that ColorManager maintains singleton behavior.""" - palette1 = ColorPalette({"A": "#111111"}) + palette1 = ColorPalette.from_dict({"scenarios": {}, "model_years": {}, "metrics": {"A": "#111111"}}) cm1 = ColorManager(palette1) - palette2 = ColorPalette({"A": "#222222"}) + palette2 = ColorPalette.from_dict({"scenarios": {}, "model_years": {}, "metrics": {"A": "#222222"}}) cm2 = ColorManager(palette2) # Should be the same instance @@ -82,15 +86,19 @@ def test_color_manager_reinitialize_colors() -> None: ColorManager._instance = None # type: ignore[misc] # First initialization - palette1 = ColorPalette( + palette1 = ColorPalette.from_dict( { - "Residential": "#E74C3C", - "Commercial": "#3498DB", - "Industrial": "#F39C12", + "scenarios": {}, + "model_years": {}, + "metrics": { + "Residential": "#E74C3C", + "Commercial": "#3498DB", + "Industrial": "#F39C12", + }, } ) cm = ColorManager(palette1) - cm.initialize_colors(["Residential", "Commercial", "Industrial"]) + cm.initialize_colors([], sectors=["Residential", "Commercial", "Industrial"]) # Store first colors res1 = cm.get_color("Residential") @@ -98,15 +106,19 @@ def test_color_manager_reinitialize_colors() -> None: ind1 = cm.get_color("Industrial") # Update with new palette - palette2 = ColorPalette( + palette2 = ColorPalette.from_dict( { - "Residential": "#2C3E50", - "Commercial": "#E74C3C", - "Industrial": "#ECF0F1", + "scenarios": {}, + "model_years": {}, + "metrics": { + "Residential": "#2C3E50", + "Commercial": "#E74C3C", + "Industrial": "#ECF0F1", + }, } ) cm2 = ColorManager(palette2) - cm2.initialize_colors(["Residential", "Commercial", "Industrial"]) + cm2.initialize_colors([], sectors=["Residential", "Commercial", "Industrial"]) # Get new colors res2 = cm2.get_color("Residential") @@ -129,7 +141,9 @@ def test_color_manager_scenario_styling_updates() -> None: # Reset singleton ColorManager._instance = None # type: ignore[misc] - palette1 = ColorPalette({"Scenario1": "#FF0000"}) + palette1 = ColorPalette.from_dict( + {"scenarios": {"Scenario1": "#FF0000"}, "model_years": {}, "metrics": {}} + ) cm1 = ColorManager(palette1) cm1.initialize_colors(["Scenario1"]) @@ -138,7 +152,9 @@ def test_color_manager_scenario_styling_updates() -> None: assert "255, 0, 0" in styling1["border"] # Update palette - palette2 = ColorPalette({"Scenario1": "#0000FF"}) + palette2 = ColorPalette.from_dict( + {"scenarios": {"Scenario1": "#0000FF"}, "model_years": {}, "metrics": {}} + ) cm2 = ColorManager(palette2) cm2.initialize_colors(["Scenario1"]) @@ -153,7 +169,7 @@ def test_color_manager_scenario_styling_updates() -> None: def test_color_manager_preserves_palette_reference() -> None: """Test that ColorManager properly references the provided palette.""" - palette = ColorPalette({"Label": "#123456"}) + palette = ColorPalette.from_dict({"scenarios": {}, "model_years": {}, "metrics": {"Label": "#123456"}}) cm = ColorManager(palette) # Get the palette back @@ -163,7 +179,7 @@ def test_color_manager_preserves_palette_reference() -> None: assert retrieved_palette is palette # Update palette directly - palette.update("NewLabel", "#ABCDEF") + palette.update("NewLabel", "#ABCDEF", category=ColorCategory.SECTOR) # ColorManager should see the change color = cm.get_color("NewLabel") @@ -179,14 +195,19 @@ def test_color_manager_multiple_scenarios() -> None: sectors = ["Residential", "Commercial", "Industrial"] # First palette - palette1 = ColorPalette( + palette1 = ColorPalette.from_dict( { - "Scenario1": "#FF0000", - "Scenario2": "#00FF00", - "Scenario3": "#0000FF", - "Residential": "#FFFF00", - "Commercial": "#FF00FF", - "Industrial": "#00FFFF", + "scenarios": { + "Scenario1": "#FF0000", + "Scenario2": "#00FF00", + "Scenario3": "#0000FF", + }, + "model_years": {}, + "metrics": { + "Residential": "#FFFF00", + "Commercial": "#FF00FF", + "Industrial": "#00FFFF", + }, } ) @@ -198,14 +219,19 @@ def test_color_manager_multiple_scenarios() -> None: assert len(all_styling1) == 3 # Update palette with completely different colors - palette2 = ColorPalette( + palette2 = ColorPalette.from_dict( { - "Scenario1": "#111111", - "Scenario2": "#222222", - "Scenario3": "#333333", - "Residential": "#444444", - "Commercial": "#555555", - "Industrial": "#666666", + "scenarios": { + "Scenario1": "#111111", + "Scenario2": "#222222", + "Scenario3": "#333333", + }, + "model_years": {}, + "metrics": { + "Residential": "#444444", + "Commercial": "#555555", + "Industrial": "#666666", + }, } ) diff --git a/tests/palette/test_palette.py b/tests/palette/test_palette.py index 15f5f31..5dce0e7 100644 --- a/tests/palette/test_palette.py +++ b/tests/palette/test_palette.py @@ -4,7 +4,7 @@ import pytest -from stride.ui.palette import ColorPalette +from stride.ui.palette import ColorCategory, ColorPalette class TestColorPaletteInitialization: @@ -22,7 +22,7 @@ def test_initialization_with_dict(self) -> None: "residential": "#FF5733", "commercial": "#3498DB", } - palette = ColorPalette(palette=initial_colors) + palette = ColorPalette.from_dict(initial_colors) assert len(palette.palette) == 2 assert palette.palette["residential"] == "#FF5733" assert palette.palette["commercial"] == "#3498DB" @@ -33,7 +33,7 @@ def test_initialization_with_invalid_colors(self) -> None: "residential": "not_a_color", "commercial": "#3498DB", } - palette = ColorPalette(palette=initial_colors) + palette = ColorPalette.from_dict(initial_colors) assert len(palette.palette) == 2 assert palette.palette["commercial"] == "#3498DB" # Invalid color should be replaced with a generated one @@ -46,26 +46,26 @@ class TestColorPaletteUpdate: def test_update_with_valid_hex(self) -> None: """Test updating with a valid hex color.""" palette = ColorPalette() - palette.update("residential", "#FF5733") + palette.update("residential", "#FF5733", category=ColorCategory.SECTOR) assert palette.palette["residential"] == "#FF5733" def test_update_with_valid_hex_alpha(self) -> None: """Test updating with a valid hex color with alpha.""" palette = ColorPalette() - palette.update("residential", "#FF5733CC") + palette.update("residential", "#FF5733CC", category=ColorCategory.SECTOR) assert palette.palette["residential"] == "#FF5733CC" def test_update_with_none(self) -> None: """Test that None generates a new color.""" palette = ColorPalette() - palette.update("residential", None) + palette.update("residential", None, category=ColorCategory.SECTOR) assert "residential" in palette.palette assert palette.palette["residential"] is not None def test_update_with_invalid_string(self) -> None: """Test that invalid color strings generate new colors.""" palette = ColorPalette() - palette.update("residential", "not_a_hex") + palette.update("residential", "not_a_hex", category=ColorCategory.SECTOR) assert "residential" in palette.palette assert palette.palette["residential"] != "not_a_hex" @@ -73,13 +73,13 @@ def test_update_non_string_key_raises_error(self) -> None: """Test that non-string keys raise TypeError.""" palette = ColorPalette() with pytest.raises(TypeError, match="Key must be a string"): - palette.update(123, "#FF5733") # type: ignore[arg-type] + palette.update(123, "#FF5733", category=ColorCategory.SECTOR) # type: ignore[arg-type] def test_update_overwrites_existing(self) -> None: """Test that update overwrites existing colors.""" palette = ColorPalette() - palette.update("residential", "#FF5733") - palette.update("residential", "#3498DB") + palette.update("residential", "#FF5733", category=ColorCategory.SECTOR) + palette.update("residential", "#3498DB", category=ColorCategory.SECTOR) assert palette.palette["residential"] == "#3498DB" @@ -89,7 +89,7 @@ class TestColorPaletteGet: def test_get_existing_color(self) -> None: """Test getting an existing color.""" palette = ColorPalette() - palette.update("residential", "#FF5733") + palette.update("residential", "#FF5733", category=ColorCategory.SECTOR) color = palette.get("residential") assert color == "#FF5733" @@ -115,8 +115,8 @@ class TestColorPalettePop: def test_pop_existing_key(self) -> None: """Test popping an existing key.""" palette = ColorPalette() - palette.update("residential", "#FF5733") - color = palette.pop("residential") + palette.update("residential", "#FF5733", category=ColorCategory.SECTOR) + color = palette.pop("residential", category=ColorCategory.SECTOR) assert color == "#FF5733" assert "residential" not in palette.palette @@ -124,7 +124,7 @@ def test_pop_nonexistent_key_raises_error(self) -> None: """Test that popping a nonexistent key raises KeyError.""" palette = ColorPalette() with pytest.raises(KeyError, match="unable to remove key"): - palette.pop("nonexistent") + palette.pop("nonexistent", category=ColorCategory.SECTOR) class TestColorPaletteFromDict: @@ -171,31 +171,33 @@ def test_to_dict_empty(self) -> None: palette = ColorPalette() result = palette.to_dict() assert isinstance(result, dict) - # Structured format has 3 categories - assert len(result) == 3 + # Structured format has 4 categories + assert len(result) == 4 assert "scenarios" in result assert "model_years" in result - assert "metrics" in result + assert "sectors" in result + assert "end_uses" in result assert len(result["scenarios"]) == 0 assert len(result["model_years"]) == 0 - assert len(result["metrics"]) == 0 + assert len(result["sectors"]) == 0 + assert len(result["end_uses"]) == 0 def test_to_dict_with_colors(self) -> None: """Test converting palette with colors to dict.""" palette = ColorPalette() - palette.update("residential", "#FF5733", category="metrics") - palette.update("commercial", "#3498DB", category="metrics") + palette.update("residential", "#FF5733", category=ColorCategory.SECTOR) + palette.update("commercial", "#3498DB", category=ColorCategory.SECTOR) result = palette.to_dict() - assert "metrics" in result - assert result["metrics"]["residential"] == "#FF5733" - assert result["metrics"]["commercial"] == "#3498DB" + assert "sectors" in result + assert result["sectors"]["residential"] == "#FF5733" + assert result["sectors"]["commercial"] == "#3498DB" def test_to_dict_returns_copy(self) -> None: """Test that to_dict returns a copy, not the original dict.""" palette = ColorPalette() - palette.update("residential", "#FF5733", category="metrics") + palette.update("residential", "#FF5733", category=ColorCategory.SECTOR) result = palette.to_dict() - result["metrics"]["new_key"] = "#000000" + result["sectors"]["new_key"] = "#000000" assert "new_key" not in palette.palette @@ -205,9 +207,9 @@ class TestColorPaletteRoundTrip: def test_round_trip_preserves_colors(self) -> None: """Test that to_dict and from_dict preserve colors.""" original = ColorPalette() - original.update("residential", "#FF5733") - original.update("commercial", "#3498DB") - original.update("industrial", "#2ECC71") + original.update("residential", "#FF5733", category=ColorCategory.SECTOR) + original.update("commercial", "#3498DB", category=ColorCategory.SECTOR) + original.update("industrial", "#2ECC71", category=ColorCategory.SECTOR) # Serialize and deserialize dict_repr = original.to_dict() @@ -239,8 +241,8 @@ def test_multiple_palettes_independent(self) -> None: palette1 = ColorPalette() palette2 = ColorPalette() - palette1.update("residential", "#FF5733") - palette2.update("residential", "#3498DB") + palette1.update("residential", "#FF5733", category=ColorCategory.SECTOR) + palette2.update("residential", "#3498DB", category=ColorCategory.SECTOR) assert palette1.get("residential") == "#FF5733" assert palette2.get("residential") == "#3498DB" @@ -252,44 +254,44 @@ class TestColorPaletteHexValidation: def test_valid_6_digit_hex(self) -> None: """Test that 6-digit hex colors are accepted.""" palette = ColorPalette() - palette.update("test", "#FF5733") + palette.update("test", "#FF5733", category=ColorCategory.SECTOR) assert palette.palette["test"] == "#FF5733" def test_valid_8_digit_hex(self) -> None: """Test that 8-digit hex colors (with alpha) are accepted.""" palette = ColorPalette() - palette.update("test", "#FF5733CC") + palette.update("test", "#FF5733CC", category=ColorCategory.SECTOR) assert palette.palette["test"] == "#FF5733CC" def test_lowercase_hex(self) -> None: """Test that lowercase hex colors are accepted.""" palette = ColorPalette() - palette.update("test", "#ff5733") + palette.update("test", "#ff5733", category=ColorCategory.SECTOR) assert palette.palette["test"] == "#ff5733" def test_mixed_case_hex(self) -> None: """Test that mixed case hex colors are accepted.""" palette = ColorPalette() - palette.update("test", "#Ff5733") + palette.update("test", "#Ff5733", category=ColorCategory.SECTOR) assert palette.palette["test"] == "#Ff5733" def test_invalid_short_hex(self) -> None: """Test that short hex colors are rejected.""" palette = ColorPalette() - palette.update("test", "#F57") + palette.update("test", "#F57", category=ColorCategory.SECTOR) # Should be replaced with auto-generated color assert palette.palette["test"] != "#F57" def test_invalid_no_hash(self) -> None: """Test that colors without # are rejected.""" palette = ColorPalette() - palette.update("test", "FF5733") + palette.update("test", "FF5733", category=ColorCategory.SECTOR) assert palette.palette["test"] != "FF5733" def test_invalid_non_hex_chars(self) -> None: """Test that non-hex characters are rejected.""" palette = ColorPalette() - palette.update("test", "#GGGGGG") + palette.update("test", "#GGGGGG", category=ColorCategory.SECTOR) assert palette.palette["test"] != "#GGGGGG" @@ -382,25 +384,26 @@ def test_palette_to_grouped_items(self) -> None: palette = { "scenarios": {"baseline": "#0000FF"}, "model_years": {}, - "metrics": {"heating": "#FF0000", "cooling": "#00FF00"}, + "sectors": {"heating": "#FF0000", "cooling": "#00FF00"}, + "end_uses": {}, } result = ColorPalette.palette_to_grouped_items(palette) - assert "Metrics" in result + assert "Sectors" in result assert "Scenarios" in result - assert len(result["Metrics"]) == 2 + assert len(result["Sectors"]) == 2 assert len(result["Scenarios"]) == 1 - assert result["Metrics"][0]["label"] == "heating" - assert result["Metrics"][0]["color"] == "#FF0000" - assert result["Metrics"][0]["order"] == 0 - assert result["Metrics"][1]["label"] == "cooling" - assert result["Metrics"][1]["order"] == 1 + assert result["Sectors"][0]["label"] == "heating" + assert result["Sectors"][0]["color"] == "#FF0000" + assert result["Sectors"][0]["order"] == 0 + assert result["Sectors"][1]["label"] == "cooling" + assert result["Sectors"][1]["order"] == 1 def test_grouped_items_to_palette(self) -> None: """Test converting grouped items back to structured palette.""" grouped_items: dict[str, list[dict[str, Any]]] = { - "Metrics": [ + "Sectors": [ {"label": "heating", "color": "#FF0000", "order": 0}, {"label": "cooling", "color": "#00FF00", "order": 1}, ], @@ -411,15 +414,15 @@ def test_grouped_items_to_palette(self) -> None: result = ColorPalette.grouped_items_to_palette(grouped_items) - assert len(result) == 3 - assert result["metrics"]["heating"] == "#FF0000" - assert result["metrics"]["cooling"] == "#00FF00" + assert len(result) == 4 + assert result["sectors"]["heating"] == "#FF0000" + assert result["sectors"]["cooling"] == "#00FF00" assert result["scenarios"]["baseline"] == "#0000FF" def test_grouped_items_preserves_order(self) -> None: """Test that grouped items respects custom ordering.""" grouped_items: dict[str, list[dict[str, Any]]] = { - "Metrics": [ + "Sectors": [ {"label": "heating", "color": "#FF0000", "order": 1}, {"label": "cooling", "color": "#00FF00", "order": 0}, ], @@ -428,7 +431,7 @@ def test_grouped_items_preserves_order(self) -> None: result = ColorPalette.grouped_items_to_palette(grouped_items) # Convert back to list to check order - keys = list(result["metrics"].keys()) + keys = list(result["sectors"].keys()) # cooling should come first because it has order 0 assert keys[0] == "cooling" assert keys[1] == "heating" @@ -438,7 +441,8 @@ def test_round_trip_grouped_items(self) -> None: palette = { "scenarios": {"baseline": "#0000FF"}, "model_years": {}, - "metrics": {"heating": "#FF0000", "cooling": "#00FF00"}, + "sectors": {"heating": "#FF0000", "cooling": "#00FF00"}, + "end_uses": {}, } # Convert to grouped items @@ -451,19 +455,276 @@ def test_round_trip_grouped_items(self) -> None: # Should have same items (order might differ) assert result["scenarios"]["baseline"] == palette["scenarios"]["baseline"] - assert result["metrics"]["heating"] == palette["metrics"]["heating"] - assert result["metrics"]["cooling"] == palette["metrics"]["cooling"] + assert result["sectors"]["heating"] == palette["sectors"]["heating"] + assert result["sectors"]["cooling"] == palette["sectors"]["cooling"] def test_empty_groups(self) -> None: """Test handling of empty groups.""" palette: dict[str, dict[str, str]] = { "scenarios": {}, "model_years": {}, - "metrics": {}, + "sectors": {}, + "end_uses": {}, } result: dict[str, list[dict[str, Any]]] = ColorPalette.palette_to_grouped_items(palette) assert result == {} back = ColorPalette.grouped_items_to_palette(result) - assert back == {"scenarios": {}, "model_years": {}, "metrics": {}} + assert back == {"scenarios": {}, "model_years": {}, "sectors": {}, "end_uses": {}} + + +class TestLegacyMetricsCompat: + """Test backward compatibility with the legacy 'metrics' key format.""" + + def test_init_with_legacy_metrics_key(self) -> None: + """Test that from_dict loads legacy 'metrics' entries into sectors.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": "#0000FF"}, + "model_years": {}, + "metrics": {"residential": "#FF0000", "commercial": "#00FF00"}, + } + ) + assert palette.sectors["residential"] == "#FF0000" + assert palette.sectors["commercial"] == "#00FF00" + assert palette.scenarios["baseline"] == "#0000FF" + + def test_from_dict_with_legacy_metrics_key(self) -> None: + """Test that from_dict loads legacy 'metrics' into sectors.""" + palette = ColorPalette.from_dict( + { + "scenarios": {}, + "model_years": {}, + "metrics": {"heating": "#FF0000"}, + } + ) + assert palette.sectors["heating"] == "#FF0000" + assert len(palette.end_uses) == 0 + + def test_to_dict_emits_new_format(self) -> None: + """Test that to_dict outputs the new 4-key format even after loading legacy.""" + palette = ColorPalette.from_dict( + { + "scenarios": {}, + "model_years": {}, + "metrics": {"heating": "#FF0000"}, + } + ) + result = palette.to_dict() + assert "sectors" in result + assert "end_uses" in result + assert "metrics" not in result + assert result["sectors"]["heating"] == "#FF0000" + + def test_to_dict_legacy_emits_old_format(self) -> None: + """Test that to_dict_legacy merges sectors and end_uses under 'metrics'.""" + palette = ColorPalette() + palette.update("residential", "#FF0000", category=ColorCategory.SECTOR) + palette.update("heating", "#00FF00", category=ColorCategory.END_USE) + result = palette.to_dict_legacy() + assert "metrics" in result + assert "sectors" not in result + assert "end_uses" not in result + assert result["metrics"]["residential"] == "#FF0000" + assert result["metrics"]["heating"] == "#00FF00" + + def test_init_with_new_4key_format(self) -> None: + """Test that from_dict accepts the new 4-key format.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": "#0000FF"}, + "model_years": {"2020": "#AABBCC"}, + "sectors": {"residential": "#FF0000"}, + "end_uses": {"heating": "#00FF00"}, + } + ) + assert palette.scenarios["baseline"] == "#0000FF" + assert palette.model_years["2020"] == "#AABBCC" + assert palette.sectors["residential"] == "#FF0000" + assert palette.end_uses["heating"] == "#00FF00" + + def test_palette_to_grouped_items_legacy_compat(self) -> None: + """Test that palette_to_grouped_items handles legacy 'metrics' key.""" + palette = { + "scenarios": {}, + "model_years": {}, + "metrics": {"heating": "#FF0000", "cooling": "#00FF00"}, + } + result = ColorPalette.palette_to_grouped_items(palette) + assert "Sectors" in result + assert len(result["Sectors"]) == 2 + + +class TestMergeWithProjectDimensions: + """Test ColorPalette.merge_with_project_dimensions.""" + + def test_matched_names_keep_colors(self) -> None: + """Entries in both palette and project keep their stored color.""" + palette = ColorPalette.from_dict( + {"scenarios": {"baseline": "#AA0000", "high": "#BB0000"}, "model_years": {}, "sectors": {}, "end_uses": {}} + ) + palette.merge_with_project_dimensions(scenarios=["baseline", "high"]) + assert palette.scenarios["baseline"] == "#AA0000" + assert palette.scenarios["high"] == "#BB0000" + + def test_new_project_names_get_colors(self) -> None: + """Names in the project but not in the palette get auto-assigned colors.""" + palette = ColorPalette.from_dict( + {"scenarios": {"baseline": "#AA0000"}, "model_years": {}, "sectors": {}, "end_uses": {}} + ) + palette.merge_with_project_dimensions(scenarios=["baseline", "new_scenario"]) + assert palette.scenarios["baseline"] == "#AA0000" + assert "new_scenario" in palette.scenarios + # New scenario should get a color different from the matched one + assert palette.scenarios["new_scenario"] != "#AA0000" + + def test_extra_palette_entries_become_reserves(self) -> None: + """Palette entries not in the project are kept as reserves at the end.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": "#AA0000", "old_scenario": "#BB0000"}, + "model_years": {}, + "sectors": {}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions(scenarios=["baseline"]) + keys = list(palette.scenarios.keys()) + # Active entry first, reserve after + assert keys[0] == "baseline" + assert keys[1] == "old_scenario" + # Colors preserved + assert palette.scenarios["baseline"] == "#AA0000" + assert palette.scenarios["old_scenario"] == "#BB0000" + + def test_reserve_colors_reused_before_theme(self) -> None: + """New names should get reserve colors before cycling through the theme.""" + from stride.ui.palette import TOL_BRIGHT + + palette = ColorPalette.from_dict( + { + "scenarios": {"old_one": TOL_BRIGHT[0], "old_two": TOL_BRIGHT[1]}, + "model_years": {}, + "sectors": {}, + "end_uses": {}, + } + ) + # Project has neither old name — both become reserves. + # Two new names should reuse the reserve colors. + palette.merge_with_project_dimensions(scenarios=["alpha", "beta"]) + assert palette.scenarios["alpha"] == TOL_BRIGHT[0] + assert palette.scenarios["beta"] == TOL_BRIGHT[1] + # Reserves still at the end + assert list(palette.scenarios.keys()) == ["alpha", "beta", "old_one", "old_two"] + + def test_mixed_match_and_reserve_reuse(self) -> None: + """Mix of matched, reserve-reuse, and fresh theme colors.""" + from stride.ui.palette import TOL_BRIGHT + + palette = ColorPalette.from_dict( + { + "scenarios": { + "baseline": TOL_BRIGHT[0], + "removed": TOL_BRIGHT[1], + }, + "model_years": {}, + "sectors": {}, + "end_uses": {}, + } + ) + # "baseline" matches, "removed" becomes reserve, "new_one" is new. + palette.merge_with_project_dimensions(scenarios=["baseline", "new_one"]) + assert palette.scenarios["baseline"] == TOL_BRIGHT[0] + # "new_one" should get the reserve color (TOL_BRIGHT[1]) before theme cycling + assert palette.scenarios["new_one"] == TOL_BRIGHT[1] + # "removed" is kept as reserve + assert "removed" in palette.scenarios + + def test_multiple_categories(self) -> None: + """Merge works across multiple categories simultaneously.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": "#AA0000"}, + "model_years": {}, + "sectors": {"residential": "#110000"}, + "end_uses": {"heating": "#220000"}, + } + ) + palette.merge_with_project_dimensions( + scenarios=["baseline", "new_sc"], + sectors=["residential", "commercial"], + end_uses=["heating", "cooling"], + ) + assert "new_sc" in palette.scenarios + assert "commercial" in palette.sectors + assert "cooling" in palette.end_uses + # Originals preserved + assert palette.scenarios["baseline"] == "#AA0000" + assert palette.sectors["residential"] == "#110000" + assert palette.end_uses["heating"] == "#220000" + + def test_none_categories_skipped(self) -> None: + """Passing None for a category leaves it untouched.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": "#AA0000"}, + "model_years": {}, + "sectors": {"residential": "#110000"}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions(scenarios=["baseline"]) + # Sectors untouched (None was passed) + assert palette.sectors == {"residential": "#110000"} + + def test_case_insensitive_matching(self) -> None: + """Merge normalizes names to lowercase for matching.""" + palette = ColorPalette.from_dict( + {"scenarios": {"baseline": "#AA0000"}, "model_years": {}, "sectors": {}, "end_uses": {}} + ) + palette.merge_with_project_dimensions(scenarios=["Baseline"]) + assert palette.scenarios["baseline"] == "#AA0000" + assert len(palette.scenarios) == 1 + + +class TestUpdateRequiresCategory: + """Test that update() and pop() require an explicit category.""" + + def test_update_requires_category(self) -> None: + """Calling update() without category should raise TypeError.""" + palette = ColorPalette() + with pytest.raises(TypeError): + palette.update("baseline", "#AA0000") # type: ignore[call-arg] + + def test_pop_requires_category(self) -> None: + """Calling pop() without category should raise TypeError.""" + palette = ColorPalette() + palette.update("baseline", "#AA0000", category=ColorCategory.SCENARIO) + with pytest.raises(TypeError): + palette.pop("baseline") # type: ignore[call-arg] + + def test_duplicate_label_across_categories(self) -> None: + """Same label can exist in multiple categories independently.""" + palette = ColorPalette() + palette.update("2025", "#AA0000", category=ColorCategory.MODEL_YEAR) + palette.update("2025", "#BB0000", category=ColorCategory.SECTOR) + assert palette.model_years["2025"] == "#AA0000" + assert palette.sectors["2025"] == "#BB0000" + + def test_roundtrip_temp_edits_with_explicit_category(self) -> None: + """Simulate the save flow: copy palette, apply edits with category.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": "#AA0000"}, + "model_years": {"2025": "#BB0000"}, + "sectors": {"residential": "#CC0000"}, + "end_uses": {"heating": "#DD0000"}, + } + ) + palette_copy = palette.copy() + palette_copy.update("2025", "#112233", category=ColorCategory.MODEL_YEAR) + palette_copy.update("residential", "#445566", category=ColorCategory.SECTOR) + + assert palette_copy.model_years["2025"] == "#112233" + assert palette_copy.sectors["residential"] == "#445566" diff --git a/tests/palette/test_palette_init.py b/tests/palette/test_palette_init.py index 447db34..b104c82 100644 --- a/tests/palette/test_palette_init.py +++ b/tests/palette/test_palette_init.py @@ -5,7 +5,7 @@ from stride.api import APIClient from stride.project import Project from stride.ui.palette import ColorPalette -from stride.ui.tui import get_user_palette_dir +from stride.ui.palette_utils import get_user_palette_dir def test_api_query_methods() -> None: @@ -95,7 +95,7 @@ def test_palette_init_from_user_palette() -> None: print("Testing palette initialization from user palette") print("=" * 80) - from stride.ui.tui import load_user_palette, save_user_palette + from stride.ui.palette_utils import load_user_palette, save_user_palette # Create a test user palette test_palette_name = "test_source_palette" @@ -120,9 +120,9 @@ def test_palette_init_from_user_palette() -> None: print(f" {label}: {color}") # Verify (loaded_dict has lowercase keys) - assert loaded_dict == {k.lower(): v for k, v in test_palette.items()}, ( - "Loaded palette should match original" - ) + assert loaded_dict == { + k.lower(): v for k, v in test_palette.items() + }, "Loaded palette should match original" print("✓ User palette loaded successfully") # Clean up diff --git a/tests/palette/test_palette_merge.py b/tests/palette/test_palette_merge.py new file mode 100644 index 0000000..0d2abba --- /dev/null +++ b/tests/palette/test_palette_merge.py @@ -0,0 +1,248 @@ +"""Tests for ColorPalette.merge_with_project_dimensions.""" + +from stride.ui.palette import ColorCategory, ColorPalette, TOL_BRIGHT, TOL_METRICS_LIGHT + + +class TestMergeMatchedNames: + """Matched names keep their stored color.""" + + def test_all_names_match(self) -> None: + """When palette and project have the same names, colors are preserved.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": "#111111", "high_growth": "#222222"}, + "model_years": {}, + "sectors": {}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions(scenarios=["baseline", "high_growth"]) + assert palette.scenarios["baseline"] == "#111111" + assert palette.scenarios["high_growth"] == "#222222" + + def test_sectors_match(self) -> None: + """Sector name matching preserves colors.""" + palette = ColorPalette.from_dict( + { + "scenarios": {}, + "model_years": {}, + "sectors": {"residential": "#AA0000", "commercial": "#BB0000"}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions(sectors=["residential", "commercial"]) + assert palette.sectors["residential"] == "#AA0000" + assert palette.sectors["commercial"] == "#BB0000" + + +class TestMergeUnmatchedProjectNames: + """Project names not in palette get new theme colors, skipping used ones.""" + + def test_extra_project_scenarios(self) -> None: + """Project has more scenarios than the palette — extras get theme colors.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": TOL_BRIGHT[0]}, + "model_years": {}, + "sectors": {}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions( + scenarios=["baseline", "new_scenario_1", "new_scenario_2"] + ) + assert palette.scenarios["baseline"] == TOL_BRIGHT[0] + # New entries should get colors, and they shouldn't duplicate the matched color + new1 = palette.scenarios["new_scenario_1"] + new2 = palette.scenarios["new_scenario_2"] + assert new1 != TOL_BRIGHT[0] + assert new2 != TOL_BRIGHT[0] + assert new1 != new2 + + def test_completely_new_scenarios(self) -> None: + """Project has entirely different scenario names — all get fresh colors.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"old_scenario": "#AAAAAA"}, + "model_years": {}, + "sectors": {}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions(scenarios=["alpha", "beta"]) + # Alpha and beta should have colors + assert "alpha" in palette.scenarios + assert "beta" in palette.scenarios + # Old scenario should still be present as reserve + assert "old_scenario" in palette.scenarios + + def test_skips_used_colors(self) -> None: + """New entries skip colors already used by matched entries.""" + # Give baseline the first TOL_BRIGHT color + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": TOL_BRIGHT[0]}, + "model_years": {}, + "sectors": {}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions(scenarios=["baseline", "extra"]) + # "extra" should NOT get TOL_BRIGHT[0] since it's used by "baseline" + assert palette.scenarios["extra"] != TOL_BRIGHT[0] + + +class TestMergeReserveEntries: + """Extra palette entries not in the project are kept as reserves.""" + + def test_reserves_kept(self) -> None: + """Palette entries not in the project remain and appear after active entries.""" + palette = ColorPalette.from_dict( + { + "scenarios": { + "reserve_1": "#111111", + "reserve_2": "#222222", + "active": "#333333", + }, + "model_years": {}, + "sectors": {}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions(scenarios=["active"]) + # Active entry is preserved + assert palette.scenarios["active"] == "#333333" + # Reserves are kept + assert palette.scenarios["reserve_1"] == "#111111" + assert palette.scenarios["reserve_2"] == "#222222" + # Active entry comes first in iteration order + keys = list(palette.scenarios.keys()) + assert keys[0] == "active" + + def test_large_user_palette_reserves(self) -> None: + """A user palette with 20 sector colors applied to an 8-sector project.""" + big_palette_data: dict[str, str] = {} + for i in range(20): + big_palette_data[f"sector_{i}"] = f"#{i:02x}{i:02x}{i:02x}" + + palette = ColorPalette.from_dict( + { + "scenarios": {}, + "model_years": {}, + "sectors": big_palette_data, + "end_uses": {}, + } + ) + + project_sectors = [f"sector_{i}" for i in range(8)] + palette.merge_with_project_dimensions(sectors=project_sectors) + + # All 20 entries should still exist + assert len(palette.sectors) == 20 + # First 8 keys should be the project sectors + keys = list(palette.sectors.keys()) + assert keys[:8] == project_sectors + # Colors are preserved for matched entries + for i in range(8): + assert palette.sectors[f"sector_{i}"] == f"#{i:02x}{i:02x}{i:02x}" + + +class TestMergeOrdering: + """Merged palette has project entries first, then reserves.""" + + def test_project_order_preserved(self) -> None: + """Active entries appear in the order provided by the project.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"c": "#CC0000", "a": "#AA0000", "b": "#BB0000"}, + "model_years": {}, + "sectors": {}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions(scenarios=["b", "a"]) + keys = list(palette.scenarios.keys()) + assert keys[0] == "b" + assert keys[1] == "a" + # "c" is a reserve and comes after + assert keys[2] == "c" + + +class TestMergeMultipleCategories: + """Merge can operate on multiple categories at once.""" + + def test_merge_scenarios_and_sectors(self) -> None: + """Merging affects both scenarios and sectors simultaneously.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": "#111111"}, + "model_years": {}, + "sectors": {"residential": "#AAAAAA"}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions( + scenarios=["baseline", "high"], + sectors=["residential", "commercial"], + ) + assert len(palette.scenarios) == 2 + assert len(palette.sectors) == 2 + assert palette.scenarios["baseline"] == "#111111" + assert palette.sectors["residential"] == "#AAAAAA" + # New entries have valid hex colors + assert palette.scenarios["high"].startswith("#") + assert palette.sectors["commercial"].startswith("#") + + def test_none_categories_skipped(self) -> None: + """Categories passed as None are not touched.""" + palette = ColorPalette.from_dict( + { + "scenarios": {"baseline": "#111111"}, + "model_years": {}, + "sectors": {"residential": "#AAAAAA"}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions(scenarios=["baseline"]) + # Sectors should be untouched + assert palette.sectors["residential"] == "#AAAAAA" + assert len(palette.sectors) == 1 + + +class TestMergeCaseInsensitive: + """Merge handles case normalization.""" + + def test_mixed_case_matching(self) -> None: + """Project names with different casing still match palette entries.""" + palette = ColorPalette.from_dict( + { + "scenarios": {}, + "model_years": {}, + "sectors": {"residential": "#FF0000"}, + "end_uses": {}, + } + ) + palette.merge_with_project_dimensions(sectors=["Residential"]) + # Should match (keys are lowered) + assert "residential" in palette.sectors + assert palette.sectors["residential"] == "#FF0000" + + +class TestMergeEndUses: + """End uses merge correctly with their own theme.""" + + def test_end_uses_merge(self) -> None: + """End uses get metric theme colors for new entries.""" + palette = ColorPalette.from_dict( + { + "scenarios": {}, + "model_years": {}, + "sectors": {}, + "end_uses": {"heating": "#FF0000"}, + } + ) + palette.merge_with_project_dimensions(end_uses=["heating", "cooling", "lighting"]) + assert palette.end_uses["heating"] == "#FF0000" + assert "cooling" in palette.end_uses + assert "lighting" in palette.end_uses + assert len(palette.end_uses) == 3 diff --git a/tests/palette/test_palette_override.py b/tests/palette/test_palette_override.py index 70a4bec..1148092 100644 --- a/tests/palette/test_palette_override.py +++ b/tests/palette/test_palette_override.py @@ -2,9 +2,9 @@ import pytest -from stride.ui.tui import ( +from stride.config import get_stride_config_path +from stride.ui.palette_utils import ( get_default_user_palette, - get_stride_config_path, load_user_palette, save_user_palette, set_default_user_palette, @@ -20,7 +20,7 @@ def test_user_palette_save_and_load(tmp_path, monkeypatch) -> None: # type: ign def mock_get_user_palette_dir(): # type: ignore[no-untyped-def] return palette_dir - monkeypatch.setattr("stride.ui.tui.get_user_palette_dir", mock_get_user_palette_dir) + monkeypatch.setattr("stride.ui.palette_utils.get_user_palette_dir", mock_get_user_palette_dir) # Create a test palette test_palette = { @@ -54,8 +54,8 @@ def mock_get_stride_config_dir(): # type: ignore[no-untyped-def] def mock_get_user_palette_dir(): # type: ignore[no-untyped-def] return palette_dir - monkeypatch.setattr("stride.ui.tui.get_stride_config_dir", mock_get_stride_config_dir) - monkeypatch.setattr("stride.ui.tui.get_user_palette_dir", mock_get_user_palette_dir) + monkeypatch.setattr("stride.config.get_stride_config_dir", mock_get_stride_config_dir) + monkeypatch.setattr("stride.ui.palette_utils.get_user_palette_dir", mock_get_user_palette_dir) # Create a test palette file test_palette = {"Residential": "#FF0000"} @@ -92,8 +92,8 @@ def mock_get_stride_config_dir(): # type: ignore[no-untyped-def] def mock_get_user_palette_dir(): # type: ignore[no-untyped-def] return palette_dir - monkeypatch.setattr("stride.ui.tui.get_stride_config_dir", mock_get_stride_config_dir) - monkeypatch.setattr("stride.ui.tui.get_user_palette_dir", mock_get_user_palette_dir) + monkeypatch.setattr("stride.config.get_stride_config_dir", mock_get_stride_config_dir) + monkeypatch.setattr("stride.ui.palette_utils.get_user_palette_dir", mock_get_user_palette_dir) # Try to set a non-existent palette as default with pytest.raises(FileNotFoundError, match="not found"): @@ -115,8 +115,8 @@ def mock_get_stride_config_dir(): # type: ignore[no-untyped-def] def mock_get_user_palette_dir(): # type: ignore[no-untyped-def] return palette_dir - monkeypatch.setattr("stride.ui.tui.get_stride_config_dir", mock_get_stride_config_dir) - monkeypatch.setattr("stride.ui.tui.get_user_palette_dir", mock_get_user_palette_dir) + monkeypatch.setattr("stride.config.get_stride_config_dir", mock_get_stride_config_dir) + monkeypatch.setattr("stride.ui.palette_utils.get_user_palette_dir", mock_get_user_palette_dir) # Create test palettes save_user_palette("palette1", {"Label1": "#FF0000"}) diff --git a/tests/palette/test_palette_override_integration.py b/tests/palette/test_palette_override_integration.py index b26fabe..ee5fff4 100644 --- a/tests/palette/test_palette_override_integration.py +++ b/tests/palette/test_palette_override_integration.py @@ -5,7 +5,7 @@ import pytest from stride.ui.palette import ColorPalette -from stride.ui.tui import ( +from stride.ui.palette_utils import ( get_default_user_palette, load_user_palette, save_user_palette, @@ -28,8 +28,8 @@ def mock_get_stride_config_dir(): # type: ignore[no-untyped-def] def mock_get_user_palette_dir(): # type: ignore[no-untyped-def] return palette_dir - monkeypatch.setattr("stride.ui.tui.get_stride_config_dir", mock_get_stride_config_dir) - monkeypatch.setattr("stride.ui.tui.get_user_palette_dir", mock_get_user_palette_dir) + monkeypatch.setattr("stride.config.get_stride_config_dir", mock_get_stride_config_dir) + monkeypatch.setattr("stride.ui.palette_utils.get_user_palette_dir", mock_get_user_palette_dir) return { "config_dir": config_dir, diff --git a/tests/palette/test_rgb_color_format.py b/tests/palette/test_rgb_color_format.py index 3efd853..472b194 100644 --- a/tests/palette/test_rgb_color_format.py +++ b/tests/palette/test_rgb_color_format.py @@ -3,12 +3,12 @@ import pytest from stride.ui.color_manager import ColorManager -from stride.ui.palette import ColorPalette +from stride.ui.palette import ColorCategory, ColorPalette def test_rgb_format_in_palette() -> None: """Test that ColorPalette accepts rgb() format colors.""" - palette = ColorPalette( + palette = ColorPalette.from_dict( { "Label1": "rgb(200,0,0)", "Label2": "rgb(0, 200, 0)", @@ -24,7 +24,7 @@ def test_rgb_format_in_palette() -> None: def test_rgba_format_in_palette() -> None: """Test that ColorPalette accepts rgba() format colors.""" - palette = ColorPalette( + palette = ColorPalette.from_dict( { "Label1": "rgba(200,0,0,0.5)", "Label2": "rgba(0, 200, 0, 0.8)", @@ -40,7 +40,7 @@ def test_rgba_format_in_palette() -> None: def test_hex_format_in_palette() -> None: """Test that ColorPalette still accepts hex format colors.""" - palette = ColorPalette( + palette = ColorPalette.from_dict( { "Label1": "#FF0000", "Label2": "#00FF00", @@ -56,7 +56,7 @@ def test_hex_format_in_palette() -> None: def test_mixed_formats_in_palette() -> None: """Test that ColorPalette accepts mixed format colors.""" - palette = ColorPalette( + palette = ColorPalette.from_dict( { "Hex": "#FF0000", "RGB": "rgb(0,255,0)", @@ -77,16 +77,20 @@ def test_color_manager_with_rgb_format() -> None: # Reset singleton ColorManager._instance = None # type: ignore[misc] - palette = ColorPalette( + palette = ColorPalette.from_dict( { - "Residential": "rgb(200,0,0)", - "Commercial": "rgb(0,200,0)", - "Industrial": "rgb(0,0,200)", + "scenarios": {}, + "model_years": {}, + "metrics": { + "Residential": "rgb(200,0,0)", + "Commercial": "rgb(0,200,0)", + "Industrial": "rgb(0,0,200)", + }, } ) cm = ColorManager(palette) - cm.initialize_colors(["Residential", "Commercial", "Industrial"]) + cm.initialize_colors([], sectors=["Residential", "Commercial", "Industrial"]) # Colors should be converted to rgba format res = cm.get_color("Residential") @@ -103,10 +107,12 @@ def test_color_manager_with_no_spaces_rgb() -> None: # Reset singleton ColorManager._instance = None # type: ignore[misc] - palette = ColorPalette({"Label": "rgb(123,45,67)"}) + palette = ColorPalette.from_dict( + {"scenarios": {}, "model_years": {}, "metrics": {"Label": "rgb(123,45,67)"}} + ) cm = ColorManager(palette) - cm.initialize_colors(["Label"]) + cm.initialize_colors([], sectors=["Label"]) color = cm.get_color("Label") @@ -122,7 +128,9 @@ def test_color_manager_scenario_styling_with_rgb() -> None: # Reset singleton ColorManager._instance = None # type: ignore[misc] - palette = ColorPalette({"Scenario1": "rgb(255,100,50)"}) + palette = ColorPalette.from_dict( + {"scenarios": {"Scenario1": "rgb(255,100,50)"}, "model_years": {}, "metrics": {}} + ) cm = ColorManager(palette) cm.initialize_colors(["Scenario1"]) @@ -148,7 +156,7 @@ def test_palette_update_with_rgb() -> None: palette = ColorPalette() # Update with rgb format - palette.update("Label1", "rgb(100,150,200)") + palette.update("Label1", "rgb(100,150,200)", category=ColorCategory.SECTOR) # Should be preserved assert palette.get("Label1") == "rgb(100,150,200)" @@ -159,7 +167,7 @@ def test_palette_update_with_rgba() -> None: palette = ColorPalette() # Update with rgba format - palette.update("Label1", "rgba(100,150,200,0.7)") + palette.update("Label1", "rgba(100,150,200,0.7)", category=ColorCategory.SECTOR) # Should be preserved assert palette.get("Label1") == "rgba(100,150,200,0.7)" @@ -183,7 +191,7 @@ def test_palette_from_dict_with_rgb() -> None: def test_invalid_rgb_format_gets_replaced() -> None: """Test that invalid rgb() format gets replaced with theme color.""" - palette = ColorPalette( + palette = ColorPalette.from_dict( { "Valid": "rgb(100,100,100)", "Invalid1": "rgb(300,400,500)", # Values out of range (but still valid syntax) @@ -204,7 +212,7 @@ def test_invalid_rgb_format_gets_replaced() -> None: def test_rgb_with_various_spacing() -> None: """Test rgb() colors with various spacing patterns.""" - palette = ColorPalette( + palette = ColorPalette.from_dict( { "NoSpaces": "rgb(10,20,30)", "AllSpaces": "rgb(10, 20, 30)", @@ -248,18 +256,25 @@ def test_end_to_end_rgb_workflow() -> None: # Simulate colors from project.json5 file (rgb format) project_colors = { - "baseline": "rgb(56,166,165)", - "alternate": "rgb(29,105,150)", - "Residential": "rgb(200,0,0)", - "Commercial": "rgb(0,200,0)", + "scenarios": { + "baseline": "rgb(56,166,165)", + "alternate": "rgb(29,105,150)", + }, + "model_years": {}, + "metrics": { + "Residential": "rgb(200,0,0)", + "Commercial": "rgb(0,200,0)", + }, } # Create palette - palette = ColorPalette(project_colors) + palette = ColorPalette.from_dict(project_colors) # Verify palette preserves rgb format - for label, color in project_colors.items(): - assert palette.palette[label.lower()] == color + assert palette.scenarios["baseline"] == "rgb(56,166,165)" + assert palette.scenarios["alternate"] == "rgb(29,105,150)" + assert palette.sectors["residential"] == "rgb(200,0,0)" + assert palette.sectors["commercial"] == "rgb(0,200,0)" # Initialize ColorManager cm = ColorManager(palette) diff --git a/tests/palette/test_save_user_palette.py b/tests/palette/test_save_user_palette.py new file mode 100644 index 0000000..441d2d7 --- /dev/null +++ b/tests/palette/test_save_user_palette.py @@ -0,0 +1,278 @@ +"""Test that user-edited colors survive the save-as-user-palette round-trip. + +Reproduces the bug where a user changes a color in the settings UI, +clicks "Save As User Palette", enters a name, clicks again, and the +custom color vanishes both from the display and (seemingly) from disk. + +Root cause: ``set_ui_theme()`` unconditionally overwrites all model-year +colors with freshly sampled iridescent values. The file on disk IS +correct, but every subsequent load goes through ``create_fresh_color_manager`` +→ ``set_ui_theme`` → all model-year edits are lost. +""" + +from __future__ import annotations + +import json +from itertools import cycle +from unittest.mock import patch + +import pytest + +from stride.ui.color_manager import ColorManager +from stride.ui.palette import ColorPalette +from stride.ui.settings.callbacks import register_settings_callbacks +from stride.ui.settings.layout import ( + clear_temp_color_edits, + set_temp_color_edit, +) + + +# ── helpers ────────────────────────────────────────────────────────────── +def _make_palette() -> ColorPalette: + """Build a small palette with all four categories populated.""" + data = { + "scenarios": {"reference": "#4477AA", "high_growth": "#CCBB44"}, + "model_years": {"2025": "#9B8AC4", "2030": "#9A709E", "2035": "#906388"}, + "sectors": {"residential": "#CC6677", "commercial": "#999933"}, + "end_uses": {"heating": "#5289C7", "cooling": "#117733"}, + } + return ColorPalette.from_dict(data) + + +CUSTOM_HEX = "#FF0000" # unmistakable custom color + +# Parametrize tuples: (composite_key_category, dict_key, label) +_ALL_CATEGORIES = [ + ("scenario", "scenarios", "reference"), + ("model_year", "model_years", "2030"), + ("sector", "sectors", "residential"), + ("end_use", "end_uses", "heating"), +] + + +def _palette_dict(palette: ColorPalette, dict_key: str) -> dict[str, str]: + """Return the palette attribute dict for a to_dict() key name.""" + return { + "scenarios": palette.scenarios, + "model_years": palette.model_years, + "sectors": palette.sectors, + "end_uses": palette.end_uses, + }[dict_key] + + +# ── tests ──────────────────────────────────────────────────────────────── + + +class TestSetUiThemePreservesCustomColors: + """Unit tests for ``set_ui_theme`` preserving user-customised colors. + + These isolate the ``set_ui_theme`` method directly, without involving + callbacks or disk I/O, so a regression in ``set_ui_theme`` is + immediately pinpointed. + """ + + @pytest.mark.parametrize("cat, dict_key, label", _ALL_CATEGORIES) + @pytest.mark.parametrize("theme", ["light", "dark"]) + def test_edit_survives_set_ui_theme( + self, cat: str, dict_key: str, label: str, theme: str + ) -> None: + """Custom colors must survive ``set_ui_theme``.""" + palette = _make_palette() + palette.update(label, CUSTOM_HEX, category=cat) + + palette.set_ui_theme(theme) + + assert _palette_dict(palette, dict_key)[label] == CUSTOM_HEX, ( + f"set_ui_theme('{theme}') clobbered the custom {dict_key} " f"color for '{label}'" + ) + + +# ── helper: create a real ColorManager bypassing the singleton ─────────── +def _make_color_manager(palette: ColorPalette) -> ColorManager: + """Build a standalone ``ColorManager`` (no singleton side-effects).""" + cm = object.__new__(ColorManager) + cm._initialized = False + cm._scenario_colors = {} + ColorManager.__init__(cm, palette) + cm.initialize_colors( + scenarios=list(palette.scenarios.keys()), + sectors=list(palette.sectors.keys()), + end_uses=list(palette.end_uses.keys()), + ) + return cm + + +from typing import Any, Callable + + +def _capture_settings_callbacks( + get_dh: Callable[[], Any], + get_cm: Callable[[], ColorManager], + on_change: Callable[[ColorPalette, Any, Any], Any], +) -> dict[str, Callable[..., Any]]: + """Call ``register_settings_callbacks`` with a mocked ``@callback``. + + Returns a dict mapping function name → the original (unwrapped) + callback function. Closure variables (``get_color_manager_func``, + ``on_palette_change_func``, etc.) are bound to the arguments passed + here. + """ + captured: dict[str, Callable[..., Any]] = {} + + def fake_callback( + *_args: Any, **_kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Replacement for ``dash.callback`` – just record the function.""" + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + captured[func.__name__] = func + return func + + return decorator + + with patch("stride.ui.settings.callbacks.callback", fake_callback): + register_settings_callbacks(get_dh, get_cm, on_change) + + return captured + + +class TestSaveCallbackDirectly: + """Invoke the *real* ``save_to_new_palette`` callback via MagicMock. + + By mocking ``dash.callback`` as a no-op decorator we can capture the + nested function that ``register_settings_callbacks`` creates and call + it directly with controlled arguments. + """ + + def setup_method(self) -> None: + clear_temp_color_edits() + + def teardown_method(self) -> None: + clear_temp_color_edits() + + # -- file-on-disk correctness ----------------------------------------- + + def test_callback_writes_custom_color_to_disk(self, tmp_path: Any, monkeypatch: Any) -> None: + """``save_to_new_palette`` must persist the edited color to JSON.""" + palette_dir = tmp_path / "palettes" + palette_dir.mkdir() + monkeypatch.setattr("stride.ui.palette_utils.get_user_palette_dir", lambda: palette_dir) + + palette = _make_palette() + cm = _make_color_manager(palette) + + on_change_calls: list[tuple[Any, ...]] = [] + cbs = _capture_settings_callbacks( + lambda: None, + lambda: cm, + lambda p, t, n: on_change_calls.append((p, t, n)), + ) + + set_temp_color_edit("model_year:2030", CUSTOM_HEX) + cbs["save_to_new_palette"](1, "my_palette") + + # The JSON on disk must contain the custom color. + raw = json.loads((palette_dir / "my_palette.json").read_text()) + assert raw["palette"]["model_years"]["2030"] == CUSTOM_HEX + + def test_callback_passes_custom_color_to_on_palette_change( + self, tmp_path: Any, monkeypatch: Any + ) -> None: + """The palette handed to ``on_palette_change_func`` must carry the edit.""" + palette_dir = tmp_path / "palettes" + palette_dir.mkdir() + monkeypatch.setattr("stride.ui.palette_utils.get_user_palette_dir", lambda: palette_dir) + + palette = _make_palette() + cm = _make_color_manager(palette) + + received_palettes: list[ColorPalette] = [] + cbs = _capture_settings_callbacks( + lambda: None, + lambda: cm, + lambda p, _t, _n: received_palettes.append(p.copy()), + ) + + set_temp_color_edit("model_year:2030", CUSTOM_HEX) + cbs["save_to_new_palette"](1, "my_palette") + + assert len(received_palettes) == 1 + assert received_palettes[0].model_years["2030"] == CUSTOM_HEX + + # -- the real bug: set_ui_theme inside on_palette_change clobbers ---- + + @pytest.mark.parametrize("theme", ["light", "dark"]) + def test_callback_color_survives_on_palette_change_chain( + self, tmp_path: Any, monkeypatch: Any, theme: str + ) -> None: + """Reproduce the full bug: save → on_palette_change → set_ui_theme. + + ``on_palette_change`` (in app.py) calls + ``create_fresh_color_manager`` which calls ``set_ui_theme``. + The custom model-year color must survive that call. + """ + palette_dir = tmp_path / "palettes" + palette_dir.mkdir() + monkeypatch.setattr("stride.ui.palette_utils.get_user_palette_dir", lambda: palette_dir) + + palette = _make_palette() + cm = _make_color_manager(palette) + + # Simulate what the real on_palette_change → create_fresh_color_manager does. + result_palettes: list[ColorPalette] = [] + + def realistic_on_palette_change(p: Any, _ptype: Any, _pname: Any) -> None: + p_copy = p.copy() + p_copy._scenario_iterator = cycle(p_copy.scenario_theme) + p_copy.set_ui_theme(theme) + result_palettes.append(p_copy) + + cbs = _capture_settings_callbacks( + lambda: None, + lambda: cm, + realistic_on_palette_change, + ) + + set_temp_color_edit("model_year:2030", CUSTOM_HEX) + cbs["save_to_new_palette"](1, "my_palette") + + assert len(result_palettes) == 1 + assert result_palettes[0].model_years["2030"] == CUSTOM_HEX, ( + f"set_ui_theme('{theme}') inside on_palette_change clobbered " + "the custom model-year color that was just saved" + ) + + @pytest.mark.parametrize("cat, dict_key, label", _ALL_CATEGORIES) + def test_callback_all_categories_survive_chain( + self, tmp_path: Any, monkeypatch: Any, cat: str, dict_key: str, label: str + ) -> None: + """Every category's custom color must survive the full callback chain.""" + palette_dir = tmp_path / "palettes" + palette_dir.mkdir() + monkeypatch.setattr("stride.ui.palette_utils.get_user_palette_dir", lambda: palette_dir) + + palette = _make_palette() + cm = _make_color_manager(palette) + + result_palettes: list[ColorPalette] = [] + + def realistic_on_palette_change(p: Any, _ptype: Any, _pname: Any) -> None: + p_copy = p.copy() + p_copy._scenario_iterator = cycle(p_copy.scenario_theme) + p_copy.set_ui_theme("light") + result_palettes.append(p_copy) + + cbs = _capture_settings_callbacks( + lambda: None, + lambda: cm, + realistic_on_palette_change, + ) + + set_temp_color_edit(f"{cat}:{label}", CUSTOM_HEX) + cbs["save_to_new_palette"](1, "my_palette") + + assert len(result_palettes) == 1 + assert _palette_dict(result_palettes[0], dict_key)[label] == CUSTOM_HEX, ( + f"Custom {dict_key} color for '{label}' was clobbered after " + "save_to_new_palette → on_palette_change → set_ui_theme" + ) diff --git a/tests/palette/test_settings_categories.py b/tests/palette/test_settings_categories.py new file mode 100644 index 0000000..94f2cc3 --- /dev/null +++ b/tests/palette/test_settings_categories.py @@ -0,0 +1,71 @@ +"""Test that all four color categories appear in the settings UI layout.""" + +from collections.abc import Sequence + +from stride.ui.color_manager import ColorManager +from stride.ui.palette import ColorCategory, ColorPalette +from stride.ui.settings.layout import create_color_preview_content + + +def _extract_headings(content: Sequence[object]) -> list[str]: + """Extract H6 heading text from the color preview content.""" + headings: list[str] = [] + for div in content: + # Each category is wrapped in an html.Div whose first child is an H6 + if hasattr(div, "children") and div.children: + first_child = div.children[0] + if hasattr(first_child, "children") and isinstance(first_child.children, str): + headings.append(first_child.children) + return headings + + +def test_all_four_categories_in_color_preview() -> None: + """All four category headings must appear in the settings color preview.""" + # Reset ColorManager singleton so we get a fresh instance + ColorManager._instance = None # type: ignore[misc] + + palette = ColorPalette() + palette.update("baseline", "#AA0000", category=ColorCategory.SCENARIO) + palette.update("2030", "#BB0000", category=ColorCategory.MODEL_YEAR) + palette.update("residential", "#CC0000", category=ColorCategory.SECTOR) + palette.update("heating", "#DD0000", category=ColorCategory.END_USE) + + cm = ColorManager(palette) + content = create_color_preview_content(cm) + + headings = _extract_headings(content) + assert "Scenarios" in headings, f"Missing 'Scenarios' heading; got {headings}" + assert "Model Years" in headings, f"Missing 'Model Years' heading; got {headings}" + assert "Sectors" in headings, f"Missing 'Sectors' heading; got {headings}" + assert "End Uses" in headings, f"Missing 'End Uses' heading; got {headings}" + + # Clean up singleton + ColorManager._instance = None # type: ignore[misc] + + +def test_duplicate_label_across_categories_both_shown() -> None: + """A label appearing in two categories must show up in both sections.""" + ColorManager._instance = None # type: ignore[misc] + + palette = ColorPalette() + palette.update("2025", "#AA0000", category=ColorCategory.MODEL_YEAR) + palette.update("2025", "#BB0000", category=ColorCategory.SECTOR) + + cm = ColorManager(palette) + content = create_color_preview_content(cm) + + headings = _extract_headings(content) + assert "Model Years" in headings + assert "Sectors" in headings + + # Count total color items — should have 2 (one per category) + total_items = 0 + for div in content: + if hasattr(div, "children") and len(div.children) >= 2: + items_div = div.children[1] # the flex-wrap div with color items + if hasattr(items_div, "children"): + total_items += len(items_div.children) + + assert total_items == 2, f"Expected 2 color items for duplicate label; got {total_items}" + + ColorManager._instance = None # type: ignore[misc] diff --git a/tests/palette/test_tol_palettes.py b/tests/palette/test_tol_palettes.py new file mode 100644 index 0000000..e8c4ae3 --- /dev/null +++ b/tests/palette/test_tol_palettes.py @@ -0,0 +1,257 @@ +"""Tests for Paul Tol color-blind-safe palette integration. + +Validates that the Tol palette constants, sample_iridescent interpolation, +and set_ui_theme method work correctly. +""" + +import pytest + +from stride.ui.palette import ( + TOL_BRIGHT, + TOL_IRIDESCENT, + TOL_METRICS_DARK, + TOL_METRICS_LIGHT, + ColorCategory, + ColorPalette, + sample_iridescent, +) + + +class TestTolPaletteConstants: + """Verify Tol palette constants are well-formed.""" + + @pytest.mark.parametrize( + "palette,expected_len", + [ + (TOL_BRIGHT, 7), + (TOL_METRICS_LIGHT, 12), + (TOL_METRICS_DARK, 16), + (TOL_IRIDESCENT, 23), + ], + ) + def test_palette_lengths(self, palette: list[str], expected_len: int) -> None: + assert len(palette) == expected_len + + @pytest.mark.parametrize( + "palette", [TOL_BRIGHT, TOL_METRICS_LIGHT, TOL_METRICS_DARK, TOL_IRIDESCENT] + ) + def test_all_hex_format(self, palette: list[str]) -> None: + for color in palette: + assert color.startswith("#"), f"{color} does not start with #" + assert len(color) == 7, f"{color} is not 7 characters" + + def test_bright_no_duplicates(self) -> None: + assert len(set(TOL_BRIGHT)) == len(TOL_BRIGHT) + + def test_metrics_light_no_duplicates(self) -> None: + assert len(set(TOL_METRICS_LIGHT)) == len(TOL_METRICS_LIGHT) + + def test_metrics_dark_no_duplicates(self) -> None: + assert len(set(TOL_METRICS_DARK)) == len(TOL_METRICS_DARK) + + +class TestSampleIridescent: + """Test the sample_iridescent interpolation function.""" + + def test_zero_colors(self) -> None: + assert sample_iridescent(0) == [] + + def test_single_color(self) -> None: + result = sample_iridescent(1, theme="light") + assert len(result) == 1 + assert result[0].startswith("#") + + def test_light_mode_range(self) -> None: + """Light mode should use idx 16–22 (7 native colors).""" + colors = sample_iridescent(7, theme="light") + assert len(colors) == 7 + # First and last should match the Iridescent endpoints for light + assert colors[0] == TOL_IRIDESCENT[16] + assert colors[-1] == TOL_IRIDESCENT[22] + + def test_dark_mode_range(self) -> None: + """Dark mode should use idx 0–19 (20 native colors).""" + colors = sample_iridescent(20, theme="dark") + assert len(colors) == 20 + assert colors[0] == TOL_IRIDESCENT[0] + assert colors[-1] == TOL_IRIDESCENT[19] + + def test_interpolation_produces_unique_colors(self) -> None: + """When requesting more colors than native, interpolation should still produce unique values.""" + colors = sample_iridescent(10, theme="light") + assert len(colors) == 10 + # All should be valid hex + for c in colors: + assert c.startswith("#") and len(c) == 7 + # Most should be unique (some edge cases could overlap) + assert len(set(colors)) >= 8 + + def test_two_colors(self) -> None: + """Requesting exactly 2 should give endpoints.""" + colors = sample_iridescent(2, theme="dark") + assert len(colors) == 2 + assert colors[0] == TOL_IRIDESCENT[0] + assert colors[-1] == TOL_IRIDESCENT[19] + + +class TestSetUiTheme: + """Test the set_ui_theme method on ColorPalette.""" + + def test_default_is_light(self) -> None: + palette = ColorPalette() + assert palette._ui_theme == "light" + assert palette.metric_theme == list(TOL_METRICS_LIGHT) + + def test_switch_to_dark(self) -> None: + palette = ColorPalette() + palette.update("metric_a", category=ColorCategory.SECTOR) + palette.update("metric_b", category=ColorCategory.SECTOR) + + palette.set_ui_theme("dark") + + assert palette._ui_theme == "dark" + assert palette.metric_theme == list(TOL_METRICS_DARK) + # Sectors should be reassigned with dark-mode colors + assert palette.sectors["metric_a"] == TOL_METRICS_DARK[0] + assert palette.sectors["metric_b"] == TOL_METRICS_DARK[1] + + def test_switch_back_to_light(self) -> None: + palette = ColorPalette() + palette.update("metric_a", category=ColorCategory.SECTOR) + palette.set_ui_theme("dark") + palette.set_ui_theme("light") + assert palette._ui_theme == "light" + assert palette.sectors["metric_a"] == TOL_METRICS_LIGHT[0] + + def test_model_years_resampled(self) -> None: + palette = ColorPalette() + palette.update("2020", category=ColorCategory.MODEL_YEAR) + palette.update("2030", category=ColorCategory.MODEL_YEAR) + palette.update("2040", category=ColorCategory.MODEL_YEAR) + + palette.set_ui_theme("dark") + dark_colors = list(palette.model_years.values()) + + palette.set_ui_theme("light") + light_colors = list(palette.model_years.values()) + + # Dark and light should use different Iridescent ranges + assert dark_colors != light_colors + + def test_invalid_theme_raises(self) -> None: + palette = ColorPalette() + with pytest.raises(ValueError, match="Invalid UI theme"): + palette.set_ui_theme("midnight") + + def test_scenarios_unchanged_by_theme(self) -> None: + """Scenarios use Tol Bright which is theme-independent.""" + palette = ColorPalette() + palette.update("scen_a", category="scenario") + color_before = palette.scenarios["scen_a"] + + palette.set_ui_theme("dark") + color_after = palette.scenarios["scen_a"] + + assert color_before == color_after + + +class TestDefaultPaletteColors: + """Test that new ColorPalette uses Tol colors by default.""" + + def test_scenario_colors_from_tol_bright(self) -> None: + palette = ColorPalette() + palette.update("test_scenario", category="scenario") + assert palette.scenarios["test_scenario"] == TOL_BRIGHT[0] + + def test_sector_colors_from_tol_metrics_light(self) -> None: + palette = ColorPalette() + palette.update("test_metric", category=ColorCategory.SECTOR) + assert palette.sectors["test_metric"] == TOL_METRICS_LIGHT[0] + + def test_model_year_colors_from_tol_iridescent(self) -> None: + palette = ColorPalette() + palette.update("2020", category="model_year") + assert palette.model_years["2020"] == TOL_IRIDESCENT[0] + + +class TestIndependentBreakdownColors: + """Verify that sectors and end-uses get independent color sequences. + + Both should start from position 0 in the metric theme, regardless of + which group was registered first or how many items the other group has. + """ + + def test_sectors_and_end_uses_start_from_same_first_color(self) -> None: + """Sectors and end-uses should both begin at metric_theme[0].""" + palette = ColorPalette() + + # Register several sectors first + palette.update("residential", category=ColorCategory.SECTOR) + palette.update("commercial", category=ColorCategory.SECTOR) + palette.update("industrial", category=ColorCategory.SECTOR) + + # Now register end-uses — these should NOT continue after industrial + palette.update("heating", category=ColorCategory.END_USE) + + assert palette.sectors["residential"] == TOL_METRICS_LIGHT[0] + assert palette.sectors["commercial"] == TOL_METRICS_LIGHT[1] + assert palette.sectors["industrial"] == TOL_METRICS_LIGHT[2] + # Key assertion: heating starts at [0], not [3] + assert palette.end_uses["heating"] == TOL_METRICS_LIGHT[0] + + def test_end_uses_registered_first_then_sectors(self) -> None: + """Order shouldn't matter — reversing registration order works too.""" + palette = ColorPalette() + + palette.update("heating", category=ColorCategory.END_USE) + palette.update("cooling", category=ColorCategory.END_USE) + + palette.update("residential", category=ColorCategory.SECTOR) + + # End-uses got [0] and [1]; sectors restart at [0] + assert palette.end_uses["heating"] == TOL_METRICS_LIGHT[0] + assert palette.end_uses["cooling"] == TOL_METRICS_LIGHT[1] + assert palette.sectors["residential"] == TOL_METRICS_LIGHT[0] + + def test_get_also_uses_independent_iterators(self) -> None: + """palette.get() should use the same independent iterators as update().""" + palette = ColorPalette() + + # Use get() to auto-assign colors — pass ColorCategory directly + s1 = palette.get("sector_a", category=ColorCategory.SECTOR) + s2 = palette.get("sector_b", category=ColorCategory.SECTOR) + e1 = palette.get("end_use_a", category=ColorCategory.END_USE) + e2 = palette.get("end_use_b", category=ColorCategory.END_USE) + + assert s1 == TOL_METRICS_LIGHT[0] + assert s2 == TOL_METRICS_LIGHT[1] + # End-uses restart at [0] + assert e1 == TOL_METRICS_LIGHT[0] + assert e2 == TOL_METRICS_LIGHT[1] + + def test_interleaved_registration_stays_independent(self) -> None: + """Interleaving sector and end-use registrations keeps sequences separate.""" + palette = ColorPalette() + + palette.update("sector_1", category=ColorCategory.SECTOR) + palette.update("end_use_1", category=ColorCategory.END_USE) + palette.update("sector_2", category=ColorCategory.SECTOR) + palette.update("end_use_2", category=ColorCategory.END_USE) + + assert palette.sectors["sector_1"] == TOL_METRICS_LIGHT[0] + assert palette.sectors["sector_2"] == TOL_METRICS_LIGHT[1] + assert palette.end_uses["end_use_1"] == TOL_METRICS_LIGHT[0] + assert palette.end_uses["end_use_2"] == TOL_METRICS_LIGHT[1] + + def test_independent_sequences_after_theme_switch(self) -> None: + """After switching to dark mode, sequences should still be independent.""" + palette = ColorPalette() + + palette.update("sector_a", category=ColorCategory.SECTOR) + palette.update("end_use_a", category=ColorCategory.END_USE) + + palette.set_ui_theme("dark") + + # After theme switch, both should be reassigned from dark theme position 0 + assert palette.sectors["sector_a"] == TOL_METRICS_DARK[0] + assert palette.end_uses["end_use_a"] == TOL_METRICS_DARK[0] diff --git a/tests/test_api.py b/tests/test_api.py index 6aa912f..311acab 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -109,6 +109,54 @@ def test_get_annual_electricity_consumption_with_breakdown(api_client: APIClient assert "sector" in df.columns +def test_annual_consumption_sector_no_duplicates(api_client: APIClient) -> None: + """Annual consumption by Sector should have one row per (scenario, year, sector).""" + df = api_client.get_annual_electricity_consumption(group_by="Sector") + assert not df.empty + duplicates = df.duplicated(subset=["scenario", "year", "sector"]) + assert ( + not duplicates.any() + ), f"Duplicate (scenario, year, sector) rows found:\n{df[duplicates]}" + assert (df["value"] > 0).all(), "All consumption values should be positive" + + +def test_annual_consumption_enduse_no_duplicates(api_client: APIClient) -> None: + """Annual consumption by End Use should have one row per (scenario, year, metric).""" + df = api_client.get_annual_electricity_consumption(group_by="End Use") + assert not df.empty + duplicates = df.duplicated(subset=["scenario", "year", "metric"]) + assert ( + not duplicates.any() + ), f"Duplicate (scenario, year, metric) rows found:\n{df[duplicates]}" + assert (df["value"] > 0).all(), "All consumption values should be positive" + + +def test_annual_consumption_breakdown_sums_to_total(api_client: APIClient) -> None: + """Sum of sector breakdown and end-use breakdown should each equal the total.""" + scenario = api_client.scenarios[0] + year = [api_client.years[0]] + + total_df = api_client.get_annual_electricity_consumption(scenarios=[scenario], years=year) + sector_df = api_client.get_annual_electricity_consumption( + scenarios=[scenario], years=year, group_by="Sector" + ) + enduse_df = api_client.get_annual_electricity_consumption( + scenarios=[scenario], years=year, group_by="End Use" + ) + assert not total_df.empty + + total_value = total_df["value"].iloc[0] + sector_sum = sector_df["value"].sum() + enduse_sum = enduse_df["value"].sum() + + assert ( + abs(sector_sum - total_value) / total_value < 0.001 + ), f"Sector sum {sector_sum:.0f} != total {total_value:.0f}" + assert ( + abs(enduse_sum - total_value) / total_value < 0.001 + ), f"End-use sum {enduse_sum:.0f} != total {total_value:.0f}" + + def test_get_annual_peak_demand(api_client: APIClient) -> None: """Test peak demand method executes.""" df = api_client.get_annual_peak_demand() @@ -129,6 +177,28 @@ def test_get_annual_peak_demand_with_breakdown(api_client: APIClient) -> None: assert "sector" in df.columns +def test_get_annual_peak_demand_sector_no_duplicates(api_client: APIClient) -> None: + """Peak demand by Sector should return exactly one row per (scenario, year, sector).""" + df = api_client.get_annual_peak_demand(group_by="Sector") + assert not df.empty + duplicates = df.duplicated(subset=["scenario", "year", "sector"]) + assert ( + not duplicates.any() + ), f"Duplicate (scenario, year, sector) rows found:\n{df[duplicates]}" + assert (df["value"] > 0).all(), "All peak demand values should be positive" + + +def test_get_annual_peak_demand_enduse_no_duplicates(api_client: APIClient) -> None: + """Peak demand by End Use should return exactly one row per (scenario, year, metric).""" + df = api_client.get_annual_peak_demand(group_by="End Use") + assert not df.empty + duplicates = df.duplicated(subset=["scenario", "year", "metric"]) + assert ( + not duplicates.any() + ), f"Duplicate (scenario, year, metric) rows found:\n{df[duplicates]}" + assert (df["value"] > 0).all(), "All peak demand values should be positive" + + def test_get_secondary_metric(api_client: APIClient) -> None: """Test secondary metric method executes.""" valid_scenario = api_client.scenarios[0] @@ -225,6 +295,89 @@ def test_get_time_series_comparison(api_client: APIClient) -> None: assert not df.empty +def test_get_time_series_hourly_sector_no_duplicates(api_client: APIClient) -> None: + """Hourly time series by Sector should have one row per (scenario, year, time_period, sector).""" + valid_scenario = api_client.scenarios[0] + valid_years = [api_client.years[0]] + df = api_client.get_time_series_comparison( + valid_scenario, valid_years, group_by="Sector", resample="Hourly" + ) + assert not df.empty + duplicates = df.duplicated(subset=["scenario", "year", "time_period", "sector"]) + assert ( + not duplicates.any() + ), f"Duplicate (scenario, year, time_period, sector) rows found:\n{df[duplicates]}" + assert (df["value"] > 0).all(), "All hourly values should be positive" + + +def test_get_time_series_hourly_enduse_no_duplicates(api_client: APIClient) -> None: + """Hourly time series by End Use should have one row per (scenario, year, time_period, metric).""" + valid_scenario = api_client.scenarios[0] + valid_years = [api_client.years[0]] + df = api_client.get_time_series_comparison( + valid_scenario, valid_years, group_by="End Use", resample="Hourly" + ) + assert not df.empty + duplicates = df.duplicated(subset=["scenario", "year", "time_period", "metric"]) + assert ( + not duplicates.any() + ), f"Duplicate (scenario, year, time_period, metric) rows found:\n{df[duplicates]}" + assert (df["value"] > 0).all(), "All hourly values should be positive" + + +def test_resampled_sector_mean_matches_annual(api_client: APIClient) -> None: + """Daily Mean sector values * 8760 should approximate annual consumption per sector.""" + scenario = api_client.scenarios[0] + year = [api_client.years[0]] + + daily_df = api_client.get_time_series_comparison( + scenario, year, group_by="Sector", resample="Daily Mean" + ) + annual_df = api_client.get_annual_electricity_consumption( + scenarios=[scenario], years=year, group_by="Sector" + ) + assert not daily_df.empty and not annual_df.empty + + # Mean hourly load * 8760 hours should approximate annual total + daily_sector_means = daily_df.groupby("sector")["value"].mean() + annual_totals = annual_df.set_index("sector")["value"] + + for sector in annual_totals.index: + estimated_annual = daily_sector_means[sector] * 8760 + actual_annual = annual_totals[sector] + ratio = estimated_annual / actual_annual + assert 0.95 < ratio < 1.05, ( + f"Sector '{sector}': daily mean * 8760 = {estimated_annual:.0f}, " + f"annual = {actual_annual:.0f}, ratio = {ratio:.3f} (expected ~1.0)" + ) + + +def test_resampled_enduse_mean_matches_annual(api_client: APIClient) -> None: + """Daily Mean end-use values * 8760 should approximate annual consumption per end use.""" + scenario = api_client.scenarios[0] + year = [api_client.years[0]] + + daily_df = api_client.get_time_series_comparison( + scenario, year, group_by="End Use", resample="Daily Mean" + ) + annual_df = api_client.get_annual_electricity_consumption( + scenarios=[scenario], years=year, group_by="End Use" + ) + assert not daily_df.empty and not annual_df.empty + + daily_enduse_means = daily_df.groupby("metric")["value"].mean() + annual_totals = annual_df.set_index("metric")["value"] + + for enduse in annual_totals.index: + estimated_annual = daily_enduse_means[enduse] * 8760 + actual_annual = annual_totals[enduse] + ratio = estimated_annual / actual_annual + assert 0.95 < ratio < 1.05, ( + f"End use '{enduse}': daily mean * 8760 = {estimated_annual:.0f}, " + f"annual = {actual_annual:.0f}, ratio = {ratio:.3f} (expected ~1.0)" + ) + + @pytest.mark.parametrize("group_by", literal_to_list(TimeGroup)) def test_seasonal_load_lines_time_groupings( # noqa: C901 api_client: APIClient, weekday_weekend_test_data: DuckDBPyConnection, group_by: TimeGroup diff --git a/tests/test_app_cache.py b/tests/test_app_cache.py index e72cae9..2e29427 100644 --- a/tests/test_app_cache.py +++ b/tests/test_app_cache.py @@ -483,17 +483,17 @@ def test_with_project_multiple_recent(self, tmp_path: Path) -> None: class TestConfigMaxCachedProjects: - """Tests for get_max_cached_projects / set_max_cached_projects in tui.py.""" + """Tests for get_max_cached_projects / set_max_cached_projects in config.py.""" def test_round_trip(self, tmp_path: Path) -> None: """set_max_cached_projects(n) -> get_max_cached_projects() should return n.""" - from stride.ui.tui import ( + from stride.config import ( get_max_cached_projects as tui_get, set_max_cached_projects as tui_set, ) config_file = tmp_path / "config.json" - with patch("stride.ui.tui.get_stride_config_path", return_value=config_file): + with patch("stride.config.get_stride_config_path", return_value=config_file): # Initially no config file assert tui_get() is None @@ -507,13 +507,13 @@ def test_round_trip(self, tmp_path: Path) -> None: def test_set_clamps_to_range(self, tmp_path: Path) -> None: """set_max_cached_projects should clamp values to [1, 10].""" - from stride.ui.tui import ( + from stride.config import ( get_max_cached_projects as tui_get, set_max_cached_projects as tui_set, ) config_file = tmp_path / "config.json" - with patch("stride.ui.tui.get_stride_config_path", return_value=config_file): + with patch("stride.config.get_stride_config_path", return_value=config_file): tui_set(0) assert tui_get() == 1 @@ -524,12 +524,12 @@ def test_set_preserves_other_config(self, tmp_path: Path) -> None: """set_max_cached_projects should not clobber other config keys.""" import json - from stride.ui.tui import set_max_cached_projects as tui_set + from stride.config import set_max_cached_projects as tui_set config_file = tmp_path / "config.json" config_file.write_text(json.dumps({"default_user_palette": "my_palette"})) - with patch("stride.ui.tui.get_stride_config_path", return_value=config_file): + with patch("stride.config.get_stride_config_path", return_value=config_file): tui_set(7) saved = json.loads(config_file.read_text()) diff --git a/tests/tui/test_edit_features.py b/tests/tui/test_edit_features.py deleted file mode 100644 index 8a0a653..0000000 --- a/tests/tui/test_edit_features.py +++ /dev/null @@ -1,212 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for the new color editing features in the Palette TUI. - -This script tests: -1. Live color preview in the edit dialog -2. Color validation -3. Edit dialog composition -""" - -from pathlib import Path - -from rich.style import Style -from rich.text import Text - - -def test_color_edit_screen() -> None: - """Test the color validation function.""" - from stride.ui.tui import validate_color - - print("Testing color validation...") - - # Test validation - test_cases = [ - ("#FF5733", True), - ("#FF5733FF", True), - ("rgb(255, 87, 51)", True), - ("rgba(255, 87, 51, 1.0)", True), - ("rgb(255,87,51)", True), # No spaces - ("#1E90FF", True), - ("invalid", False), - ("#GGG", False), - ("blue", False), - ("", False), - ] - - print("\n Color validation tests:") - for color, expected_valid in test_cases: - result = validate_color(color) - status = "✓" if result == expected_valid else "✗" - print(f" {status} '{color}': {result}") - - print("\n ✓ Color validation tests passed!") - - -def test_live_preview_simulation() -> None: - """Simulate the live preview behavior.""" - from stride.ui.tui import color_to_rich_format - - print("\nTesting live preview simulation...") - - colors = ["#FF5733", "#1E90FF", "rgb(26, 188, 156)", "rgba(255, 87, 51, 0.5)"] - - print(" Simulating color changes:") - for color in colors: - # Simulate what happens in on_input_changed - rich_color = color_to_rich_format(color) - preview = Text("████████████", style=Style(color=rich_color)) - print(f" Preview for '{color}' -> '{rich_color}': {preview}") - - print("✓ Live preview simulation test passed!") - - -def test_cursor_styling() -> None: - """Test that cursor styling is configured.""" - from stride.ui.tui import PaletteViewer - - print("\nTesting cursor styling configuration...") - - # Check that CSS includes cursor styling - css = PaletteViewer.CSS - if "cursor-background" in css or "cursor" in css: - print(" ✓ Cursor styling found in CSS") - else: - print(" ⚠ No cursor styling found in CSS") - - print(" ✓ Cursor styling test complete!") - - -def test_full_edit_workflow() -> None: - """Test the complete edit workflow.""" - from stride.models import ProjectConfig - from stride.ui.tui import organize_palette_by_groups - - print("\nTesting full edit workflow...") - - # Load test project - project_path = Path("test_project/project.json5") - if not project_path.exists(): - print(" ⚠ Test project not found, skipping workflow test") - return - - config = ProjectConfig.from_file(project_path) - print(f" ✓ Loaded project: {config.project_id}") - - # Organize palette - groups = organize_palette_by_groups(config.color_palette, config) - print(f" ✓ Organized into {len(groups)} groups") - - # Simulate editing a color - test_label = "residential" - if any(test_label in labels for labels in groups.values()): - print(f" ✓ Found '{test_label}' in palette") - - # Find the current color - for group_name, labels in groups.items(): - if test_label in labels: - old_color = labels[test_label] - print(f" Current color: {old_color}") - - # Simulate editing - new_color = "#1E90FF" - labels[test_label] = new_color - print(f" New color: {new_color}") - - # Verify the change - assert labels[test_label] == new_color - print(" ✓ Color updated successfully") - break - else: - print(f" ⚠ '{test_label}' not found in palette") - - print(" ✓ Full edit workflow test passed!") - - -def test_color_preview_widget() -> None: - """Test the color preview widget rendering.""" - from stride.ui.tui import color_to_rich_format - - print("\nTesting color preview widget...") - - colors = [ - "#FF5733", - "#1E90FF", - "rgb(26, 188, 156)", - "rgba(255, 87, 51, 0.8)", - ] - - print(" Creating preview widgets:") - for color in colors: - rich_color = color_to_rich_format(color) - print(f" ✓ Created preview for {color} -> {rich_color}") - - print("✓ Color preview widget test passed!") - - -def test_edit_dialog_composition() -> None: - """Test that the PaletteViewer can be instantiated.""" - from pathlib import Path - - from stride.ui.tui import PaletteViewer - - print("\nTesting PaletteViewer instantiation...") - - # Create a simple test palette with label groups - test_label_groups = {"Test Group": {"test_label": "#FF5733", "another_label": "#1E90FF"}} - - viewer = PaletteViewer( - palette_name="test_palette", - palette_location=Path("/tmp/test_palette.json"), - palette_type="user", - label_groups=test_label_groups, - ) - print(" ✓ Created PaletteViewer instance") - - # Verify basic attributes - assert viewer.palette_name == "test_palette" - assert viewer.palette_type == "user" - print(" ✓ Palette name and type set correctly") - - # Note: Can't call compose() outside of app context - print(" ℹ Skipping compose() test (requires app context)") - - print(" ✓ PaletteViewer instantiation test passed!") - - -def main() -> int: - """Run all tests.""" - print("=" * 60) - print("Testing Palette TUI Edit Features") - print("=" * 60) - - try: - test_color_edit_screen() - test_live_preview_simulation() - test_cursor_styling() - test_color_preview_widget() - test_edit_dialog_composition() - test_full_edit_workflow() - - print("\n" + "=" * 60) - print("✓ All edit feature tests passed!") - print("=" * 60) - print("\nTo test interactively:") - print(" stride palette view test_project --project") - print(" - Navigate to a color with arrow keys") - print(" - Press 'e' to see the live preview") - print(" - Type different color values to see preview update") - print(" - The cursor should be lighter and not mask colors") - - except Exception as e: - print(f"\n✗ Test failed with error: {e}") - import traceback - - traceback.print_exc() - return 1 - - return 0 - - -if __name__ == "__main__": - exit(main()) diff --git a/tests/tui/test_palette_tui.py b/tests/tui/test_palette_tui.py deleted file mode 100644 index 92b7c22..0000000 --- a/tests/tui/test_palette_tui.py +++ /dev/null @@ -1,191 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify the palette TUI can be instantiated and basic functions work. - -This script tests the palette TUI without actually launching it (which would require a TTY). -""" - -from pathlib import Path - -from stride.models import ProjectConfig -from stride.ui.tui import ( - get_user_palette_dir, - list_user_palettes, - organize_palette_by_groups, - save_user_palette, -) - - -def test_organize_palette() -> None: - """Test palette organization into groups.""" - print("Testing palette organization...") - - # Sample palette with various label types - test_palette = { - "residential": "#5F4690", - "commercial": "#FF5733", - "industrial": "#3498DB", - "transportation": "#E74C3C", - "baseline": "rgb(115, 175, 72)", - "alternate_gdp": "rgb(56, 166, 165)", - "2025": "rgb(237, 173, 8)", - "2030": "rgb(204, 80, 62)", - "2035": "rgb(111, 64, 112)", - "cooling": "#1ABC9C", - "heating": "#E67E22", - "lighting": "#9B59B6", - "water_heating": "#16A085", - "other_label": "#95A5A6", - } - - # Organize the palette - groups = organize_palette_by_groups(test_palette) - - print(f"\nOrganized into {len(groups)} groups:") - for group_name, labels in groups.items(): - print(f" {group_name}: {len(labels)} labels") - for label in sorted(labels.keys()): - print(f" - {label}: {labels[label]}") - - # Verify expected groups exist - assert "Scenarios" in groups - assert "Model Years" in groups - assert "Metrics" in groups - - # Verify correct categorization (everything goes to Metrics for flat palette) - assert "residential" in groups["Metrics"] - assert "2025" in groups["Metrics"] - assert "baseline" in groups["Metrics"] - assert "cooling" in groups["Metrics"] - # In legacy flat format, everything is categorized as Metrics - assert len(groups["Scenarios"]) == 0 - assert len(groups["Model Years"]) == 0 - assert len(groups["Metrics"]) == len(test_palette) - - print("\n✓ Palette organization test passed!") - - -def test_user_palette_operations() -> None: - """Test user palette save/load operations.""" - print("\nTesting user palette operations...") - - # Get user palette directory - palette_dir = get_user_palette_dir() - print(f"User palette directory: {palette_dir}") - assert palette_dir.exists() - - # Create a test palette - test_palette = { - "label1": "#FF0000", - "label2": "#00FF00", - "label3": "#0000FF", - } - - # Save the palette - test_name = "test_palette" - saved_path = save_user_palette(test_name, test_palette) - print(f"Saved test palette to: {saved_path}") - assert saved_path.exists() - - # List palettes - palettes = list_user_palettes() - print(f"Found {len(palettes)} user palette(s)") - - # Clean up test palette - saved_path.unlink() - print("Cleaned up test palette") - - print("✓ User palette operations test passed!") - - -def test_project_palette_loading() -> None: - """Test loading palette from a project.""" - print("\nTesting project palette loading...") - - # Path to test project - project_path = Path("test_project/project.json5") - - if not project_path.exists(): - print(f"Warning: Test project not found at {project_path}") - print("Skipping project palette test") - return - - # Load project config - config = ProjectConfig.from_file(project_path) - print(f"Loaded project: {config.project_id}") - - # Check palette - palette = config.color_palette - print(f"Project palette has {len(palette)} colors") - - # Organize into groups - groups = organize_palette_by_groups(palette, config) - print(f"\nOrganized into {len(groups)} groups:") - for group_name, labels in groups.items(): - print(f" {group_name}: {len(labels)} labels") - - print("✓ Project palette loading test passed!") - - -def test_palette_viewer_instantiation() -> None: - """Test that PaletteViewer can be instantiated.""" - print("\nTesting PaletteViewer instantiation...") - - from stride.ui.tui import PaletteViewer - - # Create test data - test_groups = { - "End Uses": { - "cooling": "#1ABC9C", - "heating": "#E67E22", - }, - "Scenarios": { - "baseline": "#73AF48", - "alternate": "#38A6A5", - }, - } - - # Instantiate the viewer (but don't run it) - app = PaletteViewer( - palette_name="test_palette", - palette_location=Path("/tmp/test.json"), - palette_type="test", - label_groups=test_groups, - ) - - print(f"Created PaletteViewer instance: {app.__class__.__name__}") - print(f" Palette name: {app.palette_name}") - print(f" Palette type: {app.palette_type}") - print(f" Label groups: {len(app.label_groups)}") - - print("✓ PaletteViewer instantiation test passed!") - - -def main() -> int: - """Run all tests.""" - print("=" * 60) - print("Palette TUI Test Suite") - print("=" * 60) - - try: - test_organize_palette() - test_user_palette_operations() - test_project_palette_loading() - test_palette_viewer_instantiation() - - print("\n" + "=" * 60) - print("✓ All tests passed!") - print("=" * 60) - - except Exception as e: - print(f"\n✗ Test failed with error: {e}") - import traceback - - traceback.print_exc() - return 1 - - return 0 - - -if __name__ == "__main__": - exit(main()) diff --git a/tests/tui/test_tui.py b/tests/tui/test_tui.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/tui/test_tui_refresh.py b/tests/tui/test_tui_refresh.py deleted file mode 100644 index 1edb4a9..0000000 --- a/tests/tui/test_tui_refresh.py +++ /dev/null @@ -1,234 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify _refresh_display works correctly in the palette TUI. - -This tests that the display refresh mechanism properly handles: -- Creating new groups -- Adding new labels -- Removing groups and labels -- Empty palette states -""" - -from pathlib import Path - -from stride.ui.tui import PaletteViewer - - -def test_refresh_display_with_groups() -> None: - """Test that _refresh_display works when adding groups.""" - print("Testing _refresh_display with groups...") - - # Create a palette viewer with some initial data - initial_groups = { - "Scenarios": { - "Baseline": "#5F4690", - "Alternative": "#1D6996", - }, - } - - app = PaletteViewer( - palette_name="test", - palette_location=Path("/tmp/test.json"), - palette_type="test", - label_groups=initial_groups, - ) - - print(f" Initial groups: {list(app.label_groups.keys())}") - assert len(app.label_groups) == 1 - - # Simulate adding a new group - app.label_groups["Sectors"] = { - "Residential": "#FF0000", - "Commercial": "#00FF00", - } - - print(f" After adding group: {list(app.label_groups.keys())}") - assert len(app.label_groups) == 2 - assert "Sectors" in app.label_groups - - print("✓ Group addition test passed!") - - -def test_refresh_display_with_labels() -> None: - """Test that _refresh_display works when adding labels.""" - print("\nTesting _refresh_display with labels...") - - groups = { - "End Uses": { - "Heating": "#E67E22", - }, - } - - app = PaletteViewer( - palette_name="test", - palette_location=Path("/tmp/test.json"), - palette_type="test", - label_groups=groups, - ) - - print(f" Initial labels in 'End Uses': {list(app.label_groups['End Uses'].keys())}") - assert len(app.label_groups["End Uses"]) == 1 - - # Simulate adding a new label - app.label_groups["End Uses"]["Cooling"] = "#1ABC9C" - - print(f" After adding label: {list(app.label_groups['End Uses'].keys())}") - assert len(app.label_groups["End Uses"]) == 2 - assert "Cooling" in app.label_groups["End Uses"] - - print("✓ Label addition test passed!") - - -def test_refresh_display_empty_palette() -> None: - """Test that _refresh_display handles empty palette.""" - print("\nTesting _refresh_display with empty palette...") - - app = PaletteViewer( - palette_name="empty", - palette_location=Path("/tmp/empty.json"), - palette_type="test", - label_groups={}, - ) - - print(f" Empty palette groups: {len(app.label_groups)}") - assert len(app.label_groups) == 0 - - # Simulate adding first group to empty palette - app.label_groups["Scenarios"] = {} - - print(f" After creating first group: {list(app.label_groups.keys())}") - assert len(app.label_groups) == 1 - assert "Scenarios" in app.label_groups - - print("✓ Empty palette test passed!") - - -def test_refresh_display_remove_items() -> None: - """Test that _refresh_display handles removing items.""" - print("\nTesting _refresh_display with item removal...") - - groups = { - "Scenarios": { - "Baseline": "#5F4690", - "Alternative": "#1D6996", - }, - "Sectors": { - "Residential": "#FF0000", - }, - } - - app = PaletteViewer( - palette_name="test", - palette_location=Path("/tmp/test.json"), - palette_type="test", - label_groups=groups, - ) - - print(f" Initial groups: {list(app.label_groups.keys())}") - assert len(app.label_groups) == 2 - - # Simulate removing a group - del app.label_groups["Sectors"] - - print(f" After removing 'Sectors': {list(app.label_groups.keys())}") - assert len(app.label_groups) == 1 - assert "Sectors" not in app.label_groups - - # Simulate removing a label - del app.label_groups["Scenarios"]["Alternative"] - - print(f" After removing 'Alternative': {list(app.label_groups['Scenarios'].keys())}") - assert len(app.label_groups["Scenarios"]) == 1 - assert "Alternative" not in app.label_groups["Scenarios"] - - print("✓ Item removal test passed!") - - -def test_palette_state_consistency() -> None: - """Test that palette state remains consistent across operations.""" - print("\nTesting palette state consistency...") - - groups: dict[str, dict[str, str]] = {} - - app = PaletteViewer( - palette_name="test", - palette_location=Path("/tmp/test.json"), - palette_type="test", - label_groups=groups, - ) - - # Simulate building up a palette - operations = [ - ("add_group", "Scenarios", None), - ("add_label", "Scenarios", ("Baseline", "#5F4690")), - ("add_label", "Scenarios", ("Alternative", "#1D6996")), - ("add_group", "Sectors", None), - ("add_label", "Sectors", ("Residential", "#FF0000")), - ("add_label", "Sectors", ("Commercial", "#00FF00")), - ("remove_label", "Scenarios", "Alternative"), - ("add_group", "Years", None), - ("add_label", "Years", ("2025", "#111111")), - ] - - for op_type, group, data in operations: - if op_type == "add_group": - app.label_groups[group] = {} - print(f" Added group: {group}") - elif op_type == "add_label" and isinstance(data, tuple): - label, color = data - app.label_groups[group][label] = color - print(f" Added label '{label}' to '{group}'") - elif op_type == "remove_label" and isinstance(data, str): - label = data - del app.label_groups[group][label] - print(f" Removed label '{label}' from '{group}'") - - # Verify final state - print("\n Final state:") - print(f" Groups: {list(app.label_groups.keys())}") - for group_name, labels in app.label_groups.items(): - print(f" {group_name}: {list(labels.keys())}") - - assert len(app.label_groups) == 3 - assert len(app.label_groups["Scenarios"]) == 1 # Removed Alternative - assert len(app.label_groups["Sectors"]) == 2 - assert len(app.label_groups["Years"]) == 1 - - print("✓ State consistency test passed!") - - -def main() -> int: - """Run all refresh display tests.""" - print("=" * 60) - print("Palette TUI _refresh_display Test Suite") - print("=" * 60) - - try: - test_refresh_display_with_groups() - test_refresh_display_with_labels() - test_refresh_display_empty_palette() - test_refresh_display_remove_items() - test_palette_state_consistency() - - print("\n" + "=" * 60) - print("✓ All _refresh_display tests passed!") - print("=" * 60) - print("\nThe _refresh_display method correctly handles:") - print(" - Adding new groups") - print(" - Adding new labels") - print(" - Removing groups and labels") - print(" - Empty palette states") - print(" - Complex state transitions") - - except Exception as e: - print(f"\n✗ Test failed with error: {e}") - import traceback - - traceback.print_exc() - return 1 - - return 0 - - -if __name__ == "__main__": - exit(main()) diff --git a/tests/tui/test_tui_reordering.py b/tests/tui/test_tui_reordering.py deleted file mode 100644 index ef93b43..0000000 --- a/tests/tui/test_tui_reordering.py +++ /dev/null @@ -1,144 +0,0 @@ -""" -Test script to verify TUI reordering functionality. - -This script creates a test project with a palette and verifies that: -1. The palette can be loaded into the TUI -2. Items maintain their order when displayed -3. Move up/down operations work correctly -""" - -from pathlib import Path - -from stride.models import ProjectConfig -from stride.ui.palette import ColorPalette -from stride.ui.tui import organize_palette_by_groups - - -def test_reordering_logic() -> None: - """Test the reordering logic without launching the TUI.""" - - print("=" * 60) - print("Testing Palette Reordering Logic") - print("=" * 60) - - # Create a test palette - palette_dict = { - "heating": "#FF0000", - "cooling": "#00FF00", - "lighting": "#0000FF", - "baseline": "#FFFF00", - "efficient": "#FF00FF", - "2030": "#00FFFF", - "2040": "#FFA500", - } - - print("\n1. Original palette order:") - for i, (label, color) in enumerate(palette_dict.items()): - print(f" {i}: {label:20s} -> {color}") - - # Organize into groups - groups = organize_palette_by_groups(palette_dict) - - print("\n2. Organized into groups:") - for group_name, group_labels in groups.items(): - print(f"\n {group_name}:") - for i, (label, color) in enumerate(group_labels.items()): - print(f" {i}: {label:20s} -> {color}") - - # Test moving items in End Uses group - if "End Uses" in groups: - print("\n3. Testing move operations on End Uses:") - - # Convert to items list - end_uses = groups["End Uses"] - items = [ - {"label": label, "color": color, "order": idx} - for idx, (label, color) in enumerate(end_uses.items()) - ] - - print(f" Original: {[item['label'] for item in items]}") - - # Move second item up - palette = ColorPalette() - palette.move_item_up(items, 1) - print(f" After move_item_up(1): {[item['label'] for item in items]}") - - # Move last item up - palette.move_item_up(items, len(items) - 1) - print(f" After move_item_up({len(items) - 1}): {[item['label'] for item in items]}") - - # Update the group dict - groups["End Uses"] = {str(item["label"]): str(item["color"]) for item in items} - - print("\n Updated End Uses group:") - for i, (label, color) in enumerate(groups["End Uses"].items()): - print(f" {i}: {label:20s} -> {color}") - - print("\n4. Verify dict order is preserved:") - # Create a new dict and verify order - test_dict = {"a": "1", "b": "2", "c": "3"} - print(f" Original: {list(test_dict.keys())}") - - # Reorder by creating new dict - items = [{"k": k, "v": v} for k, v in test_dict.items()] - items[0], items[1] = items[1], items[0] - test_dict = {str(item["k"]): str(item["v"]) for item in items} - print(f" After swap: {list(test_dict.keys())}") - - print("\n" + "=" * 60) - print("All tests completed successfully!") - print("=" * 60) - - -def test_with_project() -> None: - """Test loading a real project palette.""" - - print("\n\n" + "=" * 60) - print("Testing with Real Project") - print("=" * 60) - - # Look for a test project - test_project_path = Path("test_project/project.json5") - - if not test_project_path.exists(): - print("\nNo test project found at:", test_project_path) - print("Skipping project test.") - return - - print(f"\nLoading project from: {test_project_path}") - - try: - config = ProjectConfig.from_file(test_project_path) - print(f"Project: {config.project_id}") - print(f"Palette has {len(config.color_palette)} colors") - - # Organize palette - groups = organize_palette_by_groups(config.color_palette, config) - - print("\nPalette groups:") - for group_name, group_labels in groups.items(): - print(f"\n {group_name} ({len(group_labels)} items):") - for i, label in enumerate(group_labels.keys()): - print(f" {i}: {label}") - - print("\n" + "=" * 60) - print("Project palette loaded successfully!") - print("=" * 60) - - except Exception as e: - print(f"\nError loading project: {e}") - - -if __name__ == "__main__": - test_reordering_logic() - test_with_project() - - print("\n\n" + "=" * 60) - print("To test the TUI interactively:") - print(" 1. Run: stride palette view ") - print(" 2. Use arrow keys to navigate") - print(" 3. Press 'u' to move item up") - print(" 4. Press 'd' to move item down") - print(" 5. Press 's' to save") - print(" 6. Press 'q' to quit") - print("=" * 60) diff --git a/tests/tui/test_tui_simple.py b/tests/tui/test_tui_simple.py deleted file mode 100644 index 0f088bd..0000000 --- a/tests/tui/test_tui_simple.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple test to verify TUI rendering works correctly. -""" - -from rich.style import Style -from rich.text import Text -from textual.app import App, ComposeResult -from textual.containers import Horizontal, Vertical -from textual.widgets import DataTable, Footer, Header, Label - - -class SimpleTestApp(App[None]): - """Simple test app to verify DataTable works.""" - - CSS = """ - Screen { - background: $surface; - } - - .test-column { - width: 1fr; - height: 100%; - margin: 1; - padding: 1; - background: $panel; - border: solid $accent; - } - - DataTable { - height: 1fr; - } - """ - - BINDINGS = [("q", "quit", "Quit")] - - def compose(self) -> ComposeResult: - """Compose the UI.""" - yield Header(show_clock=True) - yield Label("Testing DataTable Display") - - with Horizontal(): - with Vertical(classes="test-column"): - yield Label("[bold cyan]Test Group 1[/bold cyan]") - yield DataTable(id="table1") - - with Vertical(classes="test-column"): - yield Label("[bold cyan]Test Group 2[/bold cyan]") - yield DataTable(id="table2") - - yield Footer() - - def on_mount(self) -> None: - """Populate tables after mounting.""" - # Table 1 - table1 = self.query_one("#table1", DataTable) - table1.add_columns("Label", "Color", "Preview") - table1.cursor_type = "row" - - test_data_1 = { - "residential": "#5F4690", - "commercial": "#FF5733", - "industrial": "#3498DB", - } - - for label, color in test_data_1.items(): - preview = Text("████", style=Style(color=color)) - table1.add_row(label, color, preview) - - # Table 2 - table2 = self.query_one("#table2", DataTable) - table2.add_columns("Year", "Value") - table2.cursor_type = "row" - - test_data_2 = { - "2025": "rgb(237, 173, 8)", - "2030": "rgb(204, 80, 62)", - "2035": "rgb(111, 64, 112)", - } - - for year, color in test_data_2.items(): - preview = Text("████", style=Style(color=color)) - table2.add_row(year, preview) - - -if __name__ == "__main__": - app = SimpleTestApp() - app.run() From 5a37cc397a3986c3675b4b79e703b95a1d1fedad Mon Sep 17 00:00:00 2001 From: Matthieu Irondelle Date: Wed, 1 Apr 2026 11:35:14 +0200 Subject: [PATCH 07/10] Fix sidebar theme sync and dropdown readability in light mode --- src/stride/ui/app.py | 4 ++-- src/stride/ui/assets/dark-theme.css | 8 +++---- src/stride/ui/assets/light-theme.css | 36 ++++++++++++++++++---------- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/stride/ui/app.py b/src/stride/ui/app.py index bc673d3..727387d 100644 --- a/src/stride/ui/app.py +++ b/src/stride/ui/app.py @@ -477,7 +477,7 @@ def create_app( # noqa: C901 ), ], id="sidebar", - className="sidebar-nav dark-theme", + className=f"sidebar-nav {DEFAULT_CSS_THEME}", style={ "position": "fixed", "top": 0, @@ -1416,7 +1416,7 @@ def create_app_no_project( ), ], id="sidebar", - className="sidebar-nav dark-theme", + className=f"sidebar-nav {DEFAULT_CSS_THEME}", style={ "position": "fixed", "top": 0, diff --git a/src/stride/ui/assets/dark-theme.css b/src/stride/ui/assets/dark-theme.css index aa38c0e..fce37c2 100644 --- a/src/stride/ui/assets/dark-theme.css +++ b/src/stride/ui/assets/dark-theme.css @@ -10,8 +10,8 @@ --bg-hover: #404040; --bg-card: #252525; - --text-primary: #9e9e9e; - --text-secondary: #8a8a8a; + --text-primary: #e0e0e0; + --text-secondary: #b0b0b0; --text-muted: #808080; --border-color: #404040; @@ -22,10 +22,10 @@ --input-bg: #2d2d2d; --input-border: #404040; - --input-text: #9e9e9e; + --input-text: #e0e0e0; --dropdown-bg: #2d2d2d; - --dropdown-text: #9e9e9e; + --dropdown-text: #e0e0e0; --dropdown-hover: #3a3a3a; --modal-bg: #2d2d2d; diff --git a/src/stride/ui/assets/light-theme.css b/src/stride/ui/assets/light-theme.css index 64aef7d..165ee6f 100644 --- a/src/stride/ui/assets/light-theme.css +++ b/src/stride/ui/assets/light-theme.css @@ -135,40 +135,52 @@ body { color: var(--text-primary) !important; } -/* Sidebar dropdown theming - sidebar uses dark styling in light mode */ +/* Sidebar dropdown theming - use light theme variables to match sidebar background */ +.light-theme #sidebar .dash-dropdown { + --Dash-Fill-Inverse-Strong: var(--bg-primary); + --Dash-Text-Strong: var(--text-primary); + --Dash-Text-Weak: var(--text-secondary); + --Dash-Text-Disabled: var(--text-muted); + --Dash-Stroke-Strong: var(--border-color); + --Dash-Shading-Strong: var(--border-hover); + --Dash-Shading-Weak: var(--border-color); + --Dash-Fill-Interactive-Strong: var(--accent-primary); + --Dash-Fill-Interactive-Weak: var(--bg-hover); +} + .light-theme #sidebar .dash-dropdown-trigger { - background-color: #3a3a3a !important; - border-color: #404040 !important; + background-color: var(--bg-primary) !important; + border-color: var(--border-color) !important; } .light-theme #sidebar .dash-dropdown-value { - color: #e0e0e0 !important; + color: var(--text-primary) !important; } .light-theme #sidebar .dash-dropdown-placeholder { - color: #808080 !important; + color: var(--text-muted) !important; } .light-theme #sidebar .dash-dropdown-content { - background-color: #3a3a3a !important; - border-color: #404040 !important; + background-color: var(--bg-primary) !important; + border-color: var(--border-color) !important; } .light-theme #sidebar .dash-dropdown-option { - background-color: #3a3a3a !important; - color: #e0e0e0 !important; + background-color: var(--bg-primary) !important; + color: var(--text-primary) !important; } .light-theme #sidebar .dash-dropdown-option:hover { - background-color: #404040 !important; + background-color: var(--bg-hover) !important; } .light-theme #sidebar .dash-dropdown-trigger-icon { - color: #808080 !important; + color: var(--text-muted) !important; } .light-theme #sidebar .dash-dropdown-search { - color: #e0e0e0 !important; + color: var(--text-primary) !important; } /* Sidebar settings button - match nav tabs */ From da157998cac2e68c467828003b54b82d6d92f64f Mon Sep 17 00:00:00 2001 From: Matthieu Irondelle <80946650+irmatgit@users.noreply.github.com> Date: Wed, 8 Apr 2026 08:44:06 +0200 Subject: [PATCH 08/10] Update docs/how_tos/launch_dashboard.md Co-authored-by: Elaine Hale --- docs/how_tos/launch_dashboard.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/how_tos/launch_dashboard.md b/docs/how_tos/launch_dashboard.md index 13ff2aa..6d6b79d 100644 --- a/docs/how_tos/launch_dashboard.md +++ b/docs/how_tos/launch_dashboard.md @@ -52,8 +52,7 @@ Projects can be loaded and color palettes can be managed from the sidebar. ## Configure Max Cached Projects -By default, STRIDE keeps up to 3 projects open simultaneously. Each open project holds a DuckDB connection, -and on BlobFuse2 FUSE mounts too many concurrent connections can cause errors. +By default, STRIDE keeps up to 3 projects open simultaneously. Each open project holds a DuckDB connection, and too many concurrent connections can cause errors. You can configure this limit via three methods (highest priority first): From 5172ce483af956a2e4377e124575eaadeceff6b2 Mon Sep 17 00:00:00 2001 From: Matthieu Irondelle Date: Wed, 8 Apr 2026 10:00:48 +0200 Subject: [PATCH 09/10] fixes after EH comments --- docs/how_tos/launch_dashboard.md | 16 +++++----- src/stride/cli/stride.py | 5 +-- src/stride/config.py | 12 +++++-- src/stride/ui/app.py | 25 +++++++-------- src/stride/ui/settings/callbacks.py | 6 ++-- tests/test_app_cache.py | 49 +++++++++++++++-------------- 6 files changed, 60 insertions(+), 53 deletions(-) diff --git a/docs/how_tos/launch_dashboard.md b/docs/how_tos/launch_dashboard.md index 6d6b79d..0e39a59 100644 --- a/docs/how_tos/launch_dashboard.md +++ b/docs/how_tos/launch_dashboard.md @@ -56,6 +56,13 @@ By default, STRIDE keeps up to 3 projects open simultaneously. Each open project You can configure this limit via three methods (highest priority first): +### Settings UI + +Open the sidebar, click **Settings**, and adjust the **Max Cached Projects** value in the General section. +This persists the setting to `~/.stride/config.json`. + +Valid range is 1–10. The default is 3. + ### CLI Flag ```{eval-rst} @@ -72,11 +79,4 @@ You can configure this limit via three methods (highest priority first): .. code-block:: console $ STRIDE_MAX_CACHED_PROJECTS=5 stride view my_project -``` - -### Settings UI - -Open the sidebar, click **Settings**, and adjust the **Max Cached Projects** value in the General section. -This persists the setting to `~/.stride/config.json`. - -Valid range is 1–10. The default is 3. +``` \ No newline at end of file diff --git a/src/stride/cli/stride.py b/src/stride/cli/stride.py index c8d331a..ea2aac6 100644 --- a/src/stride/cli/stride.py +++ b/src/stride/cli/stride.py @@ -10,6 +10,7 @@ from loguru import logger from stride import Project +from stride.config import CACHED_PROJECTS_UPPER_BOUND from stride.models import CalculatedTableOverride from stride.project import list_valid_countries, list_valid_model_years, list_valid_weather_years from stride.ui.palette_utils import list_user_palettes, set_palette_priority @@ -639,9 +640,9 @@ def calculated_tables() -> None: ) @click.option( "--max-cached-projects", - type=click.IntRange(1, 10), + type=click.IntRange(1, CACHED_PROJECTS_UPPER_BOUND), default=None, - help="Maximum number of projects to keep open simultaneously (1-10, default: 3)", + help=f"Maximum number of projects to keep open simultaneously (1-{CACHED_PROJECTS_UPPER_BOUND}, default: 3)", ) @click.pass_context def view( diff --git a/src/stride/config.py b/src/stride/config.py index 03b8be7..20ceaa6 100644 --- a/src/stride/config.py +++ b/src/stride/config.py @@ -8,6 +8,9 @@ from pathlib import Path from typing import Any +CACHED_PROJECTS_UPPER_BOUND = 10 +DEFAULT_MAX_CACHED_PROJECTS = 3 + def get_stride_config_dir() -> Path: """Get the stride configuration directory, creating it if necessary. @@ -74,7 +77,10 @@ def get_max_cached_projects() -> int | None: config = load_stride_config() value = config.get("max_cached_projects") if value is not None: - return int(value) + try: + return max(1, min(CACHED_PROJECTS_UPPER_BOUND, int(value))) + except (TypeError, ValueError): + return None return None @@ -84,9 +90,9 @@ def set_max_cached_projects(n: int) -> None: Parameters ---------- n : int - Number of max cached projects (will be clamped to [1, 10]) + Number of max cached projects (will be clamped to [1, CACHED_PROJECTS_UPPER_BOUND]) """ - n = max(1, min(10, n)) + n = max(1, min(CACHED_PROJECTS_UPPER_BOUND, n)) config = load_stride_config() config["max_cached_projects"] = n save_stride_config(config) diff --git a/src/stride/ui/app.py b/src/stride/ui/app.py index 727387d..958bc91 100644 --- a/src/stride/ui/app.py +++ b/src/stride/ui/app.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from pathlib import Path from typing import Any, Callable @@ -30,7 +31,7 @@ get_temp_edits_for_category, parse_temp_edit_key, ) -from stride.config import get_max_cached_projects as _get_config_max_cached +from stride.config import CACHED_PROJECTS_UPPER_BOUND, DEFAULT_MAX_CACHED_PROJECTS, get_max_cached_projects as _get_config_max_cached from stride.ui.palette_utils import get_default_user_palette, list_user_palettes assets_path = Path(__file__).parent.absolute() / "assets" @@ -55,7 +56,6 @@ # Maximum number of projects to keep open simultaneously. # Each open project holds a DuckDB connection with file descriptors; # on BlobFuse2 FUSE mounts too many concurrent connections cause [Errno 5]. -_DEFAULT_MAX_CACHED_PROJECTS = 3 _max_cached_projects_override: int | None = None @@ -63,25 +63,26 @@ def get_max_cached_projects() -> int: """Resolve the effective max cached projects value. Priority: CLI override > STRIDE_MAX_CACHED_PROJECTS env var > config file > default (3). - Result is clamped to [1, 10]. + Result is clamped to [1, CACHED_PROJECTS_UPPER_BOUND]. """ - import os - if _max_cached_projects_override is not None: - return max(1, min(10, _max_cached_projects_override)) + return max(1, min(CACHED_PROJECTS_UPPER_BOUND, _max_cached_projects_override)) env_val = os.environ.get("STRIDE_MAX_CACHED_PROJECTS") if env_val is not None: try: - return max(1, min(10, int(env_val))) + return max(1, min(CACHED_PROJECTS_UPPER_BOUND, int(env_val))) except ValueError: - pass + logger.warning( + f"Ignoring non-numeric STRIDE_MAX_CACHED_PROJECTS={env_val!r}, " + "falling back to config/default" + ) config_val = _get_config_max_cached() if config_val is not None: - return max(1, min(10, config_val)) + return config_val - return _DEFAULT_MAX_CACHED_PROJECTS + return DEFAULT_MAX_CACHED_PROJECTS def set_max_cached_projects_override(n: int | None) -> None: @@ -96,10 +97,6 @@ def set_max_cached_projects_override(n: int | None) -> None: _max_cached_projects_override = n -# Keep module-level attribute for backwards compatibility with tests -MAX_CACHED_PROJECTS = _DEFAULT_MAX_CACHED_PROJECTS - - def _evict_oldest_project() -> None: """Evict the least-recently-used project from the cache if at capacity.""" limit = get_max_cached_projects() diff --git a/src/stride/ui/settings/callbacks.py b/src/stride/ui/settings/callbacks.py index 431d7c8..9e0ab3c 100644 --- a/src/stride/ui/settings/callbacks.py +++ b/src/stride/ui/settings/callbacks.py @@ -8,6 +8,7 @@ from dash.exceptions import PreventUpdate from loguru import logger +from stride.config import CACHED_PROJECTS_UPPER_BOUND, set_max_cached_projects from stride.ui.palette import ColorPalette from stride.ui.settings.layout import ( clear_temp_color_edits, @@ -955,14 +956,13 @@ def save_max_cached_projects( className="text-danger", ) - if n < 1 or n > 10: + if n < 1 or n > CACHED_PROJECTS_UPPER_BOUND: return html.Div( - "✗ Value must be between 1 and 10", + f"✗ Value must be between 1 and {CACHED_PROJECTS_UPPER_BOUND}", className="text-danger", ) from stride.ui.app import _evict_oldest_project, set_max_cached_projects_override - from stride.config import set_max_cached_projects # Persist to config file set_max_cached_projects(n) diff --git a/tests/test_app_cache.py b/tests/test_app_cache.py index 2e29427..e1d9aa1 100644 --- a/tests/test_app_cache.py +++ b/tests/test_app_cache.py @@ -19,6 +19,7 @@ import pytest +from stride.config import CACHED_PROJECTS_UPPER_BOUND from stride.ui import app as app_module @@ -94,9 +95,9 @@ def test_clamped_to_minimum(self) -> None: assert app_module.get_max_cached_projects() == 1 def test_clamped_to_maximum(self) -> None: - """Values above 10 should be clamped to 10.""" + """Values above CACHED_PROJECTS_UPPER_BOUND should be clamped.""" app_module._max_cached_projects_override = 99 - assert app_module.get_max_cached_projects() == 10 + assert app_module.get_max_cached_projects() == CACHED_PROJECTS_UPPER_BOUND def test_env_var_clamped(self) -> None: """Env var out of range should be clamped.""" @@ -329,10 +330,12 @@ def test_preserves_insertion_order(self) -> None: # =================================================================== -def test_max_cached_projects_is_positive_int() -> None: - """Sanity check: MAX_CACHED_PROJECTS should be a small positive integer.""" - assert isinstance(app_module.MAX_CACHED_PROJECTS, int) - assert app_module.MAX_CACHED_PROJECTS > 0 +def test_default_max_cached_projects_is_positive_int() -> None: + """Sanity check: DEFAULT_MAX_CACHED_PROJECTS should be a small positive integer.""" + from stride.config import DEFAULT_MAX_CACHED_PROJECTS + + assert isinstance(DEFAULT_MAX_CACHED_PROJECTS, int) + assert DEFAULT_MAX_CACHED_PROJECTS > 0 # =================================================================== @@ -478,7 +481,7 @@ def test_with_project_multiple_recent(self, tmp_path: Path) -> None: # =================================================================== -# Tests for config round-trip (tui.py helpers) +# Tests for config round-trip (config.py helpers) # =================================================================== @@ -488,49 +491,49 @@ class TestConfigMaxCachedProjects: def test_round_trip(self, tmp_path: Path) -> None: """set_max_cached_projects(n) -> get_max_cached_projects() should return n.""" from stride.config import ( - get_max_cached_projects as tui_get, - set_max_cached_projects as tui_set, + get_max_cached_projects as cfg_get, + set_max_cached_projects as cfg_set, ) config_file = tmp_path / "config.json" with patch("stride.config.get_stride_config_path", return_value=config_file): # Initially no config file - assert tui_get() is None + assert cfg_get() is None # Set a value - tui_set(5) - assert tui_get() == 5 + cfg_set(5) + assert cfg_get() == 5 # Update the value - tui_set(8) - assert tui_get() == 8 + cfg_set(8) + assert cfg_get() == 8 def test_set_clamps_to_range(self, tmp_path: Path) -> None: - """set_max_cached_projects should clamp values to [1, 10].""" + """set_max_cached_projects should clamp values to [1, CACHED_PROJECTS_UPPER_BOUND].""" from stride.config import ( - get_max_cached_projects as tui_get, - set_max_cached_projects as tui_set, + get_max_cached_projects as cfg_get, + set_max_cached_projects as cfg_set, ) config_file = tmp_path / "config.json" with patch("stride.config.get_stride_config_path", return_value=config_file): - tui_set(0) - assert tui_get() == 1 + cfg_set(0) + assert cfg_get() == 1 - tui_set(99) - assert tui_get() == 10 + cfg_set(99) + assert cfg_get() == CACHED_PROJECTS_UPPER_BOUND def test_set_preserves_other_config(self, tmp_path: Path) -> None: """set_max_cached_projects should not clobber other config keys.""" import json - from stride.config import set_max_cached_projects as tui_set + from stride.config import set_max_cached_projects as cfg_set config_file = tmp_path / "config.json" config_file.write_text(json.dumps({"default_user_palette": "my_palette"})) with patch("stride.config.get_stride_config_path", return_value=config_file): - tui_set(7) + cfg_set(7) saved = json.loads(config_file.read_text()) assert saved["max_cached_projects"] == 7 From a1ad20cffea9c510e97966fd1c059e4e41c43b77 Mon Sep 17 00:00:00 2001 From: Matthieu Irondelle Date: Wed, 8 Apr 2026 10:24:18 +0200 Subject: [PATCH 10/10] improve codecov --- src/stride/ui/app.py | 4 +- src/stride/ui/settings/layout.py | 4 +- tests/test_app_cache.py | 236 +++++++++++++++++++++++++++++++ 3 files changed, 240 insertions(+), 4 deletions(-) diff --git a/src/stride/ui/app.py b/src/stride/ui/app.py index 958bc91..680bf0c 100644 --- a/src/stride/ui/app.py +++ b/src/stride/ui/app.py @@ -54,8 +54,8 @@ _current_project_path: str | None = None # Maximum number of projects to keep open simultaneously. -# Each open project holds a DuckDB connection with file descriptors; -# on BlobFuse2 FUSE mounts too many concurrent connections cause [Errno 5]. +# Each open project holds a database connection with file descriptors; +# on network-mounted filesystems too many concurrent connections cause errors. _max_cached_projects_override: int | None = None diff --git a/src/stride/ui/settings/layout.py b/src/stride/ui/settings/layout.py index ea7ee8b..43e4fc2 100644 --- a/src/stride/ui/settings/layout.py +++ b/src/stride/ui/settings/layout.py @@ -150,8 +150,8 @@ def create_settings_layout( ), html.Small( "Number of projects to keep open simultaneously. " - "Each open project holds a DuckDB connection; " - "too many concurrent connections may cause errors on FUSE mounts.", + "Each open project holds a database connection; " + "too many concurrent connections may cause errors on network-mounted filesystems.", className="text-muted", ), html.Div( diff --git a/tests/test_app_cache.py b/tests/test_app_cache.py index e1d9aa1..a3b7b33 100644 --- a/tests/test_app_cache.py +++ b/tests/test_app_cache.py @@ -538,3 +538,239 @@ def test_set_preserves_other_config(self, tmp_path: Path) -> None: saved = json.loads(config_file.read_text()) assert saved["max_cached_projects"] == 7 assert saved["default_user_palette"] == "my_palette" + + +# =================================================================== +# Tests for save_max_cached_projects callback logic (callbacks.py) +# =================================================================== + + +class TestSaveMaxCachedProjectsLogic: + """Test the validation and persistence logic in save_max_cached_projects. + + The actual callback is a closure inside register_settings_callbacks, + so we replicate/test the shared logic directly. + """ + + @staticmethod + def _save_max_cached_projects_logic( + n_clicks: int | None, + value: int | None, + ) -> dict[str, Any]: + """Replicate the logic of save_max_cached_projects callback. + + Returns a dict with 'success' bool, 'message' str, and optionally 'value' int. + """ + if not n_clicks: + return {"success": False, "message": "no_click"} + + if value is None: + return {"success": False, "message": "Please enter a value"} + + try: + n = int(value) + except (TypeError, ValueError): + return {"success": False, "message": "Invalid number"} + + if n < 1 or n > CACHED_PROJECTS_UPPER_BOUND: + return { + "success": False, + "message": f"Value must be between 1 and {CACHED_PROJECTS_UPPER_BOUND}", + } + + return {"success": True, "message": f"Max cached projects set to {n}", "value": n} + + def test_no_click_returns_no_update(self) -> None: + result = self._save_max_cached_projects_logic(None, 5) + assert result["success"] is False + assert result["message"] == "no_click" + + def test_none_value_returns_error(self) -> None: + result = self._save_max_cached_projects_logic(1, None) + assert result["success"] is False + assert "enter a value" in result["message"] + + def test_valid_value_succeeds(self) -> None: + result = self._save_max_cached_projects_logic(1, 5) + assert result["success"] is True + assert result["value"] == 5 + + def test_zero_value_rejected(self) -> None: + result = self._save_max_cached_projects_logic(1, 0) + assert result["success"] is False + assert "between 1" in result["message"] + + def test_over_upper_bound_rejected(self) -> None: + result = self._save_max_cached_projects_logic(1, CACHED_PROJECTS_UPPER_BOUND + 1) + assert result["success"] is False + assert "between 1" in result["message"] + + def test_upper_bound_accepted(self) -> None: + result = self._save_max_cached_projects_logic(1, CACHED_PROJECTS_UPPER_BOUND) + assert result["success"] is True + assert result["value"] == CACHED_PROJECTS_UPPER_BOUND + + def test_value_of_one_accepted(self) -> None: + result = self._save_max_cached_projects_logic(1, 1) + assert result["success"] is True + assert result["value"] == 1 + + def test_persistence_and_eviction(self, tmp_path: Path) -> None: + """Full integration: save triggers config write, override, and eviction.""" + import json + + config_file = tmp_path / "config.json" + + # Pre-fill cache with 4 projects + app_module._max_cached_projects_override = 5 + for i in range(4): + app_module._loaded_projects[f"/{i}"] = _make_cache_entry(f"P{i}") + + with patch("stride.config.get_stride_config_path", return_value=config_file): + from stride.config import set_max_cached_projects + + set_max_cached_projects(2) + app_module.set_max_cached_projects_override(2) + app_module._evict_oldest_project() + + # Should have evicted down to 1 (limit - 1) + assert len(app_module._loaded_projects) == 1 + # Config should be persisted + saved = json.loads(config_file.read_text()) + assert saved["max_cached_projects"] == 2 + + +# =================================================================== +# Tests for settings layout override display logic (layout.py) +# =================================================================== + + +class TestSettingsLayoutOverrideLogic: + """Test the override source detection logic in create_settings_layout.""" + + @staticmethod + def _resolve_override_source( + override_val: int | None, + env_val: str | None, + ) -> str | None: + """Replicate the override source resolution from layout.py.""" + if override_val is not None: + return f"CLI flag (--max-cached-projects {override_val})" + if env_val is not None: + return f"Environment variable (STRIDE_MAX_CACHED_PROJECTS={env_val})" + return None + + def test_no_override(self) -> None: + assert self._resolve_override_source(None, None) is None + + def test_cli_override(self) -> None: + result = self._resolve_override_source(5, None) + assert result is not None + assert "CLI flag" in result + assert "5" in result + + def test_env_override(self) -> None: + result = self._resolve_override_source(None, "4") + assert result is not None + assert "Environment variable" in result + assert "4" in result + + def test_cli_takes_priority_over_env(self) -> None: + """CLI override should be shown even when env var is also set.""" + result = self._resolve_override_source(5, "4") + assert result is not None + assert "CLI flag" in result + + def test_is_overridden_when_cli_set(self) -> None: + result = self._resolve_override_source(5, None) + assert result is not None # is_overridden = True + + def test_is_overridden_when_env_set(self) -> None: + result = self._resolve_override_source(None, "3") + assert result is not None # is_overridden = True + + def test_not_overridden_when_neither_set(self) -> None: + result = self._resolve_override_source(None, None) + assert result is None # is_overridden = False + + +# =================================================================== +# Tests for CLI --max-cached-projects option (stride.py) +# =================================================================== + + +class TestCLIMaxCachedProjectsOption: + """Test the --max-cached-projects CLI option integration.""" + + def test_option_sets_override(self) -> None: + """--max-cached-projects should call set_max_cached_projects_override.""" + app_module.set_max_cached_projects_override(7) + assert app_module._max_cached_projects_override == 7 + assert app_module.get_max_cached_projects() == 7 + + def test_option_none_does_not_set_override(self) -> None: + """When --max-cached-projects is not provided (None), override should not be set.""" + # Replicate the CLI logic: if max_cached_projects is not None: set_override(n) + max_cached_projects = None + if max_cached_projects is not None: + app_module.set_max_cached_projects_override(max_cached_projects) + assert app_module._max_cached_projects_override is None + + def test_override_affects_get_max(self) -> None: + """Override via CLI should change get_max_cached_projects result.""" + with patch.object(app_module, "_get_config_max_cached", return_value=5): + # Without override + assert app_module.get_max_cached_projects() == 5 + + # With override + app_module.set_max_cached_projects_override(2) + assert app_module.get_max_cached_projects() == 2 + + +# =================================================================== +# Tests for config.py get_max_cached_projects edge cases +# =================================================================== + + +class TestConfigGetMaxCachedEdgeCases: + """Test edge cases in config.py get_max_cached_projects.""" + + def test_invalid_config_value_returns_none(self, tmp_path: Path) -> None: + """Non-numeric config value should return None.""" + import json + + from stride.config import get_max_cached_projects as cfg_get + + config_file = tmp_path / "config.json" + config_file.write_text(json.dumps({"max_cached_projects": "not_a_number"})) + + with patch("stride.config.get_stride_config_path", return_value=config_file): + assert cfg_get() is None + + def test_config_value_clamped_to_bounds(self, tmp_path: Path) -> None: + """Config values outside bounds should be clamped.""" + import json + + from stride.config import get_max_cached_projects as cfg_get + + config_file = tmp_path / "config.json" + + config_file.write_text(json.dumps({"max_cached_projects": 0})) + with patch("stride.config.get_stride_config_path", return_value=config_file): + assert cfg_get() == 1 + + config_file.write_text(json.dumps({"max_cached_projects": 99})) + with patch("stride.config.get_stride_config_path", return_value=config_file): + assert cfg_get() == CACHED_PROJECTS_UPPER_BOUND + + def test_missing_key_returns_none(self, tmp_path: Path) -> None: + """Config file without the key should return None.""" + import json + + from stride.config import get_max_cached_projects as cfg_get + + config_file = tmp_path / "config.json" + config_file.write_text(json.dumps({"other_key": "value"})) + + with patch("stride.config.get_stride_config_path", return_value=config_file): + assert cfg_get() is None