Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 75 additions & 27 deletions plotly/express/_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def imshow(
y=None,
animation_frame=None,
facet_col=None,
facet_row=None,
facet_col_wrap=None,
facet_col_spacing=None,
facet_row_spacing=None,
Expand Down Expand Up @@ -128,10 +129,15 @@ def imshow(
axis number along which the image array is sliced to create a facetted plot.
If `img` is an xarray, `facet_col` can be the name of one the dimensions.

facet_row: int or str, optional (default None)
axis number along which the image array is sliced to create a vertically
facetted plot. If `img` is an xarray, `facet_row` can be the name of one
the dimensions.

facet_col_wrap: int
Maximum number of facet columns. Wraps the column variable at this width,
so that the column facets span multiple rows.
Ignored if `facet_col` is None.
Ignored if `facet_col` is None or if `facet_row` is set.

facet_col_spacing: float between 0 and 1
Spacing between facet columns, in paper units. Default is 0.02.
Expand Down Expand Up @@ -235,30 +241,43 @@ def imshow(
args = locals()
apply_default_cascade(args, constructor=None)
labels = labels.copy()
nslices_facet = 1
nslices_facet_col = 1
nslices_facet_row = 1
facet_col_slices = None
facet_row_slices = None
if facet_col is not None:
if isinstance(facet_col, str):
facet_col = img.dims.index(facet_col)
nslices_facet = img.shape[facet_col]
facet_slices = range(nslices_facet)
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet
nslices_facet_col = img.shape[facet_col]
facet_col_slices = range(nslices_facet_col)
if facet_row is not None:
if isinstance(facet_row, str):
facet_row = img.dims.index(facet_row)
nslices_facet_row = img.shape[facet_row]
facet_row_slices = range(nslices_facet_row)
# facet_col_wrap is ignored when facet_row is set
if facet_row is not None or facet_col_wrap is None:
ncols = nslices_facet_col
nrows = nslices_facet_row
else:
ncols = min(int(facet_col_wrap), nslices_facet_col)
nrows = (
nslices_facet // ncols + 1
if nslices_facet % ncols
else nslices_facet // ncols
nslices_facet_col // ncols + 1
if nslices_facet_col % ncols
else nslices_facet_col // ncols
)
else:
nrows = 1
ncols = 1
if animation_frame is not None:
if isinstance(animation_frame, str):
animation_frame = img.dims.index(animation_frame)
nslices_animation = img.shape[animation_frame]
animation_slices = range(nslices_animation)
slice_dimensions = (facet_col is not None) + (
animation_frame is not None
) # 0, 1, or 2
facet_label = None
slice_dimensions = (
(facet_col is not None)
+ (facet_row is not None)
+ (animation_frame is not None)
) # 0, 1, 2, or 3
facet_col_label = None
facet_row_label = None
animation_label = None
img_is_xarray = False
# ----- Define x and y, set labels if img is an xarray -------------------
Expand All @@ -267,9 +286,13 @@ def imshow(
img_is_xarray = True
pop_indexes = []
if facet_col is not None:
facet_slices = img.coords[img.dims[facet_col]].values
facet_col_slices = img.coords[img.dims[facet_col]].values
pop_indexes.append(facet_col)
facet_label = img.dims[facet_col]
facet_col_label = img.dims[facet_col]
if facet_row is not None:
facet_row_slices = img.coords[img.dims[facet_row]].values
pop_indexes.append(facet_row)
facet_row_label = img.dims[facet_row]
if animation_frame is not None:
animation_slices = img.coords[img.dims[animation_frame]].values
pop_indexes.append(animation_frame)
Expand All @@ -295,7 +318,9 @@ def imshow(
if labels.get("animation_frame", None) is None:
labels["animation_frame"] = animation_label
if labels.get("facet_col", None) is None:
labels["facet_col"] = facet_label
labels["facet_col"] = facet_col_label
if labels.get("facet_row", None) is None:
labels["facet_row"] = facet_row_label
if labels.get("color", None) is None:
labels["color"] = xarray.plot.utils.label_from_attrs(img)
labels["color"] = labels["color"].replace("\n", "<br>")
Expand Down Expand Up @@ -331,12 +356,20 @@ def imshow(

# --------------- Starting from here img is always a numpy array --------
img = np.asanyarray(img)
# Reshape array so that animation dimension comes first, then facets, then images
# Reshape array so that animation dimension comes first, then facet_row, then facet_col, then images
# We move axes to front in reverse order so each axis ends up at position 0 in the final order
if facet_col is not None:
img = np.moveaxis(img, facet_col, 0)
if animation_frame is not None and animation_frame < facet_col:
animation_frame += 1
if facet_row is not None and facet_row < facet_col:
facet_row += 1
facet_col = True
if facet_row is not None:
img = np.moveaxis(img, facet_row, 0)
if animation_frame is not None and animation_frame < facet_row:
animation_frame += 1
facet_row = True
if animation_frame is not None:
img = np.moveaxis(img, animation_frame, 0)
animation_frame = True
Expand All @@ -348,8 +381,10 @@ def imshow(
iterables = ()
if animation_frame is not None:
iterables += (range(nslices_animation),)
if facet_row is not None:
iterables += (range(nslices_facet_row),)
if facet_col is not None:
iterables += (range(nslices_facet),)
iterables += (range(nslices_facet_col),)

# Default behaviour of binary_string: True for RGB images, False for 2D
if binary_string is None:
Expand Down Expand Up @@ -535,19 +570,25 @@ def imshow(
raise ValueError(
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
"An image of shape %s was provided. "
"Alternatively, 3- or 4-D single or multichannel datasets can be "
"visualized using the `facet_col` or/and `animation_frame` arguments."
"Alternatively, 3-, 4-, or 5-D single or multichannel datasets can be "
"visualized using the `facet_col`, `facet_row`, and/or `animation_frame` arguments."
% str(img.shape)
)

# Now build figure
col_labels = []
row_labels = []
if facet_col is not None:
slice_label = (
"facet_col" if labels.get("facet_col") is None else labels["facet_col"]
)
col_labels = [f"{slice_label}={i}" for i in facet_slices]
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
col_labels = [f"{slice_label}={i}" for i in facet_col_slices]
if facet_row is not None:
slice_label = (
"facet_row" if labels.get("facet_row") is None else labels["facet_row"]
)
row_labels = [f"{slice_label}={i}" for i in facet_row_slices]
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, row_labels)
for attr_name in ["height", "width"]:
if args[attr_name]:
layout[attr_name] = args[attr_name]
Expand All @@ -556,15 +597,22 @@ def imshow(
elif args["template"].layout.margin.t is None:
layout["margin"] = {"t": 60}

nslices_facets = nslices_facet_row * nslices_facet_col
frame_list = []
for index, trace in enumerate(traces):
if (facet_col and index < nrows * ncols) or index == 0:
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
if ((facet_col or facet_row) and index < nrows * ncols) or index == 0:
# Calculate row and col position
# index is ordered by (facet_row, facet_col) from itertools.product
# When facet_col_wrap is used (and facet_row is None), traces are laid out
# across wrapped columns, so we use ncols for the calculation
row_idx = index // ncols
col_idx = index % ncols
fig.add_trace(trace, row=nrows - row_idx, col=col_idx + 1)
if animation_frame is not None:
for i, index in zip(range(nslices_animation), animation_slices):
frame_list.append(
dict(
data=traces[nslices_facet * i : nslices_facet * (i + 1)],
data=traces[nslices_facets * i : nslices_facets * (i + 1)],
layout=layout,
name=str(index),
)
Expand Down
86 changes: 86 additions & 0 deletions tests/test_optional/test_px/test_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,89 @@ def test_animation_and_facet(binary_string):
nslices = img.shape[0]
assert len(fig.frames) == nslices
assert len(fig.data) == img.shape[1]


@pytest.mark.parametrize("facet_row", [0, 1, 2, -1])
@pytest.mark.parametrize("binary_string", [False, True])
def test_facet_row(facet_row, binary_string):
img = np.random.randint(255, size=(10, 9, 8))
fig = px.imshow(
img,
facet_row=facet_row,
binary_string=binary_string,
)
nslices = img.shape[facet_row]
nrows = nslices
ncols = 1
nmax = ncols * nrows
assert "yaxis%d" % nmax in fig.layout
assert "yaxis%d" % (nmax + 1) not in fig.layout
assert len(fig.data) == nslices


@pytest.mark.parametrize("binary_string", [False, True])
def test_facet_row_and_col(binary_string):
img = np.random.randint(255, size=(4, 3, 9, 8))
fig = px.imshow(
img,
facet_row=0,
facet_col=1,
binary_string=binary_string,
)
nrows = img.shape[0]
ncols = img.shape[1]
nmax = ncols * nrows
assert "yaxis%d" % nmax in fig.layout
assert "yaxis%d" % (nmax + 1) not in fig.layout
assert len(fig.data) == nrows * ncols


@pytest.mark.parametrize("binary_string", [False, True])
def test_animation_facet_row_and_col(binary_string):
img = np.random.randint(255, size=(5, 4, 3, 9, 8)).astype(np.uint8)
fig = px.imshow(
img,
animation_frame=0,
facet_row=1,
facet_col=2,
binary_string=binary_string,
)
nslices_animation = img.shape[0]
nrows = img.shape[1]
ncols = img.shape[2]
assert len(fig.frames) == nslices_animation
assert len(fig.data) == nrows * ncols


def test_imshow_xarray_facet_row():
img = np.random.random((3, 4, 5))
da = xr.DataArray(
img, dims=["row_dim", "dim_1", "dim_2"], coords={"row_dim": ["A", "B", "C"]}
)
fig = px.imshow(da, facet_row="row_dim")
# Dimensions are used for axis labels and coordinates
assert fig.layout.xaxis.title.text == "dim_2"
assert fig.layout.yaxis.title.text == "dim_1"
assert np.all(np.array(fig.data[0].x) == np.array(da.coords["dim_2"]))
assert len(fig.data) == 3
# Check row labels are present
annotations = [a.text for a in fig.layout.annotations]
assert any("row_dim=A" in a for a in annotations)


def test_imshow_xarray_facet_row_and_col():
img = np.random.random((3, 4, 5, 6))
da = xr.DataArray(
img,
dims=["row_dim", "col_dim", "dim_y", "dim_x"],
coords={"row_dim": ["R1", "R2", "R3"], "col_dim": ["C1", "C2", "C3", "C4"]},
)
fig = px.imshow(da, facet_row="row_dim", facet_col="col_dim")
# Dimensions are used for axis labels and coordinates
assert fig.layout.xaxis.title.text == "dim_x"
assert fig.layout.yaxis.title.text == "dim_y"
assert len(fig.data) == 3 * 4
# Check labels are present
annotations = [a.text for a in fig.layout.annotations]
assert any("row_dim=R1" in a for a in annotations)
assert any("col_dim=C1" in a for a in annotations)