@@ -182,7 +182,8 @@ def plot(self, colors: str | list[str] | None = None, show: bool | None = None)
182182 """Plot cluster assignment visualization.
183183
184184 Shows which cluster each original period belongs to, and the
185- number of occurrences per cluster.
185+ number of occurrences per cluster. For multi-period/scenario structures,
186+ creates a faceted grid plot.
186187
187188 Args:
188189 colors: Colorscale name (str) or list of colors.
@@ -198,34 +199,42 @@ def plot(self, colors: str | list[str] | None = None, show: bool | None = None)
198199 n_clusters = (
199200 int (self .n_clusters ) if isinstance (self .n_clusters , (int , np .integer )) else int (self .n_clusters .values )
200201 )
202+ colorscale = colors or CONFIG .Plotting .default_sequential_colorscale
201203
202- cluster_order = self .get_cluster_order_for_slice ()
203-
204- # Build DataArray for fxplot heatmap
205- cluster_da = xr .DataArray (
206- cluster_order .reshape (1 , - 1 ),
207- dims = ['y' , 'original_cluster' ],
208- coords = {'y' : ['Cluster' ], 'original_cluster' : range (1 , len (cluster_order ) + 1 )},
209- name = 'cluster_assignment' ,
204+ # Build DataArray with 1-based original_cluster coords
205+ cluster_da = self .cluster_order .assign_coords (
206+ original_cluster = np .arange (1 , self .cluster_order .sizes ['original_cluster' ] + 1 )
210207 )
211208
212- # Use fxplot.heatmap for smart defaults
213- colorscale = colors or CONFIG .Plotting .default_sequential_colorscale
214- fig = cluster_da .fxplot .heatmap (
209+ has_period = 'period' in cluster_da .dims
210+ has_scenario = 'scenario' in cluster_da .dims
211+
212+ # Transpose for heatmap: first dim = y-axis, second dim = x-axis
213+ if has_period :
214+ cluster_da = cluster_da .transpose ('period' , 'original_cluster' , ...)
215+ elif has_scenario :
216+ cluster_da = cluster_da .transpose ('scenario' , 'original_cluster' , ...)
217+
218+ # Data to return (without dummy dims)
219+ ds = xr .Dataset ({'cluster_order' : cluster_da })
220+
221+ # For plotting: add dummy y-dim if needed (heatmap requires 2D)
222+ if not has_period and not has_scenario :
223+ plot_da = cluster_da .expand_dims (y = ['' ]).transpose ('y' , 'original_cluster' )
224+ plot_ds = xr .Dataset ({'cluster_order' : plot_da })
225+ else :
226+ plot_ds = ds
227+
228+ fig = plot_ds .fxplot .heatmap (
215229 colors = colorscale ,
216- title = f'Cluster Assignment ({ self .n_original_clusters } periods → { n_clusters } clusters)' ,
230+ title = f'Cluster Assignment ({ self .n_original_clusters } → { n_clusters } clusters)' ,
217231 )
218- fig . update_yaxes ( showticklabels = False )
232+
219233 fig .update_coloraxes (colorbar_title = 'Cluster' )
234+ if not has_period and not has_scenario :
235+ fig .update_yaxes (showticklabels = False )
220236
221- # Build data for PlotResult
222- data = xr .Dataset (
223- {
224- 'cluster_order' : self .cluster_order ,
225- 'cluster_occurrences' : self .cluster_occurrences ,
226- }
227- )
228- plot_result = PlotResult (data = data , figure = fig )
237+ plot_result = PlotResult (data = ds , figure = fig )
229238
230239 if show is None :
231240 show = CONFIG .Plotting .default_show
0 commit comments