Skip to content

Commit adfdfcd

Browse files
authored
Merge pull request #5437 from antonymilne/main
[Fix] Support trace-specific color sequences in Plotly Express via templates
2 parents 509d1fc + e8339e0 commit adfdfcd

File tree

4 files changed

+75
-3
lines changed

4 files changed

+75
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
44

55
## Unreleased
66

7+
### Fixed
8+
- Fix issue where Plotly Express ignored trace-specific color sequences defined in templates via `template.data.<trace_type>` [[#5437](https://github.com/plotly/plotly.py/pull/5437)]
9+
710
### Updated
811
- Speed up `validate_gantt` function [[#5386](https://github.com/plotly/plotly.py/pull/5386)], with thanks to @misrasaurabh1 for the contribution!
912

plotly/express/_core.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def one_group(x):
10031003
return ""
10041004

10051005

1006-
def apply_default_cascade(args):
1006+
def apply_default_cascade(args, constructor):
10071007
# first we apply px.defaults to unspecified args
10081008

10091009
for param in defaults.__slots__:
@@ -1037,9 +1037,29 @@ def apply_default_cascade(args):
10371037
if args["color_continuous_scale"] is None:
10381038
args["color_continuous_scale"] = sequential.Viridis
10391039

1040+
# if color_discrete_sequence not set explicitly or in px.defaults,
1041+
# see if we can defer to template. Try trace-specific colors first,
1042+
# then layout.colorway, then set reasonable defaults
10401043
if "color_discrete_sequence" in args:
1044+
if args["color_discrete_sequence"] is None and constructor is not None:
1045+
if constructor == "timeline":
1046+
trace_type = "bar"
1047+
else:
1048+
trace_type = constructor().type
1049+
if trace_data_list := getattr(args["template"].data, trace_type, None):
1050+
trace_specific_colors = [
1051+
trace_data.marker.color
1052+
for trace_data in trace_data_list
1053+
if hasattr(trace_data, "marker")
1054+
and hasattr(trace_data.marker, "color")
1055+
]
1056+
# If template contains at least one color for this trace type, assign to color_discrete_sequence
1057+
if any(trace_specific_colors):
1058+
args["color_discrete_sequence"] = trace_specific_colors
1059+
# fallback to layout.colorway if trace-specific colors not available
10411060
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
10421061
args["color_discrete_sequence"] = args["template"].layout.colorway
1062+
# final fallback to default qualitative palette
10431063
if args["color_discrete_sequence"] is None:
10441064
args["color_discrete_sequence"] = qualitative.D3
10451065

@@ -2486,7 +2506,7 @@ def get_groups_and_orders(args, grouper):
24862506
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
24872507
trace_patch = trace_patch or {}
24882508
layout_patch = layout_patch or {}
2489-
apply_default_cascade(args)
2509+
apply_default_cascade(args, constructor=constructor)
24902510

24912511
args = build_dataframe(args, constructor)
24922512
if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None:

plotly/express/_imshow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def imshow(
233233
axes labels and ticks.
234234
"""
235235
args = locals()
236-
apply_default_cascade(args)
236+
apply_default_cascade(args, constructor=None)
237237
labels = labels.copy()
238238
nslices_facet = 1
239239
if facet_col is not None:

tests/test_optional/test_px/test_px.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from itertools import permutations
22
import warnings
33

4+
import pandas as pd
45
import plotly.express as px
56
import plotly.io as pio
67
import narwhals.stable.v1 as nw
@@ -226,6 +227,54 @@ def test_px_templates(backend):
226227
pio.templates.default = "plotly"
227228

228229

230+
def test_px_templates_trace_specific_colors(backend):
231+
tips = px.data.tips(return_type=backend)
232+
233+
# trace-specific colors: each trace type uses its own template colors
234+
template = {
235+
"data": {
236+
"histogram": [
237+
{"marker": {"color": "orange"}},
238+
{"marker": {"color": "purple"}},
239+
],
240+
"bar": [
241+
{"marker": {"color": "red"}},
242+
{"marker": {"color": "blue"}},
243+
],
244+
},
245+
"layout": {
246+
"colorway": ["yellow", "green"],
247+
},
248+
}
249+
# histogram uses histogram colors
250+
fig = px.histogram(tips, x="total_bill", color="sex", template=template)
251+
assert fig.data[0].marker.color == "orange"
252+
assert fig.data[1].marker.color == "purple"
253+
# fallback to layout.colorway when trace-specific colors don't exist
254+
fig = px.box(tips, x="day", y="total_bill", color="sex", template=template)
255+
assert fig.data[0].marker.color == "yellow"
256+
assert fig.data[1].marker.color == "green"
257+
# timeline special case (maps to bar)
258+
df_timeline = pd.DataFrame(
259+
{
260+
"Task": ["Job A", "Job B"],
261+
"Start": ["2009-01-01", "2009-03-05"],
262+
"Finish": ["2009-02-28", "2009-04-15"],
263+
"Resource": ["Alex", "Max"],
264+
}
265+
)
266+
fig = px.timeline(
267+
df_timeline,
268+
x_start="Start",
269+
x_end="Finish",
270+
y="Task",
271+
color="Resource",
272+
template=template,
273+
)
274+
assert fig.data[0].marker.color == "red"
275+
assert fig.data[1].marker.color == "blue"
276+
277+
229278
def test_px_defaults():
230279
px.defaults.labels = dict(x="hey x")
231280
px.defaults.category_orders = dict(color=["b", "a"])

0 commit comments

Comments
 (0)