From 0962f8fa8a004ed30161e35195600f320e560b06 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Thu, 13 Nov 2025 17:05:08 -0800 Subject: [PATCH 1/9] Create design doc --- docs/visualization.md | 1031 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1031 insertions(+) create mode 100644 docs/visualization.md diff --git a/docs/visualization.md b/docs/visualization.md new file mode 100644 index 00000000..8e2a076e --- /dev/null +++ b/docs/visualization.md @@ -0,0 +1,1031 @@ +# πŸ“Š Visualization Configuration System β€” Design Document + +**Status:** Draft + +--- + +## 0. Background + +Want to be able to generate graphics in a way fully configurable via composable yaml configs following the structure of Leland Wilkinson’s _Grammar of Graphics_ + +1. Data +2. Transformations / Statistics +3. Scales +4. Coordinates +5. Aesthetics +6. Geometries +7. Guides +8. Facets / Layout +9. Layering + +The design is supposed to be agnostic to the visualization backend, but `Altair` and `plotnine` in particular map closely to this structure: + +### Altair + +- data +- transform +- scale +- (coordinates isn't exposed as strongly, only geoshapes/polar-ish transforms) +- encoding (aesthetics) +- mark (geoms) +- axis/legend (guides) +- facet +- selection +- repeat +- resolve + +Notice that Altair does not distinguish β€œAesthetic” vs β€œGeometry” vs β€œScaling” in separate top-level structuresβ€”they’re all parameters to: + +```python +Chart(...) + .transform_* + .encode(...) + .mark_* + .properties(...) + .facet(...) + .layer(...) +``` + +### ggplot2 + +- data +- statistics +- scales +- coordinates +- aesthetics +- geometries +- theme (guides are considered part of geometry or theme) +- facets +- (layering is implicit, `+` adds a layer) + +## 1. Overview + +This document describes the design of a **backend-agnostic, declarative visualization configuration system** used to generate both **interactive** and **static** plots for experiment analysis. + +Configurations are written in **YAML**, validated/structured using **Hydra**, and rendered into visualizations by pluggable **backend renderers**. + +The first supported visualization backend is **Altair** (Vega-Lite), chosen for its powerful interactivity and declarative JSON-based grammar. The system is designed so that future backends such as **plotnine**, **matplotlib**, or **plotly** can be added without modifying existing configs. + +--- + +## 2. Goals + +### 2.1 Primary Goals + +- **Declarative Specs** + Make plots fully described by YAML config files (Hydra) rather than imperative code. + +- **Backend Agnostic** + A single plot config should be renderable in: + + - Altair (interactive plots) + - Plotnine (static publication-quality) + - Matplotlib (fallback / custom needs) + - Plotly (potential future) + +- **Consistent Grammar** + Model the configuration language after a unified Grammar of Graphics: + + - data β†’ transforms β†’ aesthetics β†’ geometry β†’ guides β†’ facets. + +- **Interactive + Static Support** + Altair used for rich interactive visualizations; other backends can target static use cases. + +- **Reproducibility** + Configs should be versionable, diffable, and stable across experiments. + +### 2.2 Non-Goals + +- Implementing 100% of Vega-Lite, ggplot2, or matplotlib features. +- Creating a one-to-one mapping of primitives across backends. +- Allowing arbitrary Python code inside configs. + +--- + +## 3. Key Requirements + +### Functional Requirements + +- Ability to select and transform DataFrame sources declaratively. +- Layered graphics (multiple geometry on the same plot). +- Aesthetic mapping (x, y, color, size, shape, opacity). +- Support for: + + - transformations (filter, calculate, aggregate, bin, window) + - geometry (point, line, bar, area, etc.) + - scales (domain/range, log/linear, etc.) + - axes and legends + - facets (row, column, wrap) + - selections (Altair interactivity) + - tooltips + +- Backend routing based on config (`backend: altair`). + +### Non-Functional Requirements + +- Config schema should be stable and extensible. +- Zero runtime dependency on Altair for non-Altair backends. +- Models should not assume a specific visualization implementation. + +--- + +## 4. High-Level Architecture + +``` + +-------------------+ + | YAML Config | + +---------+---------+ + | + v + +------------------------+ + | Hydra Structured cfg | + | (PlotConfig, ...) | + +-----------+------------+ + | + v + +---------------------------+ + | Backend Renderer API | + | build_plot(plot_cfg, df) | + +-----+-----------+---------+ + | | + +--------v--+ +--v---------+ + | Altair | | Plotnine | (future) + | Renderer | | Renderer | + +-----------+ +------------+ +``` + +The **PlotConfig** object serves as the canonical intermediate representation (IR). +Each backend renderer consumes the IR and constructs native visualization objects. + +--- + +## 5. Schema Design + +The schema is defined using Python `@dataclass`es to support Hydra structured configs. + +### 5.1 Top-level + +```python +from dataclasses import dataclass, field +from typing import Any, Literal + +BackendType = Literal["altair", "plotnine", "matplotlib", "plotly"] + +@dataclass +class GraphicsConfig: + """ + Root config object you can use as your Hydra config. + + You can either: + - use fixed attributes like plot_1, plot_2, ... + - or a dict of named plots. + Here we choose a dict for flexibility. + """ + default_backend: BackendType = "altair" + + # Named plots, e.g. {"loss_over_time": PlotConfig(...), "accuracy_hist": ...} + plots: dict[str, PlotConfig] = field(default_factory=dict) +``` + +Each entry in `plots` is a named visualization. + +--- + +## 5.2 PlotConfig + +```python +@dataclass +class PlotConfig: + """ + Top-level configuration for a single plot. + """ + backend: BackendType = "altair" + + # Data + transforms + data: DataConfig = field(default_factory=DataConfig) + transforms: list[TransformConfig] = field(default_factory=list) + + # Visual structure + layers: list[LayerConfig] = field(default_factory=list) + facet: FacetConfig | None = None + + # Global sizing & guides + size: PlotSizeConfig = field(default_factory=PlotSizeConfig) + guides: PlotLevelGuideConfig = field(default_factory=PlotLevelGuideConfig) + + # Background & theme-ish options + background: str | None = None # e.g. "#ffffff" + + # Global selections (e.g. selection that multiple layers reference). + # Layer-level selections can reference plot-level selections by name. + # Execution: plot-level selections are defined first, then layer selections. + selections: list[SelectionConfig] = field(default_factory=list) +``` + +This matches core Grammar-of-Graphics components: + +- **data β†’ transformations β†’ layers β†’ facet β†’ guides**. + +--- + +## 5.3 Data + +Which dataset(s) are used. + +```python +from dataclasses import dataclass, field + +@dataclass +class DataConfig: + """ + Specifies which DataFrame to use and how to subset it. + + This is backend-agnostic: your code just needs to know how to resolve + `source` to an actual pandas.DataFrame. + """ + source: str = "main" # logical name of a DataFrame + filters: list[str] = field(default_factory=list) + # e.g. ["split == 'train'", "loss < 2.0"] + # + # You can interpret these as pandas.query expressions, + # or later adapt them to Vega-Lite filter expressions. + + # Optional subset of columns to keep + columns: list[str] | None = None +``` + +Includes dataset selection + subsetting. + +**Filter Expression Language:** +Filters use a unified expression syntax that backends interpret appropriately: + +- For pandas-based backends (plotnine, matplotlib): interpreted as pandas `.query()` expressions +- For Vega-Lite (Altair): converted to Vega-Lite filter expressions +- Expression format: Python-like syntax (e.g., `"split == 'train'"`, `"loss < 2.0"`) + +**Data Source Resolution:** +The `source` field references a logical name that must be resolved to an actual `pandas.DataFrame` via a data registry (see Section 5.13). + +--- + +## 5.4 Transformations / Statistics + +Filtering, binning, aggregation, windowing, derived calculations. + +Transform operations form a pipeline. Examples include: + +```python +from dataclasses import dataclass, field +from typing import Any, Literal + +TransformOp = Literal[ + "filter", # keep rows where expression is true + "calculate", # create or overwrite a field from an expression + "aggregate", # groupby + aggregation + "bin", # numeric binning + "window", # window functions (rank, rolling, etc.) + "fold", # wide β†’ long + "pivot", # long β†’ wide +] + +@dataclass +class TransformConfig: + """ + A single transformation step in the pipeline. + + This is a superset of what both Vega-Lite and pandas can do; each backend + can choose which transforms to support. + """ + op: TransformOp + + # For "filter": an expression string. + # Expression syntax: Python-like (e.g., "loss < 1.0", "split == 'train'"). + # Backends interpret: pandas uses .query(), Vega-Lite converts to filter expressions. + filter: str | None = None + + # For "calculate": new field and expression. + # Expression syntax: Python-like (e.g., "log(datum.loss)" for Vega-Lite, + # "log(loss)" for pandas). Backends convert as needed. + as_field: str | None = None + expr: str | None = None + + # For "aggregate": groupby fields + aggregations + groupby: list[str] | None = None + aggregations: dict[str, str] | None = None + # e.g. {"mean_loss": "mean(loss)", "max_acc": "max(accuracy)"} + + # For "bin": field to bin and new field name + field: str | None = None + binned_as: str | None = None + maxbins: int | None = None + + # For "window": window calculations (rank, rolling, etc.) + window: dict[str, str] | None = None + # e.g. {"rank_loss": "rank(loss)"} + frame: list[int | None] | None = None # e.g. [-1, 1] + + # For "fold" / "pivot": wide/long reshaping + fold_fields: list[str] | None = None + as_fields: list[str] | None = None + # etc. (you can extend as you need) +``` + +Only some backends will implement some operations. This is fineβ€”unused transforms are no-ops on unsupported backends. + +- Can be applied in pandas (plotnine/matplotlib) +- Or passed to Vega-Lite’s `transform_*` + +--- + +## 5.5 Scales + +Functions that map data values to perceptual ranges (linear, log, color maps). + +```python +from dataclasses import dataclass, field +from typing import Any, Literal + +ScaleType = Literal[ + "linear", + "log", + "sqrt", + "pow", + "symlog", + "time", + "utc", + "ordinal", + "band", + "point", +] + + +@dataclass +class ScaleConfig: + """ + Describes how data values are mapped to visual channels. + + This is generic enough to map to: + - Altair: scale=... + - ggplot/plotnine: scale_x_continuous, etc. + """ + type: ScaleType | None = None + domain: list[Any] | None = None # e.g. [0, 1] or ["a", "b", "c"] + range: list[Any] | None = None # e.g. [0, 800] or color list + clamp: bool | None = None + nice: bool | None = None + reverse: bool | None = None +``` + +Attached to each channel (aesthetics.x.scale, etc.). + +--- + +## 5.6 Aesthetics (Encodings) + +Mappings from data fields to perceptual channels: x, y, color, size, shape, opacity, etc. + +`AestheticsConfig` captures x, y, color, etc.: + +```python +from dataclasses import dataclass, field +from typing import Any, Literal + +# Channel types: "quantitative" (numeric), "ordinal" (ordered categorical), +# "nominal" (unordered categorical), "temporal" (time/date) +ChannelType = Literal["quantitative", "ordinal", "nominal", "temporal"] + +@dataclass +class ChannelAestheticsConfig: + """ + Represents one visual channel (x, y, color, size, etc.) + + This is a generic version of a Vega-Lite "encoding" entry: + field, type, aggregate, bin, scale, axis, legend, value (constant). + """ + field: str | None = None + type: ChannelType | None = None # "quantitative", "nominal", etc. + + # Either constant value OR data field – not both at once. + # Validation: if both are set, field takes precedence (or raise error at config validation). + value: Any | None = None # e.g. fixed color, size, opacity + + # Aggregation and binning + aggregate: str | None = None # e.g. "mean", "sum", "count" + bin: bool | None = None + time_unit: str | None = None # e.g. "year", "month", etc. + + # Scale / guides + scale: ScaleConfig | None = None + axis: AxisConfig | None = None + legend: LegendConfig | None = None + + # Sorting + sort: str | list[Any] | None = None + # e.g. "ascending", "descending", or explicit domain order + + +@dataclass +class AestheticsConfig: + """ + Collection of channel aesthetics for a given layer. + + These correspond to: + - Altair: chart.encode(x=..., y=..., color=..., tooltip=..., etc.) + - ggplot: aes(x=..., y=..., color=..., ...) + """ + x: ChannelAestheticsConfig | None = None + y: ChannelAestheticsConfig | None = None + + color: ChannelAestheticsConfig | None = None + size: ChannelAestheticsConfig | None = None + shape: ChannelAestheticsConfig | None = None + opacity: ChannelAestheticsConfig | None = None + + tooltip: list[ChannelAestheticsConfig] | None = None + + # Note: row/column in AestheticsConfig are deprecated in favor of + # plot-level FacetConfig. These may be used for layer-specific faceting + # in some backends, but FacetConfig.row/column should be preferred. + row: ChannelAestheticsConfig | None = None + column: ChannelAestheticsConfig | None = None +``` + +This structure is expressive enough for: + +- Altair: `.encode(x=..., y=..., color=...)` +- ggplot2: `aes(x=..., y=..., color=...)` +- Matplotlib: parameters to `scatter`, `plot`, etc. + +--- + +## 5.7 Geometry + +Visual primitives: point, line, bar, area, text. + +```python +GeometryType = Literal[ + "point", + "line", + "area", + "bar", + "rect", + "rule", + "tick", + "circle", + "square", + "text", + "boxplot", + "errorbar", + "errorband", +] + +@dataclass +class GeometryConfig: + """ + Visual primitive used for a layer. + + Backend mapping examples: + - Altair: chart.mark_line(**mark_props) + - plotnine: geom_line(**kwargs) + """ + type: GeometryType + # Arbitrary properties passed to backend's mark/geom/artist: + # e.g. {"filled": True, "interpolate": "monotone", "strokeWidth": 2} + props: dict[str, Any] = field(default_factory=dict) +``` + +Backend mapping: + +- Altair: `mark_bar`, `mark_point`, … +- plotnine: `geom_bar`, `geom_point`, … + +--- + +## 5.8 Layers + +Composition rule: a plot may have multiple layers, each with its own data, aesthetics, transforms, mark, scale overrides. +Layers encapsulate their own data + aesthetics + geometry +Each layer is a visual primitive applied over the same or different data: + +```python +@dataclass +class LayerConfig: + """ + A single layer in the plot. + + Each layer can have its own data source, transforms, geometry, aesthetics, + and (optionally) its own selections. + """ + name: str | None = None + + # Optional per-layer data override. If None, inherit from plot.data. + # Note: If a layer specifies its own data source, plot-level transforms + # do NOT apply to it. Layer transforms are applied after resolving layer data. + data: DataConfig | None = None + + # Optional transforms specific to this layer. + # Execution order: + # 1. Plot-level data + filters + transforms (if layer.data is None) + # 2. Layer data override (if specified) + filters + # 3. Layer transforms + # 4. Geometry + aesthetics rendering + transforms: list[TransformConfig] = field(default_factory=list) + + geometry: GeometryConfig = field( + default_factory=lambda: GeometryConfig(type="point", props={}) + ) + + aesthetics: AestheticsConfig = field( + default_factory=AestheticsConfig + ) + + # Layer-specific selections. Can reference plot-level selections by name. + selections: list[SelectionConfig] = field(default_factory=list) +``` + +Multiple layers allow: + +- overlays +- aggregated + raw plots +- annotation layers +- regression layers + +--- + +## 5.9 Facets + +How multiple small plots are arranged (row, column, wrap). + +```python +@dataclass +class FacetConfig: + """ + High-level faceting. + + This can be interpreted as: + - Altair: facet=... / row/column encoding + - plotnine: facet_wrap / facet_grid + """ + row: str | None = None + column: str | None = None + wrap: int | None = None # for wrap-style faceting (1D grid) +``` + +Maps to: + +- Altair: `facet()`, `row=`, `column=` +- plotnine: `facet_wrap`, `facet_grid` + +--- + +## 5.10 Guides (Axes, Legend, Title) + +Plot-level guides: + +```python +@dataclass +class PlotLevelGuideConfig: + """ + Plot-level guides and title (backend-agnostic). + """ + title: str | None = None + subtitle: str | None = None + caption: str | None = None + labels: list[LabelConfig] | None = None +``` + +Channel guides handled by `AxisConfig` and `LegendConfig`. + +```python +@dataclass +class AxisConfig: + title: str | None = None + grid: bool | None = None + format: str | None = None # e.g. ".2f", "%Y-%m-%d" + tick_count: int | None = None + label_angle: float | None = None + visible: bool = True + + +@dataclass +class LegendConfig: + title: str | None = None + orient: str | None = None # "right", "left", "top", "bottom", "none" + # Altair: "none" disables legend; for other backends, map as appropriate + visible: bool = True +``` + +```python +@dataclass +class LabelConfig: + """ + Configuration for plot labels (annotations, text overlays, etc.). + + To be fully defined based on specific labeling needs. + """ + text: str | None = None + x: float | str | None = None # position or field name + y: float | str | None = None + # Additional properties TBD +``` + +--- + +## 5.11 Plot Size + +Plot size + +```python +@dataclass +class PlotSizeConfig: + """ + Size/layout configuration. + """ + width: int | None = None # pixels or backend units + height: int | None = None + autosize: str | None = None # e.g. Altair/Vega-Lite autosize mode +``` + +--- + +## 5.12 Selections (Interactivity) + +Selections are first-class objects in the schema: + +```python +SelectionType = Literal["interval", "single", "multi"] + + +@dataclass +class SelectionConfig: + """ + Abstract interactive selection. + + For Altair: + - maps to selection_interval(), selection_single(), selection_multi() + For other backends: + - may be ignored, or handled by a UI layer. + """ + name: str + type: SelectionType = "interval" + encodings: list[str] | None = None # e.g. ["x", "y", "color"] + fields: list[str] | None = None # data fields for selection + bind: dict[str, Any] | None = None # UI bindings (sliders, dropdowns, etc.) +``` + +Backend behavior: + +- Altair: maps to `selection_interval`, `selection_point` +- others: ignore or support via UI layer + +--- + +## 5.13 Data Registry / Source Management + +The visualization system requires a **data registry** that maps logical source names to actual `pandas.DataFrame` objects. + +**Data Registry Interface:** + +```python +from typing import Protocol + +class DataRegistry(Protocol): + """Protocol for data source resolution.""" + def get(self, source_name: str) -> pd.DataFrame: + """Resolve a logical source name to a DataFrame.""" + ... +``` + +**Usage Pattern:** + +```python +# Data registry is provided by the caller +data_registry = { + "main": df_main, + "metrics": df_metrics, + "validation": df_val, +} + +# Renderer uses registry to resolve DataConfig.source +chart = build_altair_chart(plot_cfg, data_registry) +``` + +**Error Handling:** + +- If a source name doesn't exist in the registry, renderers should raise a `ValueError` with a clear message. +- Missing fields in DataFrames are handled at render time (backend-specific behavior). + +**Lifecycle:** + +- DataFrames are provided at render time, not stored in configs. +- This allows the same config to work with different datasets. +- Data registry can be populated from files, databases, or in-memory DataFrames. + +--- + +## 6. Example YAML + +```yaml +default_backend: altair + +plots: + loss_over_time: + data: + source: "metrics" + filters: ["split == 'train'"] + + transforms: + - op: calculate + as_field: log_loss + expr: "log(datum.loss)" + + size: + width: 600 + height: 350 + + guides: + title: "Training Loss Over Time" + + layers: + - name: raw_runs + geometry: + type: line + props: + opacity: 0.3 + aesthetics: + x: { field: step, type: quantitative } + y: { field: log_loss, type: quantitative } + color: { field: run_id, type: nominal } + + - name: mean_line + geometry: { type: line, props: { strokeWidth: 3 } } + aesthetics: + x: { field: step, type: quantitative } + y: { field: log_loss, type: quantitative, aggregate: mean } + color: { value: "black" } +``` + +**Additional Examples:** + +**Example with Facets:** + +```yaml +plots: + loss_by_split: + data: + source: "metrics" + facet: + row: split + column: model_type + layers: + - geometry: + type: line + aesthetics: + x: { field: step, type: quantitative } + y: { field: loss, type: quantitative } + color: { field: run_id, type: nominal } +``` + +**Example with Selections:** + +```yaml +plots: + interactive_scatter: + data: + source: "results" + selections: + - name: brush + type: interval + encodings: [x, y] + layers: + - geometry: + type: point + aesthetics: + x: { field: accuracy, type: quantitative } + y: { field: loss, type: quantitative } + color: + field: model_type + type: nominal + # Selection condition would be applied here in Altair +``` + +**Example with Multiple Data Sources:** + +```yaml +plots: + comparison: + layers: + - name: training_data + data: + source: "train_metrics" + geometry: + type: line + aesthetics: + x: { field: epoch, type: quantitative } + y: { field: loss, type: quantitative } + + - name: validation_data + data: + source: "val_metrics" + geometry: + type: line + props: + strokeDash: [5, 5] + aesthetics: + x: { field: epoch, type: quantitative } + y: { field: loss, type: quantitative } +``` + +--- + +## 7. Validation & Error Handling + +### 7.1 Validation Strategy + +**Schema Validation:** + +- Hydra automatically validates config structure against dataclass schemas. +- Type checking ensures correct types for all fields. + +**Semantic Validation:** +Performed at config load time (before rendering): + +- **ChannelAestheticsConfig**: `field` and `value` cannot both be set (validation error). +- **Required fields**: Certain geometries require specific aesthetics (e.g., `bar` needs `x` or `y`). +- **Transform parameters**: Validate that required parameters are present for each transform `op`. +- **Data source existence**: Check that all referenced sources exist in the data registry (at render time). + +**Backend Capability Validation:** + +- Warn (or error) if a config uses features unsupported by the selected backend. +- Example: Using `selections` with `backend: plotnine` would generate a warning. + +### 7.2 Error Handling + +**Config Load Errors:** + +- Invalid YAML syntax β†’ YAML parse error +- Missing required fields β†’ Hydra validation error +- Invalid field values β†’ Type/validation error + +**Render-Time Errors:** + +- Missing data source β†’ `ValueError: Data source 'X' not found in registry` +- Missing DataFrame column β†’ Backend-specific (Altair: field error, pandas: KeyError) +- Transform failure β†’ Backend-specific error with context +- Unsupported feature β†’ Warning logged, feature ignored (or error if critical) + +**Error Messages:** +All errors should include: + +- The config path/plot name where the error occurred +- The specific field or operation that failed +- Suggested fixes when possible + +--- + +## 8. Altair Renderer (Summary) + +A backend renderer converts `PlotConfig` β†’ Altair `Chart`: + +``` +plot_cfg + ↓ +resolve pandas DataFrame + ↓ +apply pandas-level filters + ↓ +construct Altair base Chart + ↓ +apply Altair transforms + ↓ +construct each Layer + ↓ +combine via alt.layer(...) + ↓ +apply title, size, facet, background + ↓ +return Chart +``` + +Renderer API: + +```python +def build_altair_chart(plot_cfg: PlotConfig, + data_registry: dict[str, pd.DataFrame]) -> alt.Chart +``` + +--- + +## 9. Backend Capability Matrix + +The following table outlines which features are supported by each backend: + +| Feature | Altair | Plotnine | Matplotlib | Notes | +| ----------------- | ------ | -------- | ---------- | ----------------------------- | +| **Transforms** | | | | | +| filter | βœ… | βœ… | βœ… | pandas query / Vega-Lite | +| calculate | βœ… | ⚠️ | ⚠️ | Via pandas eval / limited | +| aggregate | βœ… | βœ… | βœ… | Via pandas groupby | +| bin | βœ… | βœ… | ⚠️ | Manual binning for matplotlib | +| window | βœ… | ⚠️ | ⚠️ | Limited pandas support | +| fold/pivot | βœ… | βœ… | βœ… | Via pandas | +| **Geometries** | | | | | +| point, line, bar | βœ… | βœ… | βœ… | Core geometries | +| area, rect | βœ… | βœ… | βœ… | | +| text | βœ… | βœ… | βœ… | | +| boxplot, errorbar | βœ… | βœ… | βœ… | | +| **Interactivity** | | | | | +| Selections | βœ… | ❌ | ❌ | Altair-only | +| Tooltips | βœ… | ❌ | ⚠️ | Limited in matplotlib | +| **Facets** | βœ… | βœ… | ⚠️ | Manual subplots in matplotlib | +| **Scales** | βœ… | βœ… | βœ… | All support log/linear/etc | + +**Legend:** + +- βœ… Fully supported +- ⚠️ Partially supported or requires workarounds +- ❌ Not supported + +--- + +## 10. Future Backends + +Because the schema is backend-agnostic, adding new rendering backends involves implementing: + +``` +build_plotnine_plot(plot_cfg, data_registry) +build_matplotlib_plot(plot_cfg, data_registry) +build_plotly_plot(plot_cfg, data_registry) +``` + +The schema remains unchanged. + +--- + +## 11. Limitations & Future Work + +- Some Vega-Lite transforms (e.g. window, fold) will need backend-specific support. +- Plotnine/matplotlib backends may not support all interactive concepts. +- More advanced layout control (grids, multi-view dashboards) is out of scope but possible. + +--- + +## 12. Conclusion + +This configuration system provides: + +- A **unified Grammar-of-Graphics-style schema** +- **Declarative, reproducible plots** +- **Immediate support for Altair** (interactive plots) +- A path to **plotnine/matplotlib** (static publication-ready plots) +- A clean separation between _configuration_, _data_, and _rendering backends_ + +By adopting this design, we gain long-term flexibility in visualization tooling while keeping plot definitions clean, expressive, and consistent across projects. + +--- + +## 13. Future Extensions + +### Support for other backends + +- plotnine +- matplotlib + +### Additional configuration + +**Coordinate Systems:** +While coordinates are a core Grammar of Graphics concept, they are deferred to future work because: + +- Most common use cases (cartesian coordinates) are implicit in all backends +- Polar/geographic coordinates have limited cross-backend support +- Can be added without breaking existing configs + +```python +@dataclass +class CoordConfig: + type: Literal["cartesian", "polar", "geo"] + # Altair only supports a subset (geo; no polar without hacks). + # plotnine supports cartesian and polar. + # Matplotlib supports all via different projections. +``` + +**Theme Configuration:** +For consistent styling across plots: + +```python +@dataclass +class ThemeConfig: + # fonts, spacing, backgrounds, etc. + font_family: str | None = None + font_size: int | None = None + background_color: str | None = None + grid_color: str | None = None + # Additional theme properties TBD +``` + +**Export Formats:** + +- PNG, SVG, PDF for static backends +- HTML for interactive backends (Altair) +- Configurable resolution/DPI for static exports From a0273943178ea3875e8cc853b3c62ef70f36b9a1 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Thu, 13 Nov 2025 17:44:30 -0800 Subject: [PATCH 2/9] Implement solution --- examples/visualization_demo.py | 105 +++++ pyproject.toml | 1 + simplexity/visualization/__init__.py | 53 +++ simplexity/visualization/altair_renderer.py | 375 ++++++++++++++++++ simplexity/visualization/data_registry.py | 39 ++ .../visualization/structured_configs.py | 261 ++++++++++++ uv.lock | 154 +++++++ 7 files changed, 988 insertions(+) create mode 100644 examples/visualization_demo.py create mode 100644 simplexity/visualization/__init__.py create mode 100644 simplexity/visualization/altair_renderer.py create mode 100644 simplexity/visualization/data_registry.py create mode 100644 simplexity/visualization/structured_configs.py diff --git a/examples/visualization_demo.py b/examples/visualization_demo.py new file mode 100644 index 00000000..4532c1bc --- /dev/null +++ b/examples/visualization_demo.py @@ -0,0 +1,105 @@ +"""Standalone demo that renders a layered Altair chart via visualization configs.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd + +from simplexity.visualization import ( + AestheticsConfig, + ChannelAestheticsConfig, + DataConfig, + DictDataRegistry, + GeometryConfig, + LayerConfig, + PlotConfig, + PlotLevelGuideConfig, + PlotSizeConfig, + TransformConfig, +) +from simplexity.visualization.altair_renderer import build_altair_chart + + +def main() -> None: + """Generate a toy dataset, build a PlotConfig, and save the rendered chart.""" + df = _create_demo_dataframe() + registry = DictDataRegistry({"metrics": df}) + plot_cfg = _build_plot_config() + chart = build_altair_chart(plot_cfg, registry) + + output_path = Path(__file__).with_name("visualization_demo.html") + chart.save(str(output_path)) + print(f"Wrote visualization demo to {output_path}") # noqa: T201 - simple example harness + + +def _create_demo_dataframe() -> pd.DataFrame: + rng = np.random.default_rng(7) + records: list[dict[str, float | str | int]] = [] + for run_idx in range(3): + run_id = f"run_{run_idx + 1}" + for epoch in range(1, 51): + base_loss = np.exp(-epoch / 25.0) + 0.1 * run_idx + jitter = rng.normal(0.0, 0.02) + loss = max(base_loss + jitter, 1e-4) + accuracy = 0.55 + 0.008 * epoch + rng.normal(0.0, 0.01) + records.append( + { + "run_id": run_id, + "epoch": epoch, + "loss": loss, + "accuracy": accuracy, + } + ) + return pd.DataFrame(records) + + +def _build_plot_config() -> PlotConfig: + log_transform = TransformConfig(op="calculate", as_field="log_loss", expr="log(loss)") + base_aesthetics = AestheticsConfig( + x=ChannelAestheticsConfig(field="epoch", type="quantitative", title="Epoch"), + y=ChannelAestheticsConfig(field="log_loss", type="quantitative", title="log(loss)"), + tooltip=[ + ChannelAestheticsConfig(field="run_id", type="nominal", title="Run"), + ChannelAestheticsConfig(field="epoch", type="quantitative", title="Epoch"), + ChannelAestheticsConfig(field="log_loss", type="quantitative", title="log(loss)"), + ], + ) + raw_layer = LayerConfig( + name="raw_runs", + geometry=GeometryConfig(type="line", props={"opacity": 0.4}), + aesthetics=AestheticsConfig( + x=base_aesthetics.x, + y=base_aesthetics.y, + color=ChannelAestheticsConfig(field="run_id", type="nominal", title="Run"), + tooltip=base_aesthetics.tooltip, + ), + ) + mean_layer = LayerConfig( + name="mean_line", + geometry=GeometryConfig(type="line", props={"strokeWidth": 3, "color": "#111111"}), + aesthetics=AestheticsConfig( + x=base_aesthetics.x, + y=ChannelAestheticsConfig( + field="log_loss", + type="quantitative", + aggregate="mean", + title="Mean log(loss)", + ), + ), + ) + return PlotConfig( + data=DataConfig(source="metrics"), + transforms=[log_transform], + layers=[raw_layer, mean_layer], + size=PlotSizeConfig(width=600, height=400), + guides=PlotLevelGuideConfig( + title="Training loss over epochs", + subtitle="Each line is a synthetic training run built from random noise.", + ), + ) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index b43c88f6..a0f40023 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ description = "Computational Mechanics of sequence prediction models." readme = "README.md" requires-python = ">=3.12" dependencies = [ + "altair>=5.3.0", "chex", "dotenv", "equinox>=0.13.0", diff --git a/simplexity/visualization/__init__.py b/simplexity/visualization/__init__.py new file mode 100644 index 00000000..eb036267 --- /dev/null +++ b/simplexity/visualization/__init__.py @@ -0,0 +1,53 @@ +"""Visualization utilities and configuration schemas.""" + +from .data_registry import DataRegistry, DictDataRegistry, resolve_data_source +from .structured_configs import ( + AestheticsConfig, + AxisConfig, + BackendType, + ChannelAestheticsConfig, + ChannelType, + DataConfig, + FacetConfig, + GeometryConfig, + GeometryType, + GraphicsConfig, + LabelConfig, + LayerConfig, + LegendConfig, + PlotConfig, + PlotLevelGuideConfig, + PlotSizeConfig, + ScaleConfig, + SelectionConfig, + SelectionType, + TransformConfig, + TransformOp, +) + +__all__ = [ + "DataRegistry", + "DictDataRegistry", + "resolve_data_source", + "AestheticsConfig", + "AxisConfig", + "BackendType", + "ChannelAestheticsConfig", + "ChannelType", + "DataConfig", + "FacetConfig", + "GeometryConfig", + "GeometryType", + "GraphicsConfig", + "LabelConfig", + "LayerConfig", + "LegendConfig", + "PlotConfig", + "PlotLevelGuideConfig", + "PlotSizeConfig", + "ScaleConfig", + "SelectionConfig", + "SelectionType", + "TransformConfig", + "TransformOp", +] diff --git a/simplexity/visualization/altair_renderer.py b/simplexity/visualization/altair_renderer.py new file mode 100644 index 00000000..53a659bd --- /dev/null +++ b/simplexity/visualization/altair_renderer.py @@ -0,0 +1,375 @@ +"""Altair renderer for declarative visualization configs.""" + +from __future__ import annotations + +import logging +import math +from collections.abc import Mapping +from typing import Any + +import altair as alt +import numpy as np +import pandas as pd + +from simplexity.exceptions import ConfigValidationError +from simplexity.visualization.data_registry import DataRegistry, resolve_data_source +from simplexity.visualization.structured_configs import ( + AestheticsConfig, + AxisConfig, + ChannelAestheticsConfig, + DataConfig, + FacetConfig, + GeometryConfig, + LayerConfig, + LegendConfig, + PlotConfig, + PlotLevelGuideConfig, + PlotSizeConfig, + ScaleConfig, + SelectionConfig, + TransformConfig, +) + +LOGGER = logging.getLogger(__name__) + +_CALC_ENV = { + "np": np, + "pd": pd, + "math": math, + "log": np.log, + "exp": np.exp, + "sqrt": np.sqrt, + "abs": np.abs, + "clip": np.clip, +} + +_CHANNEL_CLASS_MAP = { + "x": "X", + "y": "Y", + "color": "Color", + "size": "Size", + "shape": "Shape", + "opacity": "Opacity", + "row": "Row", + "column": "Column", +} + + +def build_altair_chart(plot_cfg: PlotConfig, data_registry: DataRegistry | Mapping[str, pd.DataFrame]): + """Render a PlotConfig into an Altair chart.""" + if not plot_cfg.layers: + raise ConfigValidationError("PlotConfig.layers must include at least one layer for Altair rendering") + + plot_df = _materialize_data(plot_cfg.data, data_registry) + plot_df = _apply_transforms(plot_df, plot_cfg.transforms) + + layer_charts = [ + _build_layer_chart(layer, _resolve_layer_dataframe(layer, plot_df, data_registry)) for layer in plot_cfg.layers + ] + + chart = layer_charts[0] if len(layer_charts) == 1 else alt.layer(*layer_charts) + + if plot_cfg.selections: + chart = chart.add_params(*[_build_selection_param(sel) for sel in plot_cfg.selections]) + + if plot_cfg.facet: + chart = _apply_facet(chart, plot_cfg.facet) + + chart = _apply_plot_level_properties(chart, plot_cfg.guides, plot_cfg.size, plot_cfg.background) + + return chart + + +def _materialize_data(data_cfg: DataConfig, data_registry: DataRegistry | Mapping[str, pd.DataFrame]) -> pd.DataFrame: + df = resolve_data_source(data_cfg.source, data_registry).copy() + if data_cfg.filters: + df = _apply_filters(df, data_cfg.filters) + if data_cfg.columns: + missing = [col for col in data_cfg.columns if col not in df.columns] + if missing: + raise ConfigValidationError(f"Columns {missing} are not present in data source '{data_cfg.source}'") + df = df.loc[:, data_cfg.columns] + return df + + +def _resolve_layer_dataframe( + layer: LayerConfig, + plot_df: pd.DataFrame, + data_registry: DataRegistry | Mapping[str, pd.DataFrame], +) -> pd.DataFrame: + if layer.data is None: + layer_df = plot_df.copy() + else: + layer_df = _materialize_data(layer.data, data_registry) + + if layer.transforms: + layer_df = _apply_transforms(layer_df, layer.transforms) + return layer_df + + +def _apply_filters(df: pd.DataFrame, filters: list[str]) -> pd.DataFrame: + result = df.copy() + for expr in filters: + norm_expr = _normalize_expression(expr) + result = result.query(norm_expr, engine="python", local_dict=_CALC_ENV) + return result + + +def _apply_transforms(df: pd.DataFrame, transforms: list[TransformConfig]) -> pd.DataFrame: + result = df.copy() + for transform in transforms: + result = _apply_transform(result, transform) + return result + + +def _apply_transform(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + if transform.op == "filter": + if transform.filter is None: + raise ConfigValidationError("Filter transforms require the `filter` expression.") + return _apply_filters(df, [transform.filter]) + if transform.op == "calculate": + return _apply_calculate(df, transform) + if transform.op == "aggregate": + return _apply_aggregate(df, transform) + if transform.op == "bin": + return _apply_bin(df, transform) + if transform.op == "window": + return _apply_window(df, transform) + if transform.op == "fold": + return _apply_fold(df, transform) + if transform.op == "pivot": + raise ConfigValidationError("Pivot transforms are not implemented yet for the Altair renderer.") + raise ConfigValidationError(f"Unsupported transform operation '{transform.op}'") + + +def _apply_calculate(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + expr = _normalize_expression(transform.expr or "") + target = transform.as_field or "" + if not target: + raise ConfigValidationError("TransformConfig.as_field is required for calculate transforms") + result = df.copy() + result[target] = result.eval(expr, engine="python", local_dict=_CALC_ENV) + return result + + +def _apply_aggregate(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + groupby = transform.groupby or [] + aggregations = transform.aggregations or {} + if not groupby or not aggregations: + raise ConfigValidationError("Aggregate transforms require `groupby` and `aggregations` fields.") + + agg_kwargs: dict[str, tuple[str, str]] = {} + for alias, expr in aggregations.items(): + func, field = _parse_function_expr(expr, expected_arg=True) + agg_kwargs[alias] = (field, func) + + grouped = df.groupby(groupby, dropna=False).agg(**agg_kwargs).reset_index() + return grouped + + +def _apply_bin(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + if not transform.field or not transform.binned_as: + raise ConfigValidationError("Bin transforms require `field` and `binned_as`.") + bins = transform.maxbins or 10 + result = df.copy() + result[transform.binned_as] = pd.cut(result[transform.field], bins=bins, include_lowest=True) + return result + + +def _apply_window(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + if not transform.window: + raise ConfigValidationError("Window transforms require the `window` mapping.") + result = df.copy() + for alias, expr in transform.window.items(): + func, field = _parse_function_expr(expr, expected_arg=True) + if func == "rank": + result[alias] = result[field].rank(method="average") + elif func == "cumsum": + result[alias] = result[field].cumsum() + else: + raise ConfigValidationError(f"Window function '{func}' is not supported.") + return result + + +def _apply_fold(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + if not transform.fold_fields: + raise ConfigValidationError("Fold transforms require `fold_fields`.") + var_name, value_name = _derive_fold_names(transform.as_fields) + return df.melt(value_vars=transform.fold_fields, var_name=var_name, value_name=value_name) + + +def _parse_function_expr(expr: str, expected_arg: bool) -> tuple[str, str]: + if "(" not in expr or not expr.endswith(")"): + raise ConfigValidationError(f"Expression '{expr}' must be of the form func(field).") + func, rest = expr.split("(", 1) + value = rest[:-1].strip() + func = func.strip() + if expected_arg and not value: + raise ConfigValidationError(f"Expression '{expr}' must supply an argument.") + return func, value + + +def _derive_fold_names(as_fields: list[str] | None) -> tuple[str, str]: + if not as_fields: + return "key", "value" + if len(as_fields) == 1: + return as_fields[0], "value" + return as_fields[0], as_fields[1] + + +def _normalize_expression(expr: str) -> str: + return expr.replace("datum.", "").strip() + + +def _build_layer_chart(layer: LayerConfig, df: pd.DataFrame): + chart = alt.Chart(df) + chart = _apply_geometry(chart, layer.geometry) + encoding_kwargs = _encode_aesthetics(layer.aesthetics) + if encoding_kwargs: + chart = chart.encode(**encoding_kwargs) + if layer.selections: + chart = chart.add_params(*[_build_selection_param(sel) for sel in layer.selections]) + return chart + + +def _apply_geometry(chart, geometry: GeometryConfig): + mark_name = f"mark_{geometry.type}" + if not hasattr(chart, mark_name): + raise ConfigValidationError(f"Altair chart does not support geometry type '{geometry.type}'") + mark_fn = getattr(chart, mark_name) + return mark_fn(**(geometry.props or {})) + + +def _encode_aesthetics(aesthetics: AestheticsConfig) -> dict[str, Any]: + encodings: dict[str, Any] = {} + for channel_name in ("x", "y", "color", "size", "shape", "opacity", "row", "column"): + channel_cfg = getattr(aesthetics, channel_name) + channel_value = _channel_to_alt(channel_name, channel_cfg) + if channel_value is not None: + encodings[channel_name] = channel_value + + if aesthetics.tooltip: + encodings["tooltip"] = [_tooltip_to_alt(tooltip_cfg) for tooltip_cfg in aesthetics.tooltip] + + return encodings + + +def _channel_to_alt(channel_name: str, cfg: ChannelAestheticsConfig | None): + if cfg is None: + return None + if cfg.value is not None and cfg.field is None: + return alt.value(cfg.value) + + channel_cls_name = _CHANNEL_CLASS_MAP[channel_name] + channel_cls = getattr(alt, channel_cls_name) + kwargs: dict[str, Any] = {} + if cfg.field: + kwargs["field"] = cfg.field + if cfg.type: + kwargs["type"] = cfg.type + if cfg.title: + kwargs["title"] = cfg.title + if cfg.aggregate: + kwargs["aggregate"] = cfg.aggregate + if cfg.bin is not None: + kwargs["bin"] = cfg.bin + if cfg.time_unit: + kwargs["timeUnit"] = cfg.time_unit + if cfg.sort is not None: + if isinstance(cfg.sort, list): + kwargs["sort"] = alt.Sort(cfg.sort) + else: + kwargs["sort"] = cfg.sort + if cfg.scale: + kwargs["scale"] = _scale_to_alt(cfg.scale) + if cfg.axis and channel_name in {"x", "y", "row", "column"}: + kwargs["axis"] = _axis_to_alt(cfg.axis) + if cfg.legend and channel_name in {"color", "size", "shape", "opacity"}: + kwargs["legend"] = _legend_to_alt(cfg.legend) + return channel_cls(**kwargs) + + +def _tooltip_to_alt(cfg: ChannelAestheticsConfig): + if cfg.value is not None and cfg.field is None: + return alt.Tooltip(value=cfg.value, title=cfg.title) + if cfg.field is None: + raise ConfigValidationError("Tooltip channels must set either a field or a constant value.") + + kwargs: dict[str, Any] = {"field": cfg.field} + if cfg.type: + kwargs["type"] = cfg.type + if cfg.title: + kwargs["title"] = cfg.title + return alt.Tooltip(**kwargs) + + +def _scale_to_alt(cfg: ScaleConfig): + kwargs = {k: v for k, v in vars(cfg).items() if v is not None} + return alt.Scale(**kwargs) + + +def _axis_to_alt(cfg: AxisConfig): + kwargs = {k: v for k, v in vars(cfg).items() if v is not None} + return alt.Axis(**kwargs) + + +def _legend_to_alt(cfg: LegendConfig): + kwargs = {k: v for k, v in vars(cfg).items() if v is not None} + return alt.Legend(**kwargs) + + +def _build_selection_param(cfg: SelectionConfig): + kwargs: dict[str, Any] = {"name": cfg.name} + if cfg.encodings is not None: + kwargs["encodings"] = cfg.encodings # type: ignore[assignment] + if cfg.fields is not None: + kwargs["fields"] = cfg.fields + if cfg.bind is not None: + kwargs["bind"] = cfg.bind # type: ignore[assignment] + + if cfg.type == "interval": + return alt.selection_interval(**kwargs) + if cfg.type == "single": + return alt.selection_single(**kwargs) + if cfg.type == "multi": + return alt.selection_multi(**kwargs) + raise ConfigValidationError(f"Unsupported selection type '{cfg.type}' for Altair renderer.") + + +def _apply_facet(chart, facet_cfg: FacetConfig): + facet_args: dict[str, Any] = {} + if facet_cfg.row: + facet_args["row"] = alt.Row(facet_cfg.row) + if facet_cfg.column: + facet_args["column"] = alt.Column(facet_cfg.column) + if facet_cfg.wrap: + raise ConfigValidationError("FacetConfig.wrap is not yet implemented for Altair rendering.") + if not facet_args: + return chart + return chart.facet(**facet_args) + + +def _apply_plot_level_properties(chart, guides: PlotLevelGuideConfig, size: PlotSizeConfig, background: str | None): + title_params = _build_title_params(guides) + if title_params is not None: + chart = chart.properties(title=title_params) + width = size.width + height = size.height + if width is not None or height is not None: + chart = chart.properties(width=width, height=height) + if size.autosize: + chart.autosize = size.autosize + if background: + chart.background = background + if guides.labels: + LOGGER.info("Plot-level labels are not yet implemented for Altair; skipping %s labels.", len(guides.labels)) + return chart + + +def _build_title_params(guides: PlotLevelGuideConfig): + subtitle_lines = [text for text in (guides.subtitle, guides.caption) if text] + if not guides.title and not subtitle_lines: + return None + if subtitle_lines: + return alt.TitleParams(text=guides.title or "", subtitle=subtitle_lines) + return guides.title diff --git a/simplexity/visualization/data_registry.py b/simplexity/visualization/data_registry.py new file mode 100644 index 00000000..c70f44c6 --- /dev/null +++ b/simplexity/visualization/data_registry.py @@ -0,0 +1,39 @@ +"""Helpers for resolving logical visualization data sources.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Protocol + +import pandas as pd + + +class DataRegistry(Protocol): # pylint: disable=too-few-public-methods + """Protocol for registry objects that return pandas DataFrames.""" + + def get(self, source_name: str) -> pd.DataFrame: + """Return the DataFrame associated with ``source_name``.""" + ... # pylint: disable=unnecessary-ellipsis + + +class DictDataRegistry: # pylint: disable=too-few-public-methods + """Simple registry backed by an in-memory mapping.""" + + def __init__(self, data: Mapping[str, pd.DataFrame] | None = None) -> None: + self._data: dict[str, pd.DataFrame] = dict(data or {}) + + def get(self, source_name: str) -> pd.DataFrame: + """Get the DataFrame associated with ``source_name``.""" + try: + return self._data[source_name] + except KeyError as exc: # pragma: no cover - simple error wrapper + raise ValueError(f"Data source '{source_name}' is not registered") from exc + + +def resolve_data_source(source_name: str, data_registry: DataRegistry | Mapping[str, pd.DataFrame]) -> pd.DataFrame: + """Resolve a logical source name regardless of the registry implementation.""" + if isinstance(data_registry, Mapping): + if source_name not in data_registry: + raise ValueError(f"Data source '{source_name}' is not registered") + return data_registry[source_name] + return data_registry.get(source_name) diff --git a/simplexity/visualization/structured_configs.py b/simplexity/visualization/structured_configs.py new file mode 100644 index 00000000..6dcea953 --- /dev/null +++ b/simplexity/visualization/structured_configs.py @@ -0,0 +1,261 @@ +"""Structured visualization configuration dataclasses. + +This module implements the schema described in docs/visualization.md. The +dataclasses are intentionally backend-agnostic so that Hydra configs can be +validated once and rendered by different visualization engines (Altair, +plotnine, matplotlib, etc.). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + +from simplexity.exceptions import ConfigValidationError + +BackendType = Literal["altair"] # Currently only Altair is supported, but we could add other backends later + +TransformOp = Literal["filter", "calculate", "aggregate", "bin", "window", "fold", "pivot"] + +ScaleType = Literal["linear", "log", "sqrt", "pow", "symlog", "time", "utc", "ordinal", "band", "point"] + +ChannelType = Literal["quantitative", "ordinal", "nominal", "temporal"] + +GeometryType = Literal[ + "point", + "line", + "area", + "bar", + "rect", + "rule", + "tick", + "circle", + "square", + "text", + "boxplot", + "errorbar", + "errorband", +] + +SelectionType = Literal["interval", "single", "multi"] + + +def _ensure(condition: bool, message: str) -> None: + """Raise ConfigValidationError if condition is not met.""" + if not condition: + raise ConfigValidationError(message) + + +@dataclass +class DataConfig: + """Specifies the logical data source and lightweight filtering.""" + + source: str = "main" + filters: list[str] = field(default_factory=list) + columns: list[str] | None = None + + +@dataclass +class TransformConfig: # pylint: disable=too-many-instance-attributes + """Represents a single data transform stage.""" + + op: TransformOp + filter: str | None = None + as_field: str | None = None + expr: str | None = None + groupby: list[str] | None = None + aggregations: dict[str, str] | None = None + field: str | None = None + binned_as: str | None = None + maxbins: int | None = None + window: dict[str, str] | None = None + frame: list[int | None] | None = None + fold_fields: list[str] | None = None + as_fields: list[str] | None = None + + def __post_init__(self) -> None: + if self.op == "filter": + _ensure(bool(self.filter), "TransformConfig.filter must be provided when op='filter'") + if self.op == "calculate": + _ensure(bool(self.as_field), "TransformConfig.as_field is required for calculate transforms") + _ensure(bool(self.expr), "TransformConfig.expr is required for calculate transforms") + if self.op == "aggregate": + _ensure(bool(self.groupby), "TransformConfig.groupby is required for aggregate transforms") + _ensure( + bool(self.aggregations), + "TransformConfig.aggregations is required for aggregate transforms", + ) + if self.op == "bin": + _ensure(bool(self.field), "TransformConfig.field is required for bin transforms") + _ensure(bool(self.binned_as), "TransformConfig.binned_as is required for bin transforms") + if self.op == "window": + _ensure(bool(self.window), "TransformConfig.window is required for window transforms") + + +@dataclass +class ScaleConfig: + """Describes how raw data values are mapped to visual ranges.""" + + type: ScaleType | None = None + domain: list[Any] | None = None + range: list[Any] | None = None + clamp: bool | None = None + nice: bool | None = None + reverse: bool | None = None + + +@dataclass +class AxisConfig: + """Axis settings for positional channels.""" + + title: str | None = None + grid: bool | None = None + format: str | None = None + tick_count: int | None = None + label_angle: float | None = None + visible: bool = True + + +@dataclass +class LegendConfig: + """Legend settings for categorical or continuous mappings.""" + + title: str | None = None + orient: str | None = None + visible: bool = True + + +@dataclass +class ChannelAestheticsConfig: + """Represents one visual encoding channel (x, y, color, etc.).""" + + field: str | None = None + type: ChannelType | None = None + value: Any | None = None + aggregate: str | None = None + bin: bool | None = None + time_unit: str | None = None + scale: ScaleConfig | None = None + axis: AxisConfig | None = None + legend: LegendConfig | None = None + sort: str | list[Any] | None = None + title: str | None = None + + def __post_init__(self) -> None: + if self.field is not None and self.value is not None: + raise ConfigValidationError( + "ChannelAestheticsConfig cannot specify both 'field' and 'value'; prefer 'field'." + ) + + +@dataclass +class AestheticsConfig: + """Collection of channel encodings for a layer.""" + + x: ChannelAestheticsConfig | None = None + y: ChannelAestheticsConfig | None = None + color: ChannelAestheticsConfig | None = None + size: ChannelAestheticsConfig | None = None + shape: ChannelAestheticsConfig | None = None + opacity: ChannelAestheticsConfig | None = None + tooltip: list[ChannelAestheticsConfig] | None = None + row: ChannelAestheticsConfig | None = None + column: ChannelAestheticsConfig | None = None + + +@dataclass +class GeometryConfig: + """Visual primitive used to draw the layer.""" + + type: GeometryType + props: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + _ensure(isinstance(self.props, dict), "GeometryConfig.props must be a dictionary") + + +@dataclass +class SelectionConfig: + """Interactive selection definition.""" + + name: str + type: SelectionType = "interval" + encodings: list[str] | None = None + fields: list[str] | None = None + bind: dict[str, Any] | None = None + + +@dataclass +class PlotSizeConfig: + """Size and layout metadata for an entire plot.""" + + width: int | None = None + height: int | None = None + autosize: str | None = None + + +@dataclass +class LabelConfig: + """Free-form labels or annotations.""" + + text: str | None = None + x: float | str | None = None + y: float | str | None = None + props: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PlotLevelGuideConfig: + """Titles and caption level guides.""" + + title: str | None = None + subtitle: str | None = None + caption: str | None = None + labels: list[LabelConfig] | None = None + + +@dataclass +class FacetConfig: + """High-level faceting instructions.""" + + row: str | None = None + column: str | None = None + wrap: int | None = None + + +@dataclass +class LayerConfig: + """A single layer in a composed plot.""" + + name: str | None = None + data: DataConfig | None = None + transforms: list[TransformConfig] = field(default_factory=list) + geometry: GeometryConfig = field(default_factory=lambda: GeometryConfig(type="point")) + aesthetics: AestheticsConfig = field(default_factory=AestheticsConfig) + selections: list[SelectionConfig] = field(default_factory=list) + + +@dataclass +class PlotConfig: + """Top-level configuration for one plot.""" + + backend: BackendType = "altair" + data: DataConfig = field(default_factory=DataConfig) + transforms: list[TransformConfig] = field(default_factory=list) + layers: list[LayerConfig] = field(default_factory=list) + facet: FacetConfig | None = None + size: PlotSizeConfig = field(default_factory=PlotSizeConfig) + guides: PlotLevelGuideConfig = field(default_factory=PlotLevelGuideConfig) + background: str | None = None + selections: list[SelectionConfig] = field(default_factory=list) + + def __post_init__(self) -> None: + _ensure(self.layers is not None, "PlotConfig.layers must be a list (can be empty)") + + +@dataclass +class GraphicsConfig: + """Root Visualization config that multiplexes multiple named plots.""" + + default_backend: BackendType = "altair" + plots: dict[str, PlotConfig] = field(default_factory=dict) diff --git a/uv.lock b/uv.lock index deaae063..e245ab25 100644 --- a/uv.lock +++ b/uv.lock @@ -167,6 +167,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/44/1f/38e29b06bfed7818ebba1f84904afdc8153ef7b6c7e0d8f3bc6643f5989c/alembic-1.17.0-py3-none-any.whl", hash = "sha256:80523bc437d41b35c5db7e525ad9d908f79de65c27d6a5a5eab6df348a352d99", size = 247449, upload-time = "2025-10-11T18:40:16.288Z" }, ] +[[package]] +name = "altair" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, + { name = "jsonschema" }, + { name = "narwhals" }, + { name = "packaging" }, + { name = "typing-extensions", marker = "python_full_version < '3.15'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/c0/184a89bd5feba14ff3c41cfaf1dd8a82c05f5ceedbc92145e17042eb08a4/altair-6.0.0.tar.gz", hash = "sha256:614bf5ecbe2337347b590afb111929aa9c16c9527c4887d96c9bc7f6640756b4", size = 763834, upload-time = "2025-11-12T08:59:11.519Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/db/33/ef2f2409450ef6daa61459d5de5c08128e7d3edb773fefd0a324d1310238/altair-6.0.0-py3-none-any.whl", hash = "sha256:09ae95b53d5fe5b16987dccc785a7af8588f2dca50de1e7a156efa8a461515f8", size = 795410, upload-time = "2025-11-12T08:59:09.804Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.3" @@ -1522,6 +1538,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, ] +[[package]] +name = "jsonschema" +version = "4.25.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "jsonschema-specifications" }, + { name = "referencing" }, + { name = "rpds-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/69/f7185de793a29082a9f3c7728268ffb31cb5095131a9c139a74078e27336/jsonschema-4.25.1.tar.gz", hash = "sha256:e4a9655ce0da0c0b67a085847e00a3a51449e1157f4f75e9fb5aa545e122eb85", size = 357342, upload-time = "2025-08-18T17:03:50.038Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/9c/8c95d856233c1f82500c2450b8c68576b4cf1c871db3afac5c34ff84e6fd/jsonschema-4.25.1-py3-none-any.whl", hash = "sha256:3fba0169e345c7175110351d456342c364814cfcf3b964ba4587f22915230a63", size = 90040, upload-time = "2025-08-18T17:03:48.373Z" }, +] + +[[package]] +name = "jsonschema-specifications" +version = "2025.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "referencing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/74/a633ee74eb36c44aa6d1095e7cc5569bebf04342ee146178e2d36600708b/jsonschema_specifications-2025.9.1.tar.gz", hash = "sha256:b540987f239e745613c7a9176f3edb72b832a4ac465cf02712288397832b5e8d", size = 32855, upload-time = "2025-09-08T01:34:59.186Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, +] + [[package]] name = "kiwisolver" version = "1.4.9" @@ -2898,6 +2941,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/8b/2e814a255436fc6d604a60f1e8b8a186e05082aa3c0cabfd9330192496a2/pylint-4.0.2-py3-none-any.whl", hash = "sha256:9627ccd129893fb8ee8e8010261cb13485daca83e61a6f854a85528ee579502d", size = 536019, upload-time = "2025-10-20T13:02:32.778Z" }, ] +[[package]] +name = "pylint-per-file-ignores" +version = "3.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pylint" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6a/dc/4588e9affbb7bbac3f14d11a7617db8e3f31669f319da03271f7f0332b09/pylint_per_file_ignores-3.1.0.tar.gz", hash = "sha256:a11db907f74cfbd8956365c36b0d453daa20448e2aaada1033c52ba6f621c7d7", size = 6616, upload-time = "2025-10-13T08:21:51.403Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/16/c1054ad8f355749eef83a2b176fc545cc27b08b5dc2bf21420174325676e/pylint_per_file_ignores-3.1.0-py3-none-any.whl", hash = "sha256:57cde138807a28a98f33edf85e687d88df3660d37c8d5d5d46de1f1e6fcc665d", size = 5666, upload-time = "2025-10-13T08:21:50.42Z" }, +] + [[package]] name = "pyparsing" version = "3.2.5" @@ -3072,6 +3127,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] +[[package]] +name = "referencing" +version = "0.37.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "rpds-py" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8", size = 78036, upload-time = "2025-10-13T15:30:48.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/58/ca301544e1fa93ed4f80d724bf5b194f6e4b945841c5bfd555878eea9fcb/referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231", size = 26766, upload-time = "2025-10-13T15:30:47.625Z" }, +] + [[package]] name = "regex" version = "2025.10.23" @@ -3190,6 +3259,87 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/f6/5fc0574af5379606ffd57a4b68ed88f9b415eb222047fe023aefcc00a648/rich_argparse-1.7.1-py3-none-any.whl", hash = "sha256:a8650b42e4a4ff72127837632fba6b7da40784842f08d7395eb67a9cbd7b4bf9", size = 25357, upload-time = "2025-05-25T20:20:33.793Z" }, ] +[[package]] +name = "rpds-py" +version = "0.28.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/48/dc/95f074d43452b3ef5d06276696ece4b3b5d696e7c9ad7173c54b1390cd70/rpds_py-0.28.0.tar.gz", hash = "sha256:abd4df20485a0983e2ca334a216249b6186d6e3c1627e106651943dbdb791aea", size = 27419, upload-time = "2025-10-22T22:24:29.327Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/5c/6c3936495003875fe7b14f90ea812841a08fca50ab26bd840e924097d9c8/rpds_py-0.28.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6b4f28583a4f247ff60cd7bdda83db8c3f5b05a7a82ff20dd4b078571747708f", size = 366439, upload-time = "2025-10-22T22:22:04.525Z" }, + { url = "https://files.pythonhosted.org/packages/56/f9/a0f1ca194c50aa29895b442771f036a25b6c41a35e4f35b1a0ea713bedae/rpds_py-0.28.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d678e91b610c29c4b3d52a2c148b641df2b4676ffe47c59f6388d58b99cdc424", size = 348170, upload-time = "2025-10-22T22:22:06.397Z" }, + { url = "https://files.pythonhosted.org/packages/18/ea/42d243d3a586beb72c77fa5def0487daf827210069a95f36328e869599ea/rpds_py-0.28.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e819e0e37a44a78e1383bf1970076e2ccc4dc8c2bbaa2f9bd1dc987e9afff628", size = 378838, upload-time = "2025-10-22T22:22:07.932Z" }, + { url = "https://files.pythonhosted.org/packages/e7/78/3de32e18a94791af8f33601402d9d4f39613136398658412a4e0b3047327/rpds_py-0.28.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5ee514e0f0523db5d3fb171f397c54875dbbd69760a414dccf9d4d7ad628b5bd", size = 393299, upload-time = "2025-10-22T22:22:09.435Z" }, + { url = "https://files.pythonhosted.org/packages/13/7e/4bdb435afb18acea2eb8a25ad56b956f28de7c59f8a1d32827effa0d4514/rpds_py-0.28.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5f3fa06d27fdcee47f07a39e02862da0100cb4982508f5ead53ec533cd5fe55e", size = 518000, upload-time = "2025-10-22T22:22:11.326Z" }, + { url = "https://files.pythonhosted.org/packages/31/d0/5f52a656875cdc60498ab035a7a0ac8f399890cc1ee73ebd567bac4e39ae/rpds_py-0.28.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46959ef2e64f9e4a41fc89aa20dbca2b85531f9a72c21099a3360f35d10b0d5a", size = 408746, upload-time = "2025-10-22T22:22:13.143Z" }, + { url = "https://files.pythonhosted.org/packages/3e/cd/49ce51767b879cde77e7ad9fae164ea15dce3616fe591d9ea1df51152706/rpds_py-0.28.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8455933b4bcd6e83fde3fefc987a023389c4b13f9a58c8d23e4b3f6d13f78c84", size = 386379, upload-time = "2025-10-22T22:22:14.602Z" }, + { url = "https://files.pythonhosted.org/packages/6a/99/e4e1e1ee93a98f72fc450e36c0e4d99c35370220e815288e3ecd2ec36a2a/rpds_py-0.28.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:ad50614a02c8c2962feebe6012b52f9802deec4263946cddea37aaf28dd25a66", size = 401280, upload-time = "2025-10-22T22:22:16.063Z" }, + { url = "https://files.pythonhosted.org/packages/61/35/e0c6a57488392a8b319d2200d03dad2b29c0db9996f5662c3b02d0b86c02/rpds_py-0.28.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e5deca01b271492553fdb6c7fd974659dce736a15bae5dad7ab8b93555bceb28", size = 412365, upload-time = "2025-10-22T22:22:17.504Z" }, + { url = "https://files.pythonhosted.org/packages/ff/6a/841337980ea253ec797eb084665436007a1aad0faac1ba097fb906c5f69c/rpds_py-0.28.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:735f8495a13159ce6a0d533f01e8674cec0c57038c920495f87dcb20b3ddb48a", size = 559573, upload-time = "2025-10-22T22:22:19.108Z" }, + { url = "https://files.pythonhosted.org/packages/e7/5e/64826ec58afd4c489731f8b00729c5f6afdb86f1df1df60bfede55d650bb/rpds_py-0.28.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:961ca621ff10d198bbe6ba4957decca61aa2a0c56695384c1d6b79bf61436df5", size = 583973, upload-time = "2025-10-22T22:22:20.768Z" }, + { url = "https://files.pythonhosted.org/packages/b6/ee/44d024b4843f8386a4eeaa4c171b3d31d55f7177c415545fd1a24c249b5d/rpds_py-0.28.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2374e16cc9131022e7d9a8f8d65d261d9ba55048c78f3b6e017971a4f5e6353c", size = 553800, upload-time = "2025-10-22T22:22:22.25Z" }, + { url = "https://files.pythonhosted.org/packages/7d/89/33e675dccff11a06d4d85dbb4d1865f878d5020cbb69b2c1e7b2d3f82562/rpds_py-0.28.0-cp312-cp312-win32.whl", hash = "sha256:d15431e334fba488b081d47f30f091e5d03c18527c325386091f31718952fe08", size = 216954, upload-time = "2025-10-22T22:22:24.105Z" }, + { url = "https://files.pythonhosted.org/packages/af/36/45f6ebb3210887e8ee6dbf1bc710ae8400bb417ce165aaf3024b8360d999/rpds_py-0.28.0-cp312-cp312-win_amd64.whl", hash = "sha256:a410542d61fc54710f750d3764380b53bf09e8c4edbf2f9141a82aa774a04f7c", size = 227844, upload-time = "2025-10-22T22:22:25.551Z" }, + { url = "https://files.pythonhosted.org/packages/57/91/f3fb250d7e73de71080f9a221d19bd6a1c1eb0d12a1ea26513f6c1052ad6/rpds_py-0.28.0-cp312-cp312-win_arm64.whl", hash = "sha256:1f0cfd1c69e2d14f8c892b893997fa9a60d890a0c8a603e88dca4955f26d1edd", size = 217624, upload-time = "2025-10-22T22:22:26.914Z" }, + { url = "https://files.pythonhosted.org/packages/d3/03/ce566d92611dfac0085c2f4b048cd53ed7c274a5c05974b882a908d540a2/rpds_py-0.28.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e9e184408a0297086f880556b6168fa927d677716f83d3472ea333b42171ee3b", size = 366235, upload-time = "2025-10-22T22:22:28.397Z" }, + { url = "https://files.pythonhosted.org/packages/00/34/1c61da1b25592b86fd285bd7bd8422f4c9d748a7373b46126f9ae792a004/rpds_py-0.28.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:edd267266a9b0448f33dc465a97cfc5d467594b600fe28e7fa2f36450e03053a", size = 348241, upload-time = "2025-10-22T22:22:30.171Z" }, + { url = "https://files.pythonhosted.org/packages/fc/00/ed1e28616848c61c493a067779633ebf4b569eccaacf9ccbdc0e7cba2b9d/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85beb8b3f45e4e32f6802fb6cd6b17f615ef6c6a52f265371fb916fae02814aa", size = 378079, upload-time = "2025-10-22T22:22:31.644Z" }, + { url = "https://files.pythonhosted.org/packages/11/b2/ccb30333a16a470091b6e50289adb4d3ec656fd9951ba8c5e3aaa0746a67/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d2412be8d00a1b895f8ad827cc2116455196e20ed994bb704bf138fe91a42724", size = 393151, upload-time = "2025-10-22T22:22:33.453Z" }, + { url = "https://files.pythonhosted.org/packages/8c/d0/73e2217c3ee486d555cb84920597480627d8c0240ff3062005c6cc47773e/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cf128350d384b777da0e68796afdcebc2e9f63f0e9f242217754e647f6d32491", size = 517520, upload-time = "2025-10-22T22:22:34.949Z" }, + { url = "https://files.pythonhosted.org/packages/c4/91/23efe81c700427d0841a4ae7ea23e305654381831e6029499fe80be8a071/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a2036d09b363aa36695d1cc1a97b36865597f4478470b0697b5ee9403f4fe399", size = 408699, upload-time = "2025-10-22T22:22:36.584Z" }, + { url = "https://files.pythonhosted.org/packages/ca/ee/a324d3198da151820a326c1f988caaa4f37fc27955148a76fff7a2d787a9/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8e1e9be4fa6305a16be628959188e4fd5cd6f1b0e724d63c6d8b2a8adf74ea6", size = 385720, upload-time = "2025-10-22T22:22:38.014Z" }, + { url = "https://files.pythonhosted.org/packages/19/ad/e68120dc05af8b7cab4a789fccd8cdcf0fe7e6581461038cc5c164cd97d2/rpds_py-0.28.0-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:0a403460c9dd91a7f23fc3188de6d8977f1d9603a351d5db6cf20aaea95b538d", size = 401096, upload-time = "2025-10-22T22:22:39.869Z" }, + { url = "https://files.pythonhosted.org/packages/99/90/c1e070620042459d60df6356b666bb1f62198a89d68881816a7ed121595a/rpds_py-0.28.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d7366b6553cdc805abcc512b849a519167db8f5e5c3472010cd1228b224265cb", size = 411465, upload-time = "2025-10-22T22:22:41.395Z" }, + { url = "https://files.pythonhosted.org/packages/68/61/7c195b30d57f1b8d5970f600efee72a4fad79ec829057972e13a0370fd24/rpds_py-0.28.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5b43c6a3726efd50f18d8120ec0551241c38785b68952d240c45ea553912ac41", size = 558832, upload-time = "2025-10-22T22:22:42.871Z" }, + { url = "https://files.pythonhosted.org/packages/b0/3d/06f3a718864773f69941d4deccdf18e5e47dd298b4628062f004c10f3b34/rpds_py-0.28.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:0cb7203c7bc69d7c1585ebb33a2e6074492d2fc21ad28a7b9d40457ac2a51ab7", size = 583230, upload-time = "2025-10-22T22:22:44.877Z" }, + { url = "https://files.pythonhosted.org/packages/66/df/62fc783781a121e77fee9a21ead0a926f1b652280a33f5956a5e7833ed30/rpds_py-0.28.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7a52a5169c664dfb495882adc75c304ae1d50df552fbd68e100fdc719dee4ff9", size = 553268, upload-time = "2025-10-22T22:22:46.441Z" }, + { url = "https://files.pythonhosted.org/packages/84/85/d34366e335140a4837902d3dea89b51f087bd6a63c993ebdff59e93ee61d/rpds_py-0.28.0-cp313-cp313-win32.whl", hash = "sha256:2e42456917b6687215b3e606ab46aa6bca040c77af7df9a08a6dcfe8a4d10ca5", size = 217100, upload-time = "2025-10-22T22:22:48.342Z" }, + { url = "https://files.pythonhosted.org/packages/3c/1c/f25a3f3752ad7601476e3eff395fe075e0f7813fbb9862bd67c82440e880/rpds_py-0.28.0-cp313-cp313-win_amd64.whl", hash = "sha256:e0a0311caedc8069d68fc2bf4c9019b58a2d5ce3cd7cb656c845f1615b577e1e", size = 227759, upload-time = "2025-10-22T22:22:50.219Z" }, + { url = "https://files.pythonhosted.org/packages/e0/d6/5f39b42b99615b5bc2f36ab90423ea404830bdfee1c706820943e9a645eb/rpds_py-0.28.0-cp313-cp313-win_arm64.whl", hash = "sha256:04c1b207ab8b581108801528d59ad80aa83bb170b35b0ddffb29c20e411acdc1", size = 217326, upload-time = "2025-10-22T22:22:51.647Z" }, + { url = "https://files.pythonhosted.org/packages/5c/8b/0c69b72d1cee20a63db534be0df271effe715ef6c744fdf1ff23bb2b0b1c/rpds_py-0.28.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:f296ea3054e11fc58ad42e850e8b75c62d9a93a9f981ad04b2e5ae7d2186ff9c", size = 355736, upload-time = "2025-10-22T22:22:53.211Z" }, + { url = "https://files.pythonhosted.org/packages/f7/6d/0c2ee773cfb55c31a8514d2cece856dd299170a49babd50dcffb15ddc749/rpds_py-0.28.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5a7306c19b19005ad98468fcefeb7100b19c79fc23a5f24a12e06d91181193fa", size = 342677, upload-time = "2025-10-22T22:22:54.723Z" }, + { url = "https://files.pythonhosted.org/packages/e2/1c/22513ab25a27ea205144414724743e305e8153e6abe81833b5e678650f5a/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5d9b86aa501fed9862a443c5c3116f6ead8bc9296185f369277c42542bd646b", size = 371847, upload-time = "2025-10-22T22:22:56.295Z" }, + { url = "https://files.pythonhosted.org/packages/60/07/68e6ccdb4b05115ffe61d31afc94adef1833d3a72f76c9632d4d90d67954/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e5bbc701eff140ba0e872691d573b3d5d30059ea26e5785acba9132d10c8c31d", size = 381800, upload-time = "2025-10-22T22:22:57.808Z" }, + { url = "https://files.pythonhosted.org/packages/73/bf/6d6d15df80781d7f9f368e7c1a00caf764436518c4877fb28b029c4624af/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a5690671cd672a45aa8616d7374fdf334a1b9c04a0cac3c854b1136e92374fe", size = 518827, upload-time = "2025-10-22T22:22:59.826Z" }, + { url = "https://files.pythonhosted.org/packages/7b/d3/2decbb2976cc452cbf12a2b0aaac5f1b9dc5dd9d1f7e2509a3ee00421249/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9f1d92ecea4fa12f978a367c32a5375a1982834649cdb96539dcdc12e609ab1a", size = 399471, upload-time = "2025-10-22T22:23:01.968Z" }, + { url = "https://files.pythonhosted.org/packages/b1/2c/f30892f9e54bd02e5faca3f6a26d6933c51055e67d54818af90abed9748e/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d252db6b1a78d0a3928b6190156042d54c93660ce4d98290d7b16b5296fb7cc", size = 377578, upload-time = "2025-10-22T22:23:03.52Z" }, + { url = "https://files.pythonhosted.org/packages/f0/5d/3bce97e5534157318f29ac06bf2d279dae2674ec12f7cb9c12739cee64d8/rpds_py-0.28.0-cp313-cp313t-manylinux_2_31_riscv64.whl", hash = "sha256:d61b355c3275acb825f8777d6c4505f42b5007e357af500939d4a35b19177259", size = 390482, upload-time = "2025-10-22T22:23:05.391Z" }, + { url = "https://files.pythonhosted.org/packages/e3/f0/886bd515ed457b5bd93b166175edb80a0b21a210c10e993392127f1e3931/rpds_py-0.28.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:acbe5e8b1026c0c580d0321c8aae4b0a1e1676861d48d6e8c6586625055b606a", size = 402447, upload-time = "2025-10-22T22:23:06.93Z" }, + { url = "https://files.pythonhosted.org/packages/42/b5/71e8777ac55e6af1f4f1c05b47542a1eaa6c33c1cf0d300dca6a1c6e159a/rpds_py-0.28.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:8aa23b6f0fc59b85b4c7d89ba2965af274346f738e8d9fc2455763602e62fd5f", size = 552385, upload-time = "2025-10-22T22:23:08.557Z" }, + { url = "https://files.pythonhosted.org/packages/5d/cb/6ca2d70cbda5a8e36605e7788c4aa3bea7c17d71d213465a5a675079b98d/rpds_py-0.28.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7b14b0c680286958817c22d76fcbca4800ddacef6f678f3a7c79a1fe7067fe37", size = 575642, upload-time = "2025-10-22T22:23:10.348Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d4/407ad9960ca7856d7b25c96dcbe019270b5ffdd83a561787bc682c797086/rpds_py-0.28.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:bcf1d210dfee61a6c86551d67ee1031899c0fdbae88b2d44a569995d43797712", size = 544507, upload-time = "2025-10-22T22:23:12.434Z" }, + { url = "https://files.pythonhosted.org/packages/51/31/2f46fe0efcac23fbf5797c6b6b7e1c76f7d60773e525cb65fcbc582ee0f2/rpds_py-0.28.0-cp313-cp313t-win32.whl", hash = "sha256:3aa4dc0fdab4a7029ac63959a3ccf4ed605fee048ba67ce89ca3168da34a1342", size = 205376, upload-time = "2025-10-22T22:23:13.979Z" }, + { url = "https://files.pythonhosted.org/packages/92/e4/15947bda33cbedfc134490a41841ab8870a72a867a03d4969d886f6594a2/rpds_py-0.28.0-cp313-cp313t-win_amd64.whl", hash = "sha256:7b7d9d83c942855e4fdcfa75d4f96f6b9e272d42fffcb72cd4bb2577db2e2907", size = 215907, upload-time = "2025-10-22T22:23:15.5Z" }, + { url = "https://files.pythonhosted.org/packages/08/47/ffe8cd7a6a02833b10623bf765fbb57ce977e9a4318ca0e8cf97e9c3d2b3/rpds_py-0.28.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:dcdcb890b3ada98a03f9f2bb108489cdc7580176cb73b4f2d789e9a1dac1d472", size = 353830, upload-time = "2025-10-22T22:23:17.03Z" }, + { url = "https://files.pythonhosted.org/packages/f9/9f/890f36cbd83a58491d0d91ae0db1702639edb33fb48eeb356f80ecc6b000/rpds_py-0.28.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:f274f56a926ba2dc02976ca5b11c32855cbd5925534e57cfe1fda64e04d1add2", size = 341819, upload-time = "2025-10-22T22:23:18.57Z" }, + { url = "https://files.pythonhosted.org/packages/09/e3/921eb109f682aa24fb76207698fbbcf9418738f35a40c21652c29053f23d/rpds_py-0.28.0-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4fe0438ac4a29a520ea94c8c7f1754cdd8feb1bc490dfda1bfd990072363d527", size = 373127, upload-time = "2025-10-22T22:23:20.216Z" }, + { url = "https://files.pythonhosted.org/packages/23/13/bce4384d9f8f4989f1a9599c71b7a2d877462e5fd7175e1f69b398f729f4/rpds_py-0.28.0-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8a358a32dd3ae50e933347889b6af9a1bdf207ba5d1a3f34e1a38cd3540e6733", size = 382767, upload-time = "2025-10-22T22:23:21.787Z" }, + { url = "https://files.pythonhosted.org/packages/23/e1/579512b2d89a77c64ccef5a0bc46a6ef7f72ae0cf03d4b26dcd52e57ee0a/rpds_py-0.28.0-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e80848a71c78aa328fefaba9c244d588a342c8e03bda518447b624ea64d1ff56", size = 517585, upload-time = "2025-10-22T22:23:23.699Z" }, + { url = "https://files.pythonhosted.org/packages/62/3c/ca704b8d324a2591b0b0adcfcaadf9c862375b11f2f667ac03c61b4fd0a6/rpds_py-0.28.0-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f586db2e209d54fe177e58e0bc4946bea5fb0102f150b1b2f13de03e1f0976f8", size = 399828, upload-time = "2025-10-22T22:23:25.713Z" }, + { url = "https://files.pythonhosted.org/packages/da/37/e84283b9e897e3adc46b4c88bb3f6ec92a43bd4d2f7ef5b13459963b2e9c/rpds_py-0.28.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ae8ee156d6b586e4292491e885d41483136ab994e719a13458055bec14cf370", size = 375509, upload-time = "2025-10-22T22:23:27.32Z" }, + { url = "https://files.pythonhosted.org/packages/1a/c2/a980beab869d86258bf76ec42dec778ba98151f253a952b02fe36d72b29c/rpds_py-0.28.0-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:a805e9b3973f7e27f7cab63a6b4f61d90f2e5557cff73b6e97cd5b8540276d3d", size = 392014, upload-time = "2025-10-22T22:23:29.332Z" }, + { url = "https://files.pythonhosted.org/packages/da/b5/b1d3c5f9d3fa5aeef74265f9c64de3c34a0d6d5cd3c81c8b17d5c8f10ed4/rpds_py-0.28.0-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5d3fd16b6dc89c73a4da0b4ac8b12a7ecc75b2864b95c9e5afed8003cb50a728", size = 402410, upload-time = "2025-10-22T22:23:31.14Z" }, + { url = "https://files.pythonhosted.org/packages/74/ae/cab05ff08dfcc052afc73dcb38cbc765ffc86f94e966f3924cd17492293c/rpds_py-0.28.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:6796079e5d24fdaba6d49bda28e2c47347e89834678f2bc2c1b4fc1489c0fb01", size = 553593, upload-time = "2025-10-22T22:23:32.834Z" }, + { url = "https://files.pythonhosted.org/packages/70/80/50d5706ea2a9bfc9e9c5f401d91879e7c790c619969369800cde202da214/rpds_py-0.28.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:76500820c2af232435cbe215e3324c75b950a027134e044423f59f5b9a1ba515", size = 576925, upload-time = "2025-10-22T22:23:34.47Z" }, + { url = "https://files.pythonhosted.org/packages/ab/12/85a57d7a5855a3b188d024b099fd09c90db55d32a03626d0ed16352413ff/rpds_py-0.28.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:bbdc5640900a7dbf9dd707fe6388972f5bbd883633eb68b76591044cfe346f7e", size = 542444, upload-time = "2025-10-22T22:23:36.093Z" }, + { url = "https://files.pythonhosted.org/packages/6c/65/10643fb50179509150eb94d558e8837c57ca8b9adc04bd07b98e57b48f8c/rpds_py-0.28.0-cp314-cp314-win32.whl", hash = "sha256:adc8aa88486857d2b35d75f0640b949759f79dc105f50aa2c27816b2e0dd749f", size = 207968, upload-time = "2025-10-22T22:23:37.638Z" }, + { url = "https://files.pythonhosted.org/packages/b4/84/0c11fe4d9aaea784ff4652499e365963222481ac647bcd0251c88af646eb/rpds_py-0.28.0-cp314-cp314-win_amd64.whl", hash = "sha256:66e6fa8e075b58946e76a78e69e1a124a21d9a48a5b4766d15ba5b06869d1fa1", size = 218876, upload-time = "2025-10-22T22:23:39.179Z" }, + { url = "https://files.pythonhosted.org/packages/0f/e0/3ab3b86ded7bb18478392dc3e835f7b754cd446f62f3fc96f4fe2aca78f6/rpds_py-0.28.0-cp314-cp314-win_arm64.whl", hash = "sha256:a6fe887c2c5c59413353b7c0caff25d0e566623501ccfff88957fa438a69377d", size = 212506, upload-time = "2025-10-22T22:23:40.755Z" }, + { url = "https://files.pythonhosted.org/packages/51/ec/d5681bb425226c3501eab50fc30e9d275de20c131869322c8a1729c7b61c/rpds_py-0.28.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:7a69df082db13c7070f7b8b1f155fa9e687f1d6aefb7b0e3f7231653b79a067b", size = 355433, upload-time = "2025-10-22T22:23:42.259Z" }, + { url = "https://files.pythonhosted.org/packages/be/ec/568c5e689e1cfb1ea8b875cffea3649260955f677fdd7ddc6176902d04cd/rpds_py-0.28.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:b1cde22f2c30ebb049a9e74c5374994157b9b70a16147d332f89c99c5960737a", size = 342601, upload-time = "2025-10-22T22:23:44.372Z" }, + { url = "https://files.pythonhosted.org/packages/32/fe/51ada84d1d2a1d9d8f2c902cfddd0133b4a5eb543196ab5161d1c07ed2ad/rpds_py-0.28.0-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5338742f6ba7a51012ea470bd4dc600a8c713c0c72adaa0977a1b1f4327d6592", size = 372039, upload-time = "2025-10-22T22:23:46.025Z" }, + { url = "https://files.pythonhosted.org/packages/07/c1/60144a2f2620abade1a78e0d91b298ac2d9b91bc08864493fa00451ef06e/rpds_py-0.28.0-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e1460ebde1bcf6d496d80b191d854adedcc619f84ff17dc1c6d550f58c9efbba", size = 382407, upload-time = "2025-10-22T22:23:48.098Z" }, + { url = "https://files.pythonhosted.org/packages/45/ed/091a7bbdcf4038a60a461df50bc4c82a7ed6d5d5e27649aab61771c17585/rpds_py-0.28.0-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e3eb248f2feba84c692579257a043a7699e28a77d86c77b032c1d9fbb3f0219c", size = 518172, upload-time = "2025-10-22T22:23:50.16Z" }, + { url = "https://files.pythonhosted.org/packages/54/dd/02cc90c2fd9c2ef8016fd7813bfacd1c3a1325633ec8f244c47b449fc868/rpds_py-0.28.0-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd3bbba5def70b16cd1c1d7255666aad3b290fbf8d0fe7f9f91abafb73611a91", size = 399020, upload-time = "2025-10-22T22:23:51.81Z" }, + { url = "https://files.pythonhosted.org/packages/ab/81/5d98cc0329bbb911ccecd0b9e19fbf7f3a5de8094b4cda5e71013b2dd77e/rpds_py-0.28.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3114f4db69ac5a1f32e7e4d1cbbe7c8f9cf8217f78e6e002cedf2d54c2a548ed", size = 377451, upload-time = "2025-10-22T22:23:53.711Z" }, + { url = "https://files.pythonhosted.org/packages/b4/07/4d5bcd49e3dfed2d38e2dcb49ab6615f2ceb9f89f5a372c46dbdebb4e028/rpds_py-0.28.0-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:4b0cb8a906b1a0196b863d460c0222fb8ad0f34041568da5620f9799b83ccf0b", size = 390355, upload-time = "2025-10-22T22:23:55.299Z" }, + { url = "https://files.pythonhosted.org/packages/3f/79/9f14ba9010fee74e4f40bf578735cfcbb91d2e642ffd1abe429bb0b96364/rpds_py-0.28.0-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cf681ac76a60b667106141e11a92a3330890257e6f559ca995fbb5265160b56e", size = 403146, upload-time = "2025-10-22T22:23:56.929Z" }, + { url = "https://files.pythonhosted.org/packages/39/4c/f08283a82ac141331a83a40652830edd3a4a92c34e07e2bbe00baaea2f5f/rpds_py-0.28.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1e8ee6413cfc677ce8898d9cde18cc3a60fc2ba756b0dec5b71eb6eb21c49fa1", size = 552656, upload-time = "2025-10-22T22:23:58.62Z" }, + { url = "https://files.pythonhosted.org/packages/61/47/d922fc0666f0dd8e40c33990d055f4cc6ecff6f502c2d01569dbed830f9b/rpds_py-0.28.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:b3072b16904d0b5572a15eb9d31c1954e0d3227a585fc1351aa9878729099d6c", size = 576782, upload-time = "2025-10-22T22:24:00.312Z" }, + { url = "https://files.pythonhosted.org/packages/d3/0c/5bafdd8ccf6aa9d3bfc630cfece457ff5b581af24f46a9f3590f790e3df2/rpds_py-0.28.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:b670c30fd87a6aec281c3c9896d3bae4b205fd75d79d06dc87c2503717e46092", size = 544671, upload-time = "2025-10-22T22:24:02.297Z" }, + { url = "https://files.pythonhosted.org/packages/2c/37/dcc5d8397caa924988693519069d0beea077a866128719351a4ad95e82fc/rpds_py-0.28.0-cp314-cp314t-win32.whl", hash = "sha256:8014045a15b4d2b3476f0a287fcc93d4f823472d7d1308d47884ecac9e612be3", size = 205749, upload-time = "2025-10-22T22:24:03.848Z" }, + { url = "https://files.pythonhosted.org/packages/d7/69/64d43b21a10d72b45939a28961216baeb721cc2a430f5f7c3bfa21659a53/rpds_py-0.28.0-cp314-cp314t-win_amd64.whl", hash = "sha256:7a4e59c90d9c27c561eb3160323634a9ff50b04e4f7820600a2beb0ac90db578", size = 216233, upload-time = "2025-10-22T22:24:05.471Z" }, +] + [[package]] name = "rsa" version = "4.9.1" @@ -3467,6 +3617,7 @@ name = "simplexity" version = "0.1" source = { editable = "." } dependencies = [ + { name = "altair" }, { name = "chex" }, { name = "dotenv" }, { name = "equinox" }, @@ -3503,6 +3654,7 @@ dev = [ { name = "diff-cover" }, { name = "jaxtyping" }, { name = "pylint" }, + { name = "pylint-per-file-ignores" }, { name = "pyright" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -3519,6 +3671,7 @@ penzai = [ [package.metadata] requires-dist = [ + { name = "altair", specifier = ">=5.3.0" }, { name = "boto3", marker = "extra == 'aws'", specifier = ">=1.37.24" }, { name = "chex" }, { name = "diff-cover", marker = "extra == 'dev'" }, @@ -3547,6 +3700,7 @@ requires-dist = [ { name = "penzai", marker = "extra == 'penzai'" }, { name = "plotly" }, { name = "pylint", marker = "extra == 'dev'" }, + { name = "pylint-per-file-ignores", marker = "extra == 'dev'" }, { name = "pyright", marker = "extra == 'dev'" }, { name = "pytest", marker = "extra == 'dev'" }, { name = "pytest-cov", marker = "extra == 'dev'" }, From e34fea521e32d360e45c54a7b380a54c2ab45d49 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Thu, 13 Nov 2025 18:23:42 -0800 Subject: [PATCH 3/9] Add plotly support --- docs/visualization.md | 64 +++-- .../configs/visualization/3d_scatter.yaml | 5 + .../configs/visualization/plot/scatter3d.yaml | 26 ++ examples/visualization_3d_demo.py | 198 ++++++++++++++ simplexity/visualization/__init__.py | 6 +- simplexity/visualization/altair_renderer.py | 251 ++++-------------- simplexity/visualization/data_pipeline.py | 194 ++++++++++++++ simplexity/visualization/plotly_renderer.py | 158 +++++++++++ .../visualization/structured_configs.py | 43 +-- 9 files changed, 690 insertions(+), 255 deletions(-) create mode 100644 examples/configs/visualization/3d_scatter.yaml create mode 100644 examples/configs/visualization/plot/scatter3d.yaml create mode 100644 examples/visualization_3d_demo.py create mode 100644 simplexity/visualization/data_pipeline.py create mode 100644 simplexity/visualization/plotly_renderer.py diff --git a/docs/visualization.md b/docs/visualization.md index 8e2a076e..8e656a68 100644 --- a/docs/visualization.md +++ b/docs/visualization.md @@ -435,6 +435,7 @@ class AestheticsConfig: """ x: ChannelAestheticsConfig | None = None y: ChannelAestheticsConfig | None = None + z: ChannelAestheticsConfig | None = None color: ChannelAestheticsConfig | None = None size: ChannelAestheticsConfig | None = None @@ -456,6 +457,9 @@ This structure is expressive enough for: - ggplot2: `aes(x=..., y=..., color=...)` - Matplotlib: parameters to `scatter`, `plot`, etc. +The optional `z` channel is ignored by strictly 2D backends, but enables 3D +scatter support for engines such as Plotly. + --- ## 5.7 Geometry @@ -915,31 +919,49 @@ def build_altair_chart(plot_cfg: PlotConfig, data_registry: dict[str, pd.DataFrame]) -> alt.Chart ``` +### Plotly Renderer (Prototype) + +A lightweight Plotly backend focuses on interactive 3D scatter plots. It reuses +the shared pandas pipeline (filters + transforms) and maps the first layer of a +`PlotConfig` into `plotly.express.scatter_3d`. Current constraints: + +- Single point geometry layer per plot (sufficient for demos/prototypes) +- Requires `x`, `y`, and `z` aesthetics +- Honors `color`, `size`, `opacity`, and tooltip channel lists +- Writes self-contained HTML via `Figure.write_html` + +Renderer API: + +```python +def build_plotly_figure(plot_cfg: PlotConfig, + data_registry: dict[str, pd.DataFrame]) -> plotly.Figure +``` + --- ## 9. Backend Capability Matrix The following table outlines which features are supported by each backend: -| Feature | Altair | Plotnine | Matplotlib | Notes | -| ----------------- | ------ | -------- | ---------- | ----------------------------- | -| **Transforms** | | | | | -| filter | βœ… | βœ… | βœ… | pandas query / Vega-Lite | -| calculate | βœ… | ⚠️ | ⚠️ | Via pandas eval / limited | -| aggregate | βœ… | βœ… | βœ… | Via pandas groupby | -| bin | βœ… | βœ… | ⚠️ | Manual binning for matplotlib | -| window | βœ… | ⚠️ | ⚠️ | Limited pandas support | -| fold/pivot | βœ… | βœ… | βœ… | Via pandas | -| **Geometries** | | | | | -| point, line, bar | βœ… | βœ… | βœ… | Core geometries | -| area, rect | βœ… | βœ… | βœ… | | -| text | βœ… | βœ… | βœ… | | -| boxplot, errorbar | βœ… | βœ… | βœ… | | -| **Interactivity** | | | | | -| Selections | βœ… | ❌ | ❌ | Altair-only | -| Tooltips | βœ… | ❌ | ⚠️ | Limited in matplotlib | -| **Facets** | βœ… | βœ… | ⚠️ | Manual subplots in matplotlib | -| **Scales** | βœ… | βœ… | βœ… | All support log/linear/etc | +| Feature | Altair | Plotnine | Matplotlib | Plotly | Notes | +| ----------------- | ------ | -------- | ---------- | ------ | ------------------------------------------ | +| **Transforms** | | | | | | +| filter | βœ… | βœ… | βœ… | βœ… | pandas query / Vega-Lite | +| calculate | βœ… | ⚠️ | ⚠️ | βœ… | Via pandas eval / limited | +| aggregate | βœ… | βœ… | βœ… | βœ… | Via pandas groupby | +| bin | βœ… | βœ… | ⚠️ | βœ… | Manual binning for matplotlib | +| window | βœ… | ⚠️ | ⚠️ | βœ… | Limited pandas support | +| fold/pivot | βœ… | βœ… | βœ… | βœ… | Via pandas | +| **Geometries** | | | | | | +| point, line, bar | βœ… | βœ… | βœ… | ⚠️ | Plotly prototype currently supports point | +| area, rect | βœ… | βœ… | βœ… | ❌ | | +| text | βœ… | βœ… | βœ… | ❌ | | +| boxplot, errorbar | βœ… | βœ… | βœ… | ❌ | | +| **Interactivity** | | | | | | +| Selections | βœ… | ❌ | ❌ | ❌ | Altair-only | +| Tooltips | βœ… | ❌ | ⚠️ | βœ… | Plotly hover support is built-in | +| **Facets** | βœ… | βœ… | ⚠️ | ⚠️ | Plotly demo does not yet facet 3D charts | +| **Scales** | βœ… | βœ… | βœ… | βœ… | All support log/linear/etc | **Legend:** @@ -978,11 +1000,15 @@ This configuration system provides: - A **unified Grammar-of-Graphics-style schema** - **Declarative, reproducible plots** - **Immediate support for Altair** (interactive plots) +- Prototype support for **Plotly 3D scatter** rendering - A path to **plotnine/matplotlib** (static publication-ready plots) - A clean separation between _configuration_, _data_, and _rendering backends_ By adopting this design, we gain long-term flexibility in visualization tooling while keeping plot definitions clean, expressive, and consistent across projects. +See `examples/visualization_3d_demo.py` plus the Hydra configs in +`examples/configs/visualization/` for a complete YAML-driven demo. + --- ## 13. Future Extensions diff --git a/examples/configs/visualization/3d_scatter.yaml b/examples/configs/visualization/3d_scatter.yaml new file mode 100644 index 00000000..c28699a3 --- /dev/null +++ b/examples/configs/visualization/3d_scatter.yaml @@ -0,0 +1,5 @@ +defaults: + - data: synthetic_cloud + - plot: scatter3d + +output_html: scatter3d_demo.html diff --git a/examples/configs/visualization/plot/scatter3d.yaml b/examples/configs/visualization/plot/scatter3d.yaml new file mode 100644 index 00000000..d5d2d7ae --- /dev/null +++ b/examples/configs/visualization/plot/scatter3d.yaml @@ -0,0 +1,26 @@ +backend: plotly +data: + source: cloud +layers: + - name: cluster_cloud + geometry: + type: point + props: + size: 8 + aesthetics: + x: { field: x, type: quantitative, title: "X position" } + y: { field: y, type: quantitative, title: "Y position" } + z: { field: z, type: quantitative, title: "Z position" } + color: { field: cluster, type: nominal, title: Cluster } + size: { field: magnitude, type: quantitative } + opacity: { value: 0.85 } + tooltip: + - { field: cluster, type: nominal, title: Cluster } + - { field: magnitude, type: quantitative, title: Magnitude } +size: + width: 800 + height: 600 +guides: + title: "Synthetic 3D Scatter" + subtitle: "Points sampled from multivariate Gaussians" + caption: "Configured entirely via Hydra YAML" diff --git a/examples/visualization_3d_demo.py b/examples/visualization_3d_demo.py new file mode 100644 index 00000000..ed9b0a8e --- /dev/null +++ b/examples/visualization_3d_demo.py @@ -0,0 +1,198 @@ +"""Hydra-powered demo that renders a 3D scatter plot via PlotConfig YAML.""" + +from __future__ import annotations + +import types +from dataclasses import dataclass, field, fields, is_dataclass +from pathlib import Path +from typing import Any, Union, cast, get_args, get_origin, get_type_hints + +import hydra +import numpy as np +import pandas as pd +from hydra.utils import get_original_cwd +from omegaconf import DictConfig, OmegaConf + +from simplexity.visualization import DictDataRegistry, PlotConfig, build_altair_chart, build_plotly_figure + + +@dataclass +class SyntheticDataConfig: + """Configuration for generating synthetic 3D clusters.""" + + source_name: str = "cloud" + num_points: int = 600 + clusters: int = 4 + cluster_spread: float = 0.8 + seed: int = 11 + + +@dataclass +class Scatter3DDemoConfig: + """Root Hydra config for the demo.""" + + data: SyntheticDataConfig = field(default_factory=SyntheticDataConfig) + plot: PlotConfig = field(default_factory=PlotConfig) + output_html: str = "scatter3d_demo.html" + + +@hydra.main(version_base=None, config_path="configs/visualization", config_name="3d_scatter") +def main(cfg: DictConfig) -> None: + """Main entry point for the demo.""" + data_cfg = _convert_cfg(cfg.data, SyntheticDataConfig) + plot_cfg = _convert_cfg(cfg.plot, PlotConfig) + output_html = cast(str, cfg.get("output_html", "scatter3d_demo.html")) + dataframe = _generate_dataset(data_cfg) + registry = DictDataRegistry({data_cfg.source_name: dataframe}) + + if plot_cfg.backend == "plotly": + figure = build_plotly_figure(plot_cfg, registry) + _save_plotly_figure(figure, output_html) + else: + chart = build_altair_chart(plot_cfg, registry) + _save_altair_chart(chart, output_html) + + print(f"Saved interactive plot to {output_html}") # noqa: T201 - demo script output + + +def _generate_dataset(cfg: SyntheticDataConfig) -> pd.DataFrame: + rng = np.random.default_rng(cfg.seed) + points_per_cluster = max(1, cfg.num_points // cfg.clusters) + remainder = cfg.num_points % cfg.clusters + records: list[dict[str, float | int | str]] = [] + for cluster_idx in range(cfg.clusters): + center = rng.normal(0.0, cfg.cluster_spread * 3.0, size=3) + count = points_per_cluster + (1 if cluster_idx < remainder else 0) + for _ in range(count): + noise = rng.normal(0.0, cfg.cluster_spread, size=3) + x, y, z = center + noise + magnitude = float(np.sqrt(x**2 + y**2 + z**2)) + records.append( + { + "cluster": f"C{cluster_idx + 1}", + "x": float(x), + "y": float(y), + "z": float(z), + "magnitude": magnitude, + } + ) + return pd.DataFrame.from_records(records) + + +def _convert_cfg[T](cfg_section: DictConfig, schema: type[T]) -> T: + """Convert DictConfig to dataclass instance, handling nested dataclasses recursively.""" + # Convert DictConfig to plain dict to avoid OmegaConf's Union/Literal type validation issues + cfg_dict = OmegaConf.to_container(cfg_section, resolve=True) or {} + return _dict_to_dataclass(cfg_dict, schema) + + +def _convert_value_by_type(value: Any, field_type: Any) -> Any: + """Convert a value based on its expected type (handles lists, dataclasses, etc.).""" + origin = get_origin(field_type) + + # Handle list types + if origin is list: + args = get_args(field_type) + if isinstance(value, list) and args: + item_type = args[0] + if is_dataclass(item_type): + return [ + _dict_to_dataclass(item, item_type) if isinstance(item, dict) else item # type: ignore[arg-type] + for item in value + ] + return value + # Handle dataclass types + elif isinstance(value, dict) and is_dataclass(field_type): + return _dict_to_dataclass(value, field_type) # type: ignore[arg-type] + + return value + + +def _dict_to_dataclass(data: dict[str, Any] | Any, schema: type[Any]) -> Any: + """Recursively convert dict to dataclass instance, handling nested structures.""" + if not isinstance(data, dict): + return data + + if not is_dataclass(schema): + return data + + # Get field types from the dataclass schema, resolving string annotations + try: + field_types = get_type_hints(schema) + except (TypeError, NameError): + # Fallback to field.type if get_type_hints fails (e.g., forward references) + field_types = {f.name: f.type for f in fields(schema)} + + # Convert nested dicts to their corresponding dataclass types + converted: dict[str, Any] = {} + for key, value in data.items(): + if key not in field_types: + converted[key] = value + continue + + field_type = field_types[key] + origin = get_origin(field_type) + + # Handle Optional types (Union[X, None] or X | None) + if origin is Union or origin is types.UnionType: + args = get_args(field_type) + # Handle Optional[X] -> Union[X, None] + if args and len(args) == 2 and type(None) in args: + if value is None: + converted[key] = None + else: + non_none_type = next((t for t in args if t is not type(None)), None) + if non_none_type: + # Recursively handle the non-None type (could be a list, dict, etc.) + converted[key] = _convert_value_by_type(value, non_none_type) + else: + converted[key] = value + elif args and isinstance(value, dict): + # For other Union types, try to find a dataclass type that matches + dataclass_type = next((t for t in args if is_dataclass(t)), None) + if dataclass_type: + converted[key] = _dict_to_dataclass(value, dataclass_type) # type: ignore[arg-type] + else: + converted[key] = value + else: + # For other Union types, try to convert based on the first non-None type + non_none_types = [t for t in args if t is not type(None)] if args else [] + if non_none_types and value is not None: + converted[key] = _convert_value_by_type(value, non_none_types[0]) + else: + converted[key] = value + # Handle list types + elif origin is list: + args = get_args(field_type) + if isinstance(value, list) and args: + item_type = args[0] + if is_dataclass(item_type): + converted[key] = [ + _dict_to_dataclass(item, item_type) if isinstance(item, dict) else item # type: ignore[arg-type] + for item in value + ] + else: + converted[key] = value + else: + converted[key] = value + # Handle direct dataclass types + elif isinstance(value, dict) and is_dataclass(field_type): + converted[key] = _dict_to_dataclass(value, field_type) # type: ignore[arg-type] + else: + converted[key] = value + + return schema(**converted) + + +def _save_plotly_figure(figure, filename: str) -> None: + output_path = Path(get_original_cwd()) / filename + figure.write_html(str(output_path), include_plotlyjs="cdn") + + +def _save_altair_chart(chart, filename: str) -> None: + output_path = Path(get_original_cwd()) / filename + chart.save(str(output_path)) + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/simplexity/visualization/__init__.py b/simplexity/visualization/__init__.py index eb036267..885d46f9 100644 --- a/simplexity/visualization/__init__.py +++ b/simplexity/visualization/__init__.py @@ -1,6 +1,8 @@ -"""Visualization utilities and configuration schemas.""" +"""Visualization utilities, renderers, and structured config schemas.""" +from .altair_renderer import build_altair_chart from .data_registry import DataRegistry, DictDataRegistry, resolve_data_source +from .plotly_renderer import build_plotly_figure from .structured_configs import ( AestheticsConfig, AxisConfig, @@ -26,6 +28,8 @@ ) __all__ = [ + "build_altair_chart", + "build_plotly_figure", "DataRegistry", "DictDataRegistry", "resolve_data_source", diff --git a/simplexity/visualization/altair_renderer.py b/simplexity/visualization/altair_renderer.py index 53a659bd..970e9a07 100644 --- a/simplexity/visualization/altair_renderer.py +++ b/simplexity/visualization/altair_renderer.py @@ -3,21 +3,21 @@ from __future__ import annotations import logging -import math from collections.abc import Mapping from typing import Any -import altair as alt -import numpy as np import pandas as pd from simplexity.exceptions import ConfigValidationError -from simplexity.visualization.data_registry import DataRegistry, resolve_data_source +from simplexity.visualization.data_pipeline import ( + build_plot_level_dataframe, + resolve_layer_dataframe, +) +from simplexity.visualization.data_registry import DataRegistry from simplexity.visualization.structured_configs import ( AestheticsConfig, AxisConfig, ChannelAestheticsConfig, - DataConfig, FacetConfig, GeometryConfig, LayerConfig, @@ -27,22 +27,10 @@ PlotSizeConfig, ScaleConfig, SelectionConfig, - TransformConfig, ) LOGGER = logging.getLogger(__name__) -_CALC_ENV = { - "np": np, - "pd": pd, - "math": math, - "log": np.log, - "exp": np.exp, - "sqrt": np.sqrt, - "abs": np.abs, - "clip": np.clip, -} - _CHANNEL_CLASS_MAP = { "x": "X", "y": "Y", @@ -55,180 +43,51 @@ } -def build_altair_chart(plot_cfg: PlotConfig, data_registry: DataRegistry | Mapping[str, pd.DataFrame]): - """Render a PlotConfig into an Altair chart.""" +def build_altair_chart( + plot_cfg: PlotConfig, + data_registry: DataRegistry | Mapping[str, pd.DataFrame], +): + """Render a PlotConfig into an Altair Chart.""" + alt = _import_altair() if not plot_cfg.layers: - raise ConfigValidationError("PlotConfig.layers must include at least one layer for Altair rendering") + raise ConfigValidationError("PlotConfig.layers must include at least one layer for Altair rendering.") - plot_df = _materialize_data(plot_cfg.data, data_registry) - plot_df = _apply_transforms(plot_df, plot_cfg.transforms) + plot_df = build_plot_level_dataframe(plot_cfg.data, plot_cfg.transforms, data_registry) layer_charts = [ - _build_layer_chart(layer, _resolve_layer_dataframe(layer, plot_df, data_registry)) for layer in plot_cfg.layers + _build_layer_chart(alt, layer, resolve_layer_dataframe(layer, plot_df, data_registry)) + for layer in plot_cfg.layers ] chart = layer_charts[0] if len(layer_charts) == 1 else alt.layer(*layer_charts) if plot_cfg.selections: - chart = chart.add_params(*[_build_selection_param(sel) for sel in plot_cfg.selections]) + chart = chart.add_params(*[_build_selection_param(alt, sel) for sel in plot_cfg.selections]) if plot_cfg.facet: - chart = _apply_facet(chart, plot_cfg.facet) + chart = _apply_facet(alt, chart, plot_cfg.facet) - chart = _apply_plot_level_properties(chart, plot_cfg.guides, plot_cfg.size, plot_cfg.background) + chart = _apply_plot_level_properties(alt, chart, plot_cfg.guides, plot_cfg.size, plot_cfg.background) return chart -def _materialize_data(data_cfg: DataConfig, data_registry: DataRegistry | Mapping[str, pd.DataFrame]) -> pd.DataFrame: - df = resolve_data_source(data_cfg.source, data_registry).copy() - if data_cfg.filters: - df = _apply_filters(df, data_cfg.filters) - if data_cfg.columns: - missing = [col for col in data_cfg.columns if col not in df.columns] - if missing: - raise ConfigValidationError(f"Columns {missing} are not present in data source '{data_cfg.source}'") - df = df.loc[:, data_cfg.columns] - return df +def _import_altair(): + try: + import altair as alt # type: ignore import-not-found + except ImportError as exc: # pragma: no cover - dependency missing only in unsupported envs + raise ImportError("Altair is required for visualization rendering. Install `altair` to continue.") from exc + return alt -def _resolve_layer_dataframe( - layer: LayerConfig, - plot_df: pd.DataFrame, - data_registry: DataRegistry | Mapping[str, pd.DataFrame], -) -> pd.DataFrame: - if layer.data is None: - layer_df = plot_df.copy() - else: - layer_df = _materialize_data(layer.data, data_registry) - - if layer.transforms: - layer_df = _apply_transforms(layer_df, layer.transforms) - return layer_df - - -def _apply_filters(df: pd.DataFrame, filters: list[str]) -> pd.DataFrame: - result = df.copy() - for expr in filters: - norm_expr = _normalize_expression(expr) - result = result.query(norm_expr, engine="python", local_dict=_CALC_ENV) - return result - - -def _apply_transforms(df: pd.DataFrame, transforms: list[TransformConfig]) -> pd.DataFrame: - result = df.copy() - for transform in transforms: - result = _apply_transform(result, transform) - return result - - -def _apply_transform(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: - if transform.op == "filter": - if transform.filter is None: - raise ConfigValidationError("Filter transforms require the `filter` expression.") - return _apply_filters(df, [transform.filter]) - if transform.op == "calculate": - return _apply_calculate(df, transform) - if transform.op == "aggregate": - return _apply_aggregate(df, transform) - if transform.op == "bin": - return _apply_bin(df, transform) - if transform.op == "window": - return _apply_window(df, transform) - if transform.op == "fold": - return _apply_fold(df, transform) - if transform.op == "pivot": - raise ConfigValidationError("Pivot transforms are not implemented yet for the Altair renderer.") - raise ConfigValidationError(f"Unsupported transform operation '{transform.op}'") - - -def _apply_calculate(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: - expr = _normalize_expression(transform.expr or "") - target = transform.as_field or "" - if not target: - raise ConfigValidationError("TransformConfig.as_field is required for calculate transforms") - result = df.copy() - result[target] = result.eval(expr, engine="python", local_dict=_CALC_ENV) - return result - - -def _apply_aggregate(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: - groupby = transform.groupby or [] - aggregations = transform.aggregations or {} - if not groupby or not aggregations: - raise ConfigValidationError("Aggregate transforms require `groupby` and `aggregations` fields.") - - agg_kwargs: dict[str, tuple[str, str]] = {} - for alias, expr in aggregations.items(): - func, field = _parse_function_expr(expr, expected_arg=True) - agg_kwargs[alias] = (field, func) - - grouped = df.groupby(groupby, dropna=False).agg(**agg_kwargs).reset_index() - return grouped - - -def _apply_bin(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: - if not transform.field or not transform.binned_as: - raise ConfigValidationError("Bin transforms require `field` and `binned_as`.") - bins = transform.maxbins or 10 - result = df.copy() - result[transform.binned_as] = pd.cut(result[transform.field], bins=bins, include_lowest=True) - return result - - -def _apply_window(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: - if not transform.window: - raise ConfigValidationError("Window transforms require the `window` mapping.") - result = df.copy() - for alias, expr in transform.window.items(): - func, field = _parse_function_expr(expr, expected_arg=True) - if func == "rank": - result[alias] = result[field].rank(method="average") - elif func == "cumsum": - result[alias] = result[field].cumsum() - else: - raise ConfigValidationError(f"Window function '{func}' is not supported.") - return result - - -def _apply_fold(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: - if not transform.fold_fields: - raise ConfigValidationError("Fold transforms require `fold_fields`.") - var_name, value_name = _derive_fold_names(transform.as_fields) - return df.melt(value_vars=transform.fold_fields, var_name=var_name, value_name=value_name) - - -def _parse_function_expr(expr: str, expected_arg: bool) -> tuple[str, str]: - if "(" not in expr or not expr.endswith(")"): - raise ConfigValidationError(f"Expression '{expr}' must be of the form func(field).") - func, rest = expr.split("(", 1) - value = rest[:-1].strip() - func = func.strip() - if expected_arg and not value: - raise ConfigValidationError(f"Expression '{expr}' must supply an argument.") - return func, value - - -def _derive_fold_names(as_fields: list[str] | None) -> tuple[str, str]: - if not as_fields: - return "key", "value" - if len(as_fields) == 1: - return as_fields[0], "value" - return as_fields[0], as_fields[1] - - -def _normalize_expression(expr: str) -> str: - return expr.replace("datum.", "").strip() - - -def _build_layer_chart(layer: LayerConfig, df: pd.DataFrame): +def _build_layer_chart(alt, layer: LayerConfig, df: pd.DataFrame): chart = alt.Chart(df) chart = _apply_geometry(chart, layer.geometry) - encoding_kwargs = _encode_aesthetics(layer.aesthetics) + encoding_kwargs = _encode_aesthetics(alt, layer.aesthetics) if encoding_kwargs: chart = chart.encode(**encoding_kwargs) if layer.selections: - chart = chart.add_params(*[_build_selection_param(sel) for sel in layer.selections]) + chart = chart.add_params(*[_build_selection_param(alt, sel) for sel in layer.selections]) return chart @@ -240,26 +99,25 @@ def _apply_geometry(chart, geometry: GeometryConfig): return mark_fn(**(geometry.props or {})) -def _encode_aesthetics(aesthetics: AestheticsConfig) -> dict[str, Any]: +def _encode_aesthetics(alt, aesthetics: AestheticsConfig) -> dict[str, Any]: encodings: dict[str, Any] = {} for channel_name in ("x", "y", "color", "size", "shape", "opacity", "row", "column"): channel_cfg = getattr(aesthetics, channel_name) - channel_value = _channel_to_alt(channel_name, channel_cfg) + channel_value = _channel_to_alt(alt, channel_name, channel_cfg) if channel_value is not None: encodings[channel_name] = channel_value if aesthetics.tooltip: - encodings["tooltip"] = [_tooltip_to_alt(tooltip_cfg) for tooltip_cfg in aesthetics.tooltip] + encodings["tooltip"] = [_tooltip_to_alt(alt, tooltip_cfg) for tooltip_cfg in aesthetics.tooltip] return encodings -def _channel_to_alt(channel_name: str, cfg: ChannelAestheticsConfig | None): +def _channel_to_alt(alt, channel_name: str, cfg: ChannelAestheticsConfig | None): if cfg is None: return None if cfg.value is not None and cfg.field is None: return alt.value(cfg.value) - channel_cls_name = _CHANNEL_CLASS_MAP[channel_name] channel_cls = getattr(alt, channel_cls_name) kwargs: dict[str, Any] = {} @@ -276,20 +134,17 @@ def _channel_to_alt(channel_name: str, cfg: ChannelAestheticsConfig | None): if cfg.time_unit: kwargs["timeUnit"] = cfg.time_unit if cfg.sort is not None: - if isinstance(cfg.sort, list): - kwargs["sort"] = alt.Sort(cfg.sort) - else: - kwargs["sort"] = cfg.sort + kwargs["sort"] = alt.Sort(cfg.sort) if isinstance(cfg.sort, list) else cfg.sort if cfg.scale: - kwargs["scale"] = _scale_to_alt(cfg.scale) + kwargs["scale"] = _scale_to_alt(alt, cfg.scale) if cfg.axis and channel_name in {"x", "y", "row", "column"}: - kwargs["axis"] = _axis_to_alt(cfg.axis) + kwargs["axis"] = _axis_to_alt(alt, cfg.axis) if cfg.legend and channel_name in {"color", "size", "shape", "opacity"}: - kwargs["legend"] = _legend_to_alt(cfg.legend) + kwargs["legend"] = _legend_to_alt(alt, cfg.legend) return channel_cls(**kwargs) -def _tooltip_to_alt(cfg: ChannelAestheticsConfig): +def _tooltip_to_alt(alt, cfg: ChannelAestheticsConfig): if cfg.value is not None and cfg.field is None: return alt.Tooltip(value=cfg.value, title=cfg.title) if cfg.field is None: @@ -303,40 +158,32 @@ def _tooltip_to_alt(cfg: ChannelAestheticsConfig): return alt.Tooltip(**kwargs) -def _scale_to_alt(cfg: ScaleConfig): +def _scale_to_alt(alt, cfg: ScaleConfig): kwargs = {k: v for k, v in vars(cfg).items() if v is not None} return alt.Scale(**kwargs) -def _axis_to_alt(cfg: AxisConfig): +def _axis_to_alt(alt, cfg: AxisConfig): kwargs = {k: v for k, v in vars(cfg).items() if v is not None} return alt.Axis(**kwargs) -def _legend_to_alt(cfg: LegendConfig): +def _legend_to_alt(alt, cfg: LegendConfig): kwargs = {k: v for k, v in vars(cfg).items() if v is not None} return alt.Legend(**kwargs) -def _build_selection_param(cfg: SelectionConfig): - kwargs: dict[str, Any] = {"name": cfg.name} - if cfg.encodings is not None: - kwargs["encodings"] = cfg.encodings # type: ignore[assignment] - if cfg.fields is not None: - kwargs["fields"] = cfg.fields - if cfg.bind is not None: - kwargs["bind"] = cfg.bind # type: ignore[assignment] - +def _build_selection_param(alt, cfg: SelectionConfig): if cfg.type == "interval": - return alt.selection_interval(**kwargs) + return alt.selection_interval(name=cfg.name, encodings=cfg.encodings, fields=cfg.fields, bind=cfg.bind) if cfg.type == "single": - return alt.selection_single(**kwargs) + return alt.selection_single(name=cfg.name, encodings=cfg.encodings, fields=cfg.fields, bind=cfg.bind) if cfg.type == "multi": - return alt.selection_multi(**kwargs) + return alt.selection_multi(name=cfg.name, encodings=cfg.encodings, fields=cfg.fields, bind=cfg.bind) raise ConfigValidationError(f"Unsupported selection type '{cfg.type}' for Altair renderer.") -def _apply_facet(chart, facet_cfg: FacetConfig): +def _apply_facet(alt, chart, facet_cfg: FacetConfig): facet_args: dict[str, Any] = {} if facet_cfg.row: facet_args["row"] = alt.Row(facet_cfg.row) @@ -349,8 +196,10 @@ def _apply_facet(chart, facet_cfg: FacetConfig): return chart.facet(**facet_args) -def _apply_plot_level_properties(chart, guides: PlotLevelGuideConfig, size: PlotSizeConfig, background: str | None): - title_params = _build_title_params(guides) +def _apply_plot_level_properties( + alt, chart, guides: PlotLevelGuideConfig, size: PlotSizeConfig, background: str | None +): + title_params = _build_title_params(alt, guides) if title_params is not None: chart = chart.properties(title=title_params) width = size.width @@ -360,13 +209,13 @@ def _apply_plot_level_properties(chart, guides: PlotLevelGuideConfig, size: Plot if size.autosize: chart.autosize = size.autosize if background: - chart.background = background + chart = chart.configure(background=background) if guides.labels: LOGGER.info("Plot-level labels are not yet implemented for Altair; skipping %s labels.", len(guides.labels)) return chart -def _build_title_params(guides: PlotLevelGuideConfig): +def _build_title_params(alt, guides: PlotLevelGuideConfig): subtitle_lines = [text for text in (guides.subtitle, guides.caption) if text] if not guides.title and not subtitle_lines: return None diff --git a/simplexity/visualization/data_pipeline.py b/simplexity/visualization/data_pipeline.py new file mode 100644 index 00000000..be7d88cf --- /dev/null +++ b/simplexity/visualization/data_pipeline.py @@ -0,0 +1,194 @@ +"""Reusable helpers for preparing data prior to rendering.""" + +from __future__ import annotations + +import math +from collections.abc import Mapping + +import numpy as np +import pandas as pd + +from simplexity.exceptions import ConfigValidationError +from simplexity.visualization.data_registry import DataRegistry, resolve_data_source +from simplexity.visualization.structured_configs import ( + DataConfig, + LayerConfig, + TransformConfig, +) + +CALC_ENV = { + "np": np, + "pd": pd, + "math": math, + "log": np.log, + "exp": np.exp, + "sqrt": np.sqrt, + "abs": np.abs, + "clip": np.clip, +} + + +def normalize_expression(expr: str) -> str: + """Normalize expressions shared between pandas and Vega-Lite syntaxes.""" + return expr.replace("datum.", "").strip() + + +def materialize_data(data_cfg: DataConfig, data_registry: DataRegistry | Mapping[str, pd.DataFrame]) -> pd.DataFrame: + """Resolve a logical data source and apply lightweight filters/column selection.""" + df = resolve_data_source(data_cfg.source, data_registry).copy() + if data_cfg.filters: + df = apply_filters(df, data_cfg.filters) + if data_cfg.columns: + missing = [col for col in data_cfg.columns if col not in df.columns] + if missing: + raise ConfigValidationError(f"Columns {missing} are not present in data source '{data_cfg.source}'") + df = df.loc[:, data_cfg.columns] + return df + + +def build_plot_level_dataframe( + data_cfg: DataConfig, + transforms: list[TransformConfig], + data_registry: DataRegistry | Mapping[str, pd.DataFrame], +) -> pd.DataFrame: + """Materialize the base dataframe for a plot, applying plot-level transforms.""" + df = materialize_data(data_cfg, data_registry) + return apply_transforms(df, transforms) + + +def resolve_layer_dataframe( + layer: LayerConfig, + plot_df: pd.DataFrame, + data_registry: DataRegistry | Mapping[str, pd.DataFrame], +) -> pd.DataFrame: + """Resolve the dataframe for an individual layer.""" + if layer.data is None: + df = plot_df.copy() + else: + df = materialize_data(layer.data, data_registry) + if layer.transforms: + df = apply_transforms(df, layer.transforms) + return df + + +def apply_filters(df: pd.DataFrame, filters: list[str]) -> pd.DataFrame: + """Apply pandas-compatible query filters.""" + result = df.copy() + for expr in filters: + norm_expr = normalize_expression(expr) + result = result.query(norm_expr, engine="python", local_dict=CALC_ENV) + return result + + +def apply_transforms(df: pd.DataFrame, transforms: list[TransformConfig]) -> pd.DataFrame: + """Sequentially apply configured transforms to a dataframe.""" + result = df.copy() + for transform in transforms: + result = _apply_transform(result, transform) + return result + + +def _apply_transform(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + if transform.op == "filter": + if transform.filter is None: + raise ConfigValidationError("Filter transforms require the `filter` expression.") + return apply_filters(df, [transform.filter]) + if transform.op == "calculate": + return _apply_calculate(df, transform) + if transform.op == "aggregate": + return _apply_aggregate(df, transform) + if transform.op == "bin": + return _apply_bin(df, transform) + if transform.op == "window": + return _apply_window(df, transform) + if transform.op == "fold": + return _apply_fold(df, transform) + if transform.op == "pivot": + raise ConfigValidationError("Pivot transforms are not implemented yet.") + raise ConfigValidationError(f"Unsupported transform operation '{transform.op}'") + + +def _apply_calculate(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + expr = normalize_expression(transform.expr or "") + target = transform.as_field or "" + if not target: + raise ConfigValidationError("TransformConfig.as_field is required for calculate transforms") + result = df.copy() + result[target] = result.eval(expr, engine="python", local_dict=CALC_ENV) + return result + + +def _apply_aggregate(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + groupby = transform.groupby or [] + aggregations = transform.aggregations or {} + if not groupby or not aggregations: + raise ConfigValidationError("Aggregate transforms require `groupby` and `aggregations` fields.") + + agg_kwargs: dict[str, tuple[str, str]] = {} + for alias, expr in aggregations.items(): + func, field = _parse_function_expr(expr, expected_arg=True) + agg_kwargs[alias] = (field, func) + + grouped = df.groupby(groupby, dropna=False).agg(**agg_kwargs).reset_index() + return grouped + + +def _apply_bin(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + if not transform.field or not transform.binned_as: + raise ConfigValidationError("Bin transforms require `field` and `binned_as`.") + bins = transform.maxbins or 10 + result = df.copy() + result[transform.binned_as] = pd.cut(result[transform.field], bins=bins, include_lowest=True) + return result + + +def _apply_window(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + if not transform.window: + raise ConfigValidationError("Window transforms require the `window` mapping.") + result = df.copy() + for alias, expr in transform.window.items(): + func, field = _parse_function_expr(expr, expected_arg=True) + if func == "rank": + result[alias] = result[field].rank(method="average") + elif func == "cumsum": + result[alias] = result[field].cumsum() + else: + raise ConfigValidationError(f"Window function '{func}' is not supported.") + return result + + +def _apply_fold(df: pd.DataFrame, transform: TransformConfig) -> pd.DataFrame: + if not transform.fold_fields: + raise ConfigValidationError("Fold transforms require `fold_fields`.") + var_name, value_name = _derive_fold_names(transform.as_fields) + return df.melt(value_vars=transform.fold_fields, var_name=var_name, value_name=value_name) + + +def _parse_function_expr(expr: str, expected_arg: bool) -> tuple[str, str]: + if "(" not in expr or not expr.endswith(")"): + raise ConfigValidationError(f"Expression '{expr}' must be of the form func(field).") + func, rest = expr.split("(", 1) + value = rest[:-1].strip() + func = func.strip() + if expected_arg and not value: + raise ConfigValidationError(f"Expression '{expr}' must supply an argument.") + return func, value + + +def _derive_fold_names(as_fields: list[str] | None) -> tuple[str, str]: + if not as_fields: + return "key", "value" + if len(as_fields) == 1: + return as_fields[0], "value" + return as_fields[0], as_fields[1] + + +__all__ = [ + "CALC_ENV", + "apply_filters", + "apply_transforms", + "build_plot_level_dataframe", + "materialize_data", + "normalize_expression", + "resolve_layer_dataframe", +] diff --git a/simplexity/visualization/plotly_renderer.py b/simplexity/visualization/plotly_renderer.py new file mode 100644 index 00000000..2606132d --- /dev/null +++ b/simplexity/visualization/plotly_renderer.py @@ -0,0 +1,158 @@ +"""Plotly renderer for visualization PlotConfigs.""" + +from __future__ import annotations + +import logging +from collections.abc import Mapping +from typing import Any + +import pandas as pd +import plotly.express as px + +from simplexity.exceptions import ConfigValidationError +from simplexity.visualization.data_pipeline import ( + build_plot_level_dataframe, + resolve_layer_dataframe, +) +from simplexity.visualization.data_registry import DataRegistry +from simplexity.visualization.structured_configs import ( + AestheticsConfig, + ChannelAestheticsConfig, + LayerConfig, + PlotConfig, + PlotLevelGuideConfig, + PlotSizeConfig, +) + +LOGGER = logging.getLogger(__name__) + + +def build_plotly_figure( + plot_cfg: PlotConfig, + data_registry: DataRegistry | Mapping[str, pd.DataFrame], +): + """Render a PlotConfig into a Plotly Figure (currently 3D scatter only).""" + if not plot_cfg.layers: + raise ConfigValidationError("PlotConfig.layers must include at least one layer for Plotly rendering.") + if len(plot_cfg.layers) != 1: + raise ConfigValidationError("Plotly renderer currently supports exactly one layer.") + + layer = plot_cfg.layers[0] + if layer.geometry.type != "point": + raise ConfigValidationError("Plotly renderer currently supports point geometry for 3D scatter demo.") + + plot_df = build_plot_level_dataframe(plot_cfg.data, plot_cfg.transforms, data_registry) + layer_df = resolve_layer_dataframe(layer, plot_df, data_registry) + + figure = _build_scatter3d(layer, layer_df) + figure = _apply_plot_level_properties(figure, plot_cfg.guides, plot_cfg.size, plot_cfg.background, layer.aesthetics) + return figure + + +def _build_scatter3d(layer: LayerConfig, df: pd.DataFrame): + aes = layer.aesthetics + x_field = _require_field(aes.x, "x") + y_field = _require_field(aes.y, "y") + z_field = _require_field(aes.z, "z") + + color_field = _optional_field(aes.color) + size_field = _optional_field(aes.size) + opacity_value = _resolve_opacity(aes.opacity) + hover_fields = _collect_tooltip_fields(aes.tooltip) + + figure = px.scatter_3d( + df, + x=x_field, + y=y_field, + z=z_field, + color=color_field, + size=size_field, + hover_data=hover_fields or None, + opacity=opacity_value, + ) + + if aes.color and aes.color.value is not None: + figure.update_traces(marker=dict(color=aes.color.value)) + if aes.size and aes.size.value is not None: + figure.update_traces(marker=dict(size=aes.size.value)) + + trace_name = layer.name or (color_field or "3d_scatter") + figure.update_traces(name=trace_name, selector=dict(type="scatter3d")) + return figure + + +def _apply_plot_level_properties( + figure, + guides: PlotLevelGuideConfig, + size: PlotSizeConfig, + background: str | None, + aes: AestheticsConfig, +): + title_lines = [guides.title] if guides.title else [] + title_lines += [text for text in (guides.subtitle, guides.caption) if text] + if title_lines: + figure.update_layout(title="
".join(title_lines)) + if size.width or size.height: + figure.update_layout(width=size.width, height=size.height) + + scene_updates: dict[str, Any] = {} + x_title = _axis_title(aes.x) + y_title = _axis_title(aes.y) + z_title = _axis_title(aes.z) + if x_title: + scene_updates.setdefault("xaxis", {})["title"] = x_title + if y_title: + scene_updates.setdefault("yaxis", {})["title"] = y_title + if z_title: + scene_updates.setdefault("zaxis", {})["title"] = z_title + if background: + scene_updates["bgcolor"] = background + if scene_updates: + figure.update_layout(scene=scene_updates) + + if guides.labels: + LOGGER.info("Plot-level labels are not yet implemented for Plotly; skipping %s labels.", len(guides.labels)) + return figure + + +def _require_field(channel: ChannelAestheticsConfig | None, name: str) -> str: + if channel is None or not channel.field: + raise ConfigValidationError(f"Plotly renderer requires '{name}' channel with a field specified.") + return channel.field + + +def _optional_field(channel: ChannelAestheticsConfig | None) -> str | None: + if channel is None: + return None + return channel.field + + +def _collect_tooltip_fields(tooltips: list[ChannelAestheticsConfig] | None) -> list[str]: + if not tooltips: + return [] + fields: list[str] = [] + for tooltip in tooltips: + if tooltip.field is None: + raise ConfigValidationError("Plotly renderer tooltip entries must reference a data field.") + fields.append(tooltip.field) + return fields + + +def _resolve_opacity(channel: ChannelAestheticsConfig | None) -> float | None: + if channel is None: + return None + if channel.value is None: + raise ConfigValidationError("Plotly renderer opacity channel must specify a constant value.") + try: + opacity = float(channel.value) + except (TypeError, ValueError) as exc: + raise ConfigValidationError("Opacity channel must be a numeric constant.") from exc + if not 0.0 <= opacity <= 1.0: + raise ConfigValidationError("Opacity value must be between 0 and 1.") + return opacity + + +def _axis_title(channel: ChannelAestheticsConfig | None) -> str | None: + if channel is None: + return None + return channel.title or channel.field diff --git a/simplexity/visualization/structured_configs.py b/simplexity/visualization/structured_configs.py index 6dcea953..c156a589 100644 --- a/simplexity/visualization/structured_configs.py +++ b/simplexity/visualization/structured_configs.py @@ -9,36 +9,10 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Literal +from typing import Any from simplexity.exceptions import ConfigValidationError -BackendType = Literal["altair"] # Currently only Altair is supported, but we could add other backends later - -TransformOp = Literal["filter", "calculate", "aggregate", "bin", "window", "fold", "pivot"] - -ScaleType = Literal["linear", "log", "sqrt", "pow", "symlog", "time", "utc", "ordinal", "band", "point"] - -ChannelType = Literal["quantitative", "ordinal", "nominal", "temporal"] - -GeometryType = Literal[ - "point", - "line", - "area", - "bar", - "rect", - "rule", - "tick", - "circle", - "square", - "text", - "boxplot", - "errorbar", - "errorband", -] - -SelectionType = Literal["interval", "single", "multi"] - def _ensure(condition: bool, message: str) -> None: """Raise ConfigValidationError if condition is not met.""" @@ -59,7 +33,7 @@ class DataConfig: class TransformConfig: # pylint: disable=too-many-instance-attributes """Represents a single data transform stage.""" - op: TransformOp + op: str # ["filter", "calculate", "aggregate", "bin", "window", "fold", "pivot"] filter: str | None = None as_field: str | None = None expr: str | None = None @@ -96,7 +70,7 @@ def __post_init__(self) -> None: class ScaleConfig: """Describes how raw data values are mapped to visual ranges.""" - type: ScaleType | None = None + type: str | None = None # ["linear", "log", "sqrt", "pow", "symlog", "time", "utc", "ordinal", "band", "point"] domain: list[Any] | None = None range: list[Any] | None = None clamp: bool | None = None @@ -130,7 +104,7 @@ class ChannelAestheticsConfig: """Represents one visual encoding channel (x, y, color, etc.).""" field: str | None = None - type: ChannelType | None = None + type: str | None = None # ["quantitative", "ordinal", "nominal", "temporal"] value: Any | None = None aggregate: str | None = None bin: bool | None = None @@ -154,6 +128,7 @@ class AestheticsConfig: x: ChannelAestheticsConfig | None = None y: ChannelAestheticsConfig | None = None + z: ChannelAestheticsConfig | None = None color: ChannelAestheticsConfig | None = None size: ChannelAestheticsConfig | None = None shape: ChannelAestheticsConfig | None = None @@ -167,7 +142,7 @@ class AestheticsConfig: class GeometryConfig: """Visual primitive used to draw the layer.""" - type: GeometryType + type: str # [point, line, area, bar, rect, rule, tick, circle, square, text, boxplot, errorbar, errorband] props: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: @@ -179,7 +154,7 @@ class SelectionConfig: """Interactive selection definition.""" name: str - type: SelectionType = "interval" + type: str = "interval" # ["interval", "single", "multi"] encodings: list[str] | None = None fields: list[str] | None = None bind: dict[str, Any] | None = None @@ -239,7 +214,7 @@ class LayerConfig: class PlotConfig: """Top-level configuration for one plot.""" - backend: BackendType = "altair" + backend: str = "altair" # ["altair", "plotly"] data: DataConfig = field(default_factory=DataConfig) transforms: list[TransformConfig] = field(default_factory=list) layers: list[LayerConfig] = field(default_factory=list) @@ -257,5 +232,5 @@ def __post_init__(self) -> None: class GraphicsConfig: """Root Visualization config that multiplexes multiple named plots.""" - default_backend: BackendType = "altair" + default_backend: str = "altair" # ["altair", "plotly"] plots: dict[str, PlotConfig] = field(default_factory=dict) From 337cd899f27ea038fbb5f69f7aee1302e7ea8b56 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 14 Nov 2025 12:17:52 -0800 Subject: [PATCH 4/9] Disable too many instance attributes in configs --- simplexity/visualization/structured_configs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/simplexity/visualization/structured_configs.py b/simplexity/visualization/structured_configs.py index c156a589..0ead344d 100644 --- a/simplexity/visualization/structured_configs.py +++ b/simplexity/visualization/structured_configs.py @@ -100,7 +100,7 @@ class LegendConfig: @dataclass -class ChannelAestheticsConfig: +class ChannelAestheticsConfig: # pylint: disable=too-many-instance-attributes """Represents one visual encoding channel (x, y, color, etc.).""" field: str | None = None @@ -123,7 +123,7 @@ def __post_init__(self) -> None: @dataclass -class AestheticsConfig: +class AestheticsConfig: # pylint: disable=too-many-instance-attributes """Collection of channel encodings for a layer.""" x: ChannelAestheticsConfig | None = None @@ -211,7 +211,7 @@ class LayerConfig: @dataclass -class PlotConfig: +class PlotConfig: # pylint: disable=too-many-instance-attributes """Top-level configuration for one plot.""" backend: str = "altair" # ["altair", "plotly"] From 63f2d5a6c77c094a979c9e12b9469baef52518b9 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 14 Nov 2025 12:18:29 -0800 Subject: [PATCH 5/9] Replace dict with {} --- simplexity/visualization/plotly_renderer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/simplexity/visualization/plotly_renderer.py b/simplexity/visualization/plotly_renderer.py index 2606132d..7fb3d8a5 100644 --- a/simplexity/visualization/plotly_renderer.py +++ b/simplexity/visualization/plotly_renderer.py @@ -72,12 +72,12 @@ def _build_scatter3d(layer: LayerConfig, df: pd.DataFrame): ) if aes.color and aes.color.value is not None: - figure.update_traces(marker=dict(color=aes.color.value)) + figure.update_traces(marker={"color": aes.color.value}) if aes.size and aes.size.value is not None: - figure.update_traces(marker=dict(size=aes.size.value)) + figure.update_traces(marker={"size": aes.size.value}) trace_name = layer.name or (color_field or "3d_scatter") - figure.update_traces(name=trace_name, selector=dict(type="scatter3d")) + figure.update_traces(name=trace_name, selector={"type": "scatter3d"}) return figure From 2b9ae8cd8b6a89dd00b030c0d2b6fb4c6360b705 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 14 Nov 2025 12:18:54 -0800 Subject: [PATCH 6/9] Import altair in a normal way --- simplexity/visualization/altair_renderer.py | 78 +++++++++++---------- 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/simplexity/visualization/altair_renderer.py b/simplexity/visualization/altair_renderer.py index 970e9a07..74108c36 100644 --- a/simplexity/visualization/altair_renderer.py +++ b/simplexity/visualization/altair_renderer.py @@ -6,6 +6,11 @@ from collections.abc import Mapping from typing import Any +try: + import altair as alt # type: ignore import-not-found +except ImportError as exc: # pragma: no cover - dependency missing only in unsupported envs + raise ImportError("Altair is required for visualization rendering. Install `altair` to continue.") from exc + import pandas as pd from simplexity.exceptions import ConfigValidationError @@ -48,46 +53,36 @@ def build_altair_chart( data_registry: DataRegistry | Mapping[str, pd.DataFrame], ): """Render a PlotConfig into an Altair Chart.""" - alt = _import_altair() if not plot_cfg.layers: raise ConfigValidationError("PlotConfig.layers must include at least one layer for Altair rendering.") plot_df = build_plot_level_dataframe(plot_cfg.data, plot_cfg.transforms, data_registry) layer_charts = [ - _build_layer_chart(alt, layer, resolve_layer_dataframe(layer, plot_df, data_registry)) - for layer in plot_cfg.layers + _build_layer_chart(layer, resolve_layer_dataframe(layer, plot_df, data_registry)) for layer in plot_cfg.layers ] chart = layer_charts[0] if len(layer_charts) == 1 else alt.layer(*layer_charts) if plot_cfg.selections: - chart = chart.add_params(*[_build_selection_param(alt, sel) for sel in plot_cfg.selections]) + chart = chart.add_params(*[_build_selection_param(sel) for sel in plot_cfg.selections]) if plot_cfg.facet: - chart = _apply_facet(alt, chart, plot_cfg.facet) + chart = _apply_facet(chart, plot_cfg.facet) - chart = _apply_plot_level_properties(alt, chart, plot_cfg.guides, plot_cfg.size, plot_cfg.background) + chart = _apply_plot_level_properties(chart, plot_cfg.guides, plot_cfg.size, plot_cfg.background) return chart -def _import_altair(): - try: - import altair as alt # type: ignore import-not-found - except ImportError as exc: # pragma: no cover - dependency missing only in unsupported envs - raise ImportError("Altair is required for visualization rendering. Install `altair` to continue.") from exc - return alt - - -def _build_layer_chart(alt, layer: LayerConfig, df: pd.DataFrame): +def _build_layer_chart(layer: LayerConfig, df: pd.DataFrame): chart = alt.Chart(df) chart = _apply_geometry(chart, layer.geometry) - encoding_kwargs = _encode_aesthetics(alt, layer.aesthetics) + encoding_kwargs = _encode_aesthetics(layer.aesthetics) if encoding_kwargs: chart = chart.encode(**encoding_kwargs) if layer.selections: - chart = chart.add_params(*[_build_selection_param(alt, sel) for sel in layer.selections]) + chart = chart.add_params(*[_build_selection_param(sel) for sel in layer.selections]) return chart @@ -99,21 +94,21 @@ def _apply_geometry(chart, geometry: GeometryConfig): return mark_fn(**(geometry.props or {})) -def _encode_aesthetics(alt, aesthetics: AestheticsConfig) -> dict[str, Any]: +def _encode_aesthetics(aesthetics: AestheticsConfig) -> dict[str, Any]: encodings: dict[str, Any] = {} for channel_name in ("x", "y", "color", "size", "shape", "opacity", "row", "column"): channel_cfg = getattr(aesthetics, channel_name) - channel_value = _channel_to_alt(alt, channel_name, channel_cfg) + channel_value = _channel_to_alt(channel_name, channel_cfg) if channel_value is not None: encodings[channel_name] = channel_value if aesthetics.tooltip: - encodings["tooltip"] = [_tooltip_to_alt(alt, tooltip_cfg) for tooltip_cfg in aesthetics.tooltip] + encodings["tooltip"] = [_tooltip_to_alt(tooltip_cfg) for tooltip_cfg in aesthetics.tooltip] return encodings -def _channel_to_alt(alt, channel_name: str, cfg: ChannelAestheticsConfig | None): +def _channel_to_alt(channel_name: str, cfg: ChannelAestheticsConfig | None): if cfg is None: return None if cfg.value is not None and cfg.field is None: @@ -136,15 +131,15 @@ def _channel_to_alt(alt, channel_name: str, cfg: ChannelAestheticsConfig | None) if cfg.sort is not None: kwargs["sort"] = alt.Sort(cfg.sort) if isinstance(cfg.sort, list) else cfg.sort if cfg.scale: - kwargs["scale"] = _scale_to_alt(alt, cfg.scale) + kwargs["scale"] = _scale_to_alt(cfg.scale) if cfg.axis and channel_name in {"x", "y", "row", "column"}: - kwargs["axis"] = _axis_to_alt(alt, cfg.axis) + kwargs["axis"] = _axis_to_alt(cfg.axis) if cfg.legend and channel_name in {"color", "size", "shape", "opacity"}: - kwargs["legend"] = _legend_to_alt(alt, cfg.legend) + kwargs["legend"] = _legend_to_alt(cfg.legend) return channel_cls(**kwargs) -def _tooltip_to_alt(alt, cfg: ChannelAestheticsConfig): +def _tooltip_to_alt(cfg: ChannelAestheticsConfig): if cfg.value is not None and cfg.field is None: return alt.Tooltip(value=cfg.value, title=cfg.title) if cfg.field is None: @@ -158,32 +153,41 @@ def _tooltip_to_alt(alt, cfg: ChannelAestheticsConfig): return alt.Tooltip(**kwargs) -def _scale_to_alt(alt, cfg: ScaleConfig): +def _scale_to_alt(cfg: ScaleConfig): kwargs = {k: v for k, v in vars(cfg).items() if v is not None} return alt.Scale(**kwargs) -def _axis_to_alt(alt, cfg: AxisConfig): +def _axis_to_alt(cfg: AxisConfig): kwargs = {k: v for k, v in vars(cfg).items() if v is not None} return alt.Axis(**kwargs) -def _legend_to_alt(alt, cfg: LegendConfig): +def _legend_to_alt(cfg: LegendConfig): kwargs = {k: v for k, v in vars(cfg).items() if v is not None} return alt.Legend(**kwargs) -def _build_selection_param(alt, cfg: SelectionConfig): +def _build_selection_param(cfg: SelectionConfig): + kwargs: dict[str, Any] = {} + if cfg.name: + kwargs["name"] = cfg.name + if cfg.encodings: + kwargs["encodings"] = cfg.encodings + if cfg.fields: + kwargs["fields"] = cfg.fields + if cfg.bind: + kwargs["bind"] = cfg.bind if cfg.type == "interval": - return alt.selection_interval(name=cfg.name, encodings=cfg.encodings, fields=cfg.fields, bind=cfg.bind) + return alt.selection_interval(**kwargs) if cfg.type == "single": - return alt.selection_single(name=cfg.name, encodings=cfg.encodings, fields=cfg.fields, bind=cfg.bind) + return alt.selection_single(**kwargs) if cfg.type == "multi": - return alt.selection_multi(name=cfg.name, encodings=cfg.encodings, fields=cfg.fields, bind=cfg.bind) + return alt.selection_multi(**kwargs) raise ConfigValidationError(f"Unsupported selection type '{cfg.type}' for Altair renderer.") -def _apply_facet(alt, chart, facet_cfg: FacetConfig): +def _apply_facet(chart, facet_cfg: FacetConfig): facet_args: dict[str, Any] = {} if facet_cfg.row: facet_args["row"] = alt.Row(facet_cfg.row) @@ -196,10 +200,8 @@ def _apply_facet(alt, chart, facet_cfg: FacetConfig): return chart.facet(**facet_args) -def _apply_plot_level_properties( - alt, chart, guides: PlotLevelGuideConfig, size: PlotSizeConfig, background: str | None -): - title_params = _build_title_params(alt, guides) +def _apply_plot_level_properties(chart, guides: PlotLevelGuideConfig, size: PlotSizeConfig, background: str | None): + title_params = _build_title_params(guides) if title_params is not None: chart = chart.properties(title=title_params) width = size.width @@ -215,7 +217,7 @@ def _apply_plot_level_properties( return chart -def _build_title_params(alt, guides: PlotLevelGuideConfig): +def _build_title_params(guides: PlotLevelGuideConfig): subtitle_lines = [text for text in (guides.subtitle, guides.caption) if text] if not guides.title and not subtitle_lines: return None From caf0453baa53b66544552e2d2939e487c4d5bd5e Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 14 Nov 2025 13:40:40 -0800 Subject: [PATCH 7/9] Remove init --- simplexity/visualization/__init__.py | 57 ---------------------------- 1 file changed, 57 deletions(-) delete mode 100644 simplexity/visualization/__init__.py diff --git a/simplexity/visualization/__init__.py b/simplexity/visualization/__init__.py deleted file mode 100644 index 885d46f9..00000000 --- a/simplexity/visualization/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Visualization utilities, renderers, and structured config schemas.""" - -from .altair_renderer import build_altair_chart -from .data_registry import DataRegistry, DictDataRegistry, resolve_data_source -from .plotly_renderer import build_plotly_figure -from .structured_configs import ( - AestheticsConfig, - AxisConfig, - BackendType, - ChannelAestheticsConfig, - ChannelType, - DataConfig, - FacetConfig, - GeometryConfig, - GeometryType, - GraphicsConfig, - LabelConfig, - LayerConfig, - LegendConfig, - PlotConfig, - PlotLevelGuideConfig, - PlotSizeConfig, - ScaleConfig, - SelectionConfig, - SelectionType, - TransformConfig, - TransformOp, -) - -__all__ = [ - "build_altair_chart", - "build_plotly_figure", - "DataRegistry", - "DictDataRegistry", - "resolve_data_source", - "AestheticsConfig", - "AxisConfig", - "BackendType", - "ChannelAestheticsConfig", - "ChannelType", - "DataConfig", - "FacetConfig", - "GeometryConfig", - "GeometryType", - "GraphicsConfig", - "LabelConfig", - "LayerConfig", - "LegendConfig", - "PlotConfig", - "PlotLevelGuideConfig", - "PlotSizeConfig", - "ScaleConfig", - "SelectionConfig", - "SelectionType", - "TransformConfig", - "TransformOp", -] From ad5af1bbb56e2e5010b9fb2abd0c79d5807d1d7f Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 14 Nov 2025 13:46:49 -0800 Subject: [PATCH 8/9] Reorganize altair dependency in pyproject.toml --- .github/workflows/simplexity.yaml | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/simplexity.yaml b/.github/workflows/simplexity.yaml index 271308f7..ac76842c 100644 --- a/.github/workflows/simplexity.yaml +++ b/.github/workflows/simplexity.yaml @@ -23,7 +23,7 @@ jobs: run: uv run --extra dev --extra penzai pylint simplexity tests - name: pyright - run: uv run --extra aws --extra dev --extra penzai pyright + run: uv run --extra altair --extra aws --extra dev --extra penzai pyright unit-tests: runs-on: ubuntu-latest @@ -41,7 +41,7 @@ jobs: # Coverage is tracked but won't fail - new code in PRs is checked via diff-cover # TODO: Enable strict overall coverage enforcement once coverage improves # Note: --cov-fail-under=0 is explicit but redundant since pyproject.toml doesn't set a threshold - uv run --extra aws --extra dev --extra cuda --extra penzai pytest \ + uv run --extra altair --extra aws --extra dev --extra cuda --extra penzai pytest \ --capture=no \ --verbose \ --cov-fail-under=0 diff --git a/pyproject.toml b/pyproject.toml index a0f40023..c8646f6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ description = "Computational Mechanics of sequence prediction models." readme = "README.md" requires-python = ">=3.12" dependencies = [ - "altair>=5.3.0", "chex", "dotenv", "equinox>=0.13.0", @@ -28,6 +27,7 @@ dependencies = [ ] [project.optional-dependencies] +altair = ["altair>=5.3.0"] aws = ["boto3>=1.37.24"] # Deprecated S3 Persister. cuda = [ "jax[cuda12_pip]", From ab981cb11fcebbd7825d029b51e35da199659848 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Fri, 14 Nov 2025 13:47:02 -0800 Subject: [PATCH 9/9] Fix demo imports --- examples/visualization_3d_demo.py | 5 ++++- examples/visualization_demo.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/visualization_3d_demo.py b/examples/visualization_3d_demo.py index ed9b0a8e..aa7b4a26 100644 --- a/examples/visualization_3d_demo.py +++ b/examples/visualization_3d_demo.py @@ -13,7 +13,10 @@ from hydra.utils import get_original_cwd from omegaconf import DictConfig, OmegaConf -from simplexity.visualization import DictDataRegistry, PlotConfig, build_altair_chart, build_plotly_figure +from simplexity.visualization.altair_renderer import build_altair_chart +from simplexity.visualization.data_registry import DictDataRegistry +from simplexity.visualization.plotly_renderer import build_plotly_figure +from simplexity.visualization.structured_configs import PlotConfig @dataclass diff --git a/examples/visualization_demo.py b/examples/visualization_demo.py index 4532c1bc..2da72bb8 100644 --- a/examples/visualization_demo.py +++ b/examples/visualization_demo.py @@ -7,11 +7,12 @@ import numpy as np import pandas as pd -from simplexity.visualization import ( +from simplexity.visualization.altair_renderer import build_altair_chart +from simplexity.visualization.data_registry import DictDataRegistry +from simplexity.visualization.structured_configs import ( AestheticsConfig, ChannelAestheticsConfig, DataConfig, - DictDataRegistry, GeometryConfig, LayerConfig, PlotConfig, @@ -19,7 +20,6 @@ PlotSizeConfig, TransformConfig, ) -from simplexity.visualization.altair_renderer import build_altair_chart def main() -> None: