diff --git a/plotly/express/_imshow.py b/plotly/express/_imshow.py index 7a3dd8a6df..ed0e7a6919 100644 --- a/plotly/express/_imshow.py +++ b/plotly/express/_imshow.py @@ -63,6 +63,7 @@ def imshow( y=None, animation_frame=None, facet_col=None, + facet_row=None, facet_col_wrap=None, facet_col_spacing=None, facet_row_spacing=None, @@ -128,10 +129,15 @@ def imshow( axis number along which the image array is sliced to create a facetted plot. If `img` is an xarray, `facet_col` can be the name of one the dimensions. + facet_row: int or str, optional (default None) + axis number along which the image array is sliced to create a vertically + facetted plot. If `img` is an xarray, `facet_row` can be the name of one + the dimensions. + facet_col_wrap: int Maximum number of facet columns. Wraps the column variable at this width, so that the column facets span multiple rows. - Ignored if `facet_col` is None. + Ignored if `facet_col` is None or if `facet_row` is set. facet_col_spacing: float between 0 and 1 Spacing between facet columns, in paper units. Default is 0.02. @@ -235,30 +241,43 @@ def imshow( args = locals() apply_default_cascade(args, constructor=None) labels = labels.copy() - nslices_facet = 1 + nslices_facet_col = 1 + nslices_facet_row = 1 + facet_col_slices = None + facet_row_slices = None if facet_col is not None: if isinstance(facet_col, str): facet_col = img.dims.index(facet_col) - nslices_facet = img.shape[facet_col] - facet_slices = range(nslices_facet) - ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet + nslices_facet_col = img.shape[facet_col] + facet_col_slices = range(nslices_facet_col) + if facet_row is not None: + if isinstance(facet_row, str): + facet_row = img.dims.index(facet_row) + nslices_facet_row = img.shape[facet_row] + facet_row_slices = range(nslices_facet_row) + # facet_col_wrap is ignored when facet_row is set + if facet_row is not None or facet_col_wrap is None: + ncols = nslices_facet_col + nrows = nslices_facet_row + else: + ncols = min(int(facet_col_wrap), nslices_facet_col) nrows = ( - nslices_facet // ncols + 1 - if nslices_facet % ncols - else nslices_facet // ncols + nslices_facet_col // ncols + 1 + if nslices_facet_col % ncols + else nslices_facet_col // ncols ) - else: - nrows = 1 - ncols = 1 if animation_frame is not None: if isinstance(animation_frame, str): animation_frame = img.dims.index(animation_frame) nslices_animation = img.shape[animation_frame] animation_slices = range(nslices_animation) - slice_dimensions = (facet_col is not None) + ( - animation_frame is not None - ) # 0, 1, or 2 - facet_label = None + slice_dimensions = ( + (facet_col is not None) + + (facet_row is not None) + + (animation_frame is not None) + ) # 0, 1, 2, or 3 + facet_col_label = None + facet_row_label = None animation_label = None img_is_xarray = False # ----- Define x and y, set labels if img is an xarray ------------------- @@ -267,9 +286,13 @@ def imshow( img_is_xarray = True pop_indexes = [] if facet_col is not None: - facet_slices = img.coords[img.dims[facet_col]].values + facet_col_slices = img.coords[img.dims[facet_col]].values pop_indexes.append(facet_col) - facet_label = img.dims[facet_col] + facet_col_label = img.dims[facet_col] + if facet_row is not None: + facet_row_slices = img.coords[img.dims[facet_row]].values + pop_indexes.append(facet_row) + facet_row_label = img.dims[facet_row] if animation_frame is not None: animation_slices = img.coords[img.dims[animation_frame]].values pop_indexes.append(animation_frame) @@ -295,7 +318,9 @@ def imshow( if labels.get("animation_frame", None) is None: labels["animation_frame"] = animation_label if labels.get("facet_col", None) is None: - labels["facet_col"] = facet_label + labels["facet_col"] = facet_col_label + if labels.get("facet_row", None) is None: + labels["facet_row"] = facet_row_label if labels.get("color", None) is None: labels["color"] = xarray.plot.utils.label_from_attrs(img) labels["color"] = labels["color"].replace("\n", "
") @@ -331,12 +356,20 @@ def imshow( # --------------- Starting from here img is always a numpy array -------- img = np.asanyarray(img) - # Reshape array so that animation dimension comes first, then facets, then images + # Reshape array so that animation dimension comes first, then facet_row, then facet_col, then images + # We move axes to front in reverse order so each axis ends up at position 0 in the final order if facet_col is not None: img = np.moveaxis(img, facet_col, 0) if animation_frame is not None and animation_frame < facet_col: animation_frame += 1 + if facet_row is not None and facet_row < facet_col: + facet_row += 1 facet_col = True + if facet_row is not None: + img = np.moveaxis(img, facet_row, 0) + if animation_frame is not None and animation_frame < facet_row: + animation_frame += 1 + facet_row = True if animation_frame is not None: img = np.moveaxis(img, animation_frame, 0) animation_frame = True @@ -348,8 +381,10 @@ def imshow( iterables = () if animation_frame is not None: iterables += (range(nslices_animation),) + if facet_row is not None: + iterables += (range(nslices_facet_row),) if facet_col is not None: - iterables += (range(nslices_facet),) + iterables += (range(nslices_facet_col),) # Default behaviour of binary_string: True for RGB images, False for 2D if binary_string is None: @@ -535,19 +570,25 @@ def imshow( raise ValueError( "px.imshow only accepts 2D single-channel, RGB or RGBA images. " "An image of shape %s was provided. " - "Alternatively, 3- or 4-D single or multichannel datasets can be " - "visualized using the `facet_col` or/and `animation_frame` arguments." + "Alternatively, 3-, 4-, or 5-D single or multichannel datasets can be " + "visualized using the `facet_col`, `facet_row`, and/or `animation_frame` arguments." % str(img.shape) ) # Now build figure col_labels = [] + row_labels = [] if facet_col is not None: slice_label = ( "facet_col" if labels.get("facet_col") is None else labels["facet_col"] ) - col_labels = [f"{slice_label}={i}" for i in facet_slices] - fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) + col_labels = [f"{slice_label}={i}" for i in facet_col_slices] + if facet_row is not None: + slice_label = ( + "facet_row" if labels.get("facet_row") is None else labels["facet_row"] + ) + row_labels = [f"{slice_label}={i}" for i in facet_row_slices] + fig = init_figure(args, "xy", [], nrows, ncols, col_labels, row_labels) for attr_name in ["height", "width"]: if args[attr_name]: layout[attr_name] = args[attr_name] @@ -556,15 +597,22 @@ def imshow( elif args["template"].layout.margin.t is None: layout["margin"] = {"t": 60} + nslices_facets = nslices_facet_row * nslices_facet_col frame_list = [] for index, trace in enumerate(traces): - if (facet_col and index < nrows * ncols) or index == 0: - fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) + if ((facet_col or facet_row) and index < nrows * ncols) or index == 0: + # Calculate row and col position + # index is ordered by (facet_row, facet_col) from itertools.product + # When facet_col_wrap is used (and facet_row is None), traces are laid out + # across wrapped columns, so we use ncols for the calculation + row_idx = index // ncols + col_idx = index % ncols + fig.add_trace(trace, row=nrows - row_idx, col=col_idx + 1) if animation_frame is not None: for i, index in zip(range(nslices_animation), animation_slices): frame_list.append( dict( - data=traces[nslices_facet * i : nslices_facet * (i + 1)], + data=traces[nslices_facets * i : nslices_facets * (i + 1)], layout=layout, name=str(index), ) diff --git a/tests/test_optional/test_px/test_imshow.py b/tests/test_optional/test_px/test_imshow.py index 86e843a6ff..c4ecea9448 100644 --- a/tests/test_optional/test_px/test_imshow.py +++ b/tests/test_optional/test_px/test_imshow.py @@ -450,3 +450,89 @@ def test_animation_and_facet(binary_string): nslices = img.shape[0] assert len(fig.frames) == nslices assert len(fig.data) == img.shape[1] + + +@pytest.mark.parametrize("facet_row", [0, 1, 2, -1]) +@pytest.mark.parametrize("binary_string", [False, True]) +def test_facet_row(facet_row, binary_string): + img = np.random.randint(255, size=(10, 9, 8)) + fig = px.imshow( + img, + facet_row=facet_row, + binary_string=binary_string, + ) + nslices = img.shape[facet_row] + nrows = nslices + ncols = 1 + nmax = ncols * nrows + assert "yaxis%d" % nmax in fig.layout + assert "yaxis%d" % (nmax + 1) not in fig.layout + assert len(fig.data) == nslices + + +@pytest.mark.parametrize("binary_string", [False, True]) +def test_facet_row_and_col(binary_string): + img = np.random.randint(255, size=(4, 3, 9, 8)) + fig = px.imshow( + img, + facet_row=0, + facet_col=1, + binary_string=binary_string, + ) + nrows = img.shape[0] + ncols = img.shape[1] + nmax = ncols * nrows + assert "yaxis%d" % nmax in fig.layout + assert "yaxis%d" % (nmax + 1) not in fig.layout + assert len(fig.data) == nrows * ncols + + +@pytest.mark.parametrize("binary_string", [False, True]) +def test_animation_facet_row_and_col(binary_string): + img = np.random.randint(255, size=(5, 4, 3, 9, 8)).astype(np.uint8) + fig = px.imshow( + img, + animation_frame=0, + facet_row=1, + facet_col=2, + binary_string=binary_string, + ) + nslices_animation = img.shape[0] + nrows = img.shape[1] + ncols = img.shape[2] + assert len(fig.frames) == nslices_animation + assert len(fig.data) == nrows * ncols + + +def test_imshow_xarray_facet_row(): + img = np.random.random((3, 4, 5)) + da = xr.DataArray( + img, dims=["row_dim", "dim_1", "dim_2"], coords={"row_dim": ["A", "B", "C"]} + ) + fig = px.imshow(da, facet_row="row_dim") + # Dimensions are used for axis labels and coordinates + assert fig.layout.xaxis.title.text == "dim_2" + assert fig.layout.yaxis.title.text == "dim_1" + assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"])) + assert len(fig.data) == 3 + # Check row labels are present + annotations = [a.text for a in fig.layout.annotations] + assert any("row_dim=A" in a for a in annotations) + + +def test_imshow_xarray_facet_row_and_col(): + img = np.random.random((3, 4, 5, 6)) + da = xr.DataArray( + img, + dims=["row_dim", "col_dim", "dim_y", "dim_x"], + coords={"row_dim": ["R1", "R2", "R3"], "col_dim": ["C1", "C2", "C3", "C4"]}, + ) + fig = px.imshow(da, facet_row="row_dim", facet_col="col_dim") + # Dimensions are used for axis labels and coordinates + assert fig.layout.xaxis.title.text == "dim_x" + assert fig.layout.yaxis.title.text == "dim_y" + assert len(fig.data) == 3 * 4 + # Check labels are present + annotations = [a.text for a in fig.layout.annotations] + assert any("row_dim=R1" in a for a in annotations) + assert any("col_dim=C1" in a for a in annotations)