Skip to content

Commit 9afec7e

Browse files
authored
Fix inverted transfunc condition in render_shapes; add transfunc to render_labels (#635)
1 parent 525c94a commit 9afec7e

7 files changed

Lines changed: 45 additions & 1 deletion

File tree

src/spatialdata_plot/pl/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ def render_labels(
854854
scale=param_values["scale"],
855855
table_name=param_values["table_name"],
856856
table_layer=param_values["table_layer"],
857+
transfunc=kwargs.get("transfunc"),
857858
zorder=n_steps,
858859
colorbar=param_values["colorbar"],
859860
colorbar_params=param_values["colorbar_params"],

src/spatialdata_plot/pl/render.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def _render_shapes(
402402
sdata_filt[element] = shapes
403403

404404
# color_source_vector is None when the values aren't categorical
405-
if values_are_categorical and render_params.transfunc is not None:
405+
if not values_are_categorical and render_params.transfunc is not None:
406406
color_vector = render_params.transfunc(color_vector)
407407

408408
norm = copy(render_params.cmap_params.norm)
@@ -1702,6 +1702,10 @@ def _render_labels(
17021702
if isinstance(color_vector.dtype, pd.CategoricalDtype):
17031703
color_vector = color_vector.remove_unused_categories()
17041704

1705+
# color_source_vector is None when the values aren't categorical
1706+
if color_source_vector is None and render_params.transfunc is not None:
1707+
color_vector = render_params.transfunc(color_vector)
1708+
17051709
def _draw_labels(
17061710
seg_erosionpx: int | None,
17071711
seg_boundaries: bool,

src/spatialdata_plot/pl/render_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ class LabelsRenderParams:
304304
scale: str | None = None
305305
table_name: str | None = None
306306
table_layer: str | None = None
307+
transfunc: Callable[[float], float] | None = None
307308
zorder: int = 0
308309
colorbar: bool | str | None = "auto"
309310
colorbar_params: dict[str, object] | None = None
56.5 KB
Loading
40.9 KB
Loading

tests/pl/test_render_labels.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,11 @@ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
295295
cmap=_viridis_with_under_over(),
296296
).pl.show()
297297

298+
def test_plot_transfunc_applied_to_continuous_labels(self, sdata_blobs: SpatialData):
299+
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", transfunc=lambda x: x * 100).pl.show(
300+
title="transfunc: x * 100"
301+
)
302+
298303
def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData):
299304
sdata_blobs["table"].layers["normalized"] = get_standard_RNG().random(sdata_blobs["table"].X.shape)
300305
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show()
@@ -454,3 +459,17 @@ def test_groups_warns_when_no_groups_match_labels(sdata_blobs: SpatialData, capl
454459
sdata_blobs.pl.render_labels(
455460
labels_name, color="cat", groups=["nonexistent"], table_name="label_table", na_color=None
456461
).pl.show()
462+
463+
464+
def test_transfunc_is_applied_for_continuous_labels(sdata_blobs: SpatialData):
465+
called = []
466+
467+
def track(x):
468+
called.append(True)
469+
return x
470+
471+
fig, ax = plt.subplots()
472+
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", transfunc=track).pl.show(ax=ax)
473+
plt.close(fig)
474+
475+
assert called, "transfunc was not called for continuous labels data"

tests/pl/test_render_shapes.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,11 @@ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs_shapes_annotated
732732
element="blobs_polygons", color="value", norm=Normalize(2, 4, clip=False), cmap=_viridis_with_under_over()
733733
).pl.show()
734734

735+
def test_plot_transfunc_applied_to_continuous_shapes(self, sdata_blobs_shapes_annotated: SpatialData):
736+
sdata_blobs_shapes_annotated.pl.render_shapes(
737+
element="blobs_polygons", color="value", transfunc=lambda x: x * 100
738+
).pl.show(title="transfunc: x * 100")
739+
735740
def test_plot_datashader_can_color_with_norm_and_clipping(self, sdata_blobs_shapes_annotated: SpatialData):
736741
sdata_blobs_shapes_annotated.pl.render_shapes(
737742
element="blobs_polygons",
@@ -1310,3 +1315,17 @@ def test_datashader_na_color_nan_overlay(sdata_blobs: SpatialData, na_color: str
13101315
f"Expected {expected_images} image(s), got {len(ax.get_images())} for na_color={na_color!r}"
13111316
)
13121317
plt.close(fig)
1318+
1319+
1320+
def test_transfunc_is_applied_for_continuous_shapes(sdata_blobs_shapes_annotated: SpatialData):
1321+
called = []
1322+
1323+
def track(x):
1324+
called.append(True)
1325+
return x
1326+
1327+
fig, ax = plt.subplots()
1328+
sdata_blobs_shapes_annotated.pl.render_shapes("blobs_polygons", color="value", transfunc=track).pl.show(ax=ax)
1329+
plt.close(fig)
1330+
1331+
assert called, "transfunc was not called for continuous shapes data"

0 commit comments

Comments
 (0)