Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

### Fixed
- Fix issue where Plotly Express ignored trace-specific color sequences defined in templates via `template.data.<trace_type>` [[#5437](https://github.com/plotly/plotly.py/pull/5437)]

### Updated
- Speed up `validate_gantt` function [[#5386](https://github.com/plotly/plotly.py/pull/5386)], with thanks to @misrasaurabh1 for the contribution!

Expand Down
25 changes: 23 additions & 2 deletions plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def one_group(x):
return ""


def apply_default_cascade(args):
def apply_default_cascade(args, constructor=None):
# first we apply px.defaults to unspecified args

for param in defaults.__slots__:
Expand Down Expand Up @@ -1037,9 +1037,30 @@ def apply_default_cascade(args):
if args["color_continuous_scale"] is None:
args["color_continuous_scale"] = sequential.Viridis

# if color_discrete_sequence not set explicitly or in px.defaults,
# see if we can defer to template. Try trace-specific colors first,
# then layout.colorway, then set reasonable defaults
if "color_discrete_sequence" in args:
if args["color_discrete_sequence"] is None and constructor is not None:
if constructor == "timeline":
trace_type = "bar"
else:
trace_type = constructor().type
if trace_data_list := getattr(args["template"].data, trace_type, None):
args["color_discrete_sequence"] = [
trace_data.marker.color
for trace_data in trace_data_list
if hasattr(trace_data, "marker")
and hasattr(trace_data.marker, "color")
]
if not args["color_discrete_sequence"] or not any(
args["color_discrete_sequence"]
):
args["color_discrete_sequence"] = None
# fallback to layout.colorway if trace-specific colors not available
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
args["color_discrete_sequence"] = args["template"].layout.colorway
# final fallback to default qualitative palette
if args["color_discrete_sequence"] is None:
args["color_discrete_sequence"] = qualitative.D3

Expand Down Expand Up @@ -2486,7 +2507,7 @@ def get_groups_and_orders(args, grouper):
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trace_patch = trace_patch or {}
layout_patch = layout_patch or {}
apply_default_cascade(args)
apply_default_cascade(args, constructor=constructor)

args = build_dataframe(args, constructor)
if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_optional/test_px/test_px.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,52 @@ def test_px_templates(backend):
pio.templates.default = "plotly"


def test_px_templates_trace_specific_colors(backend):
import pandas as pd

tips = px.data.tips(return_type=backend)

# trace-specific colors: each trace type uses its own template colors
template = {
"data_histogram": [
{"marker": {"color": "orange"}},
{"marker": {"color": "purple"}},
],
"data_bar": [
{"marker": {"color": "red"}},
{"marker": {"color": "blue"}},
],
"layout_colorway": ["yellow", "green"],
}
# histogram uses histogram colors
fig = px.histogram(tips, x="total_bill", color="sex", template=template)
assert fig.data[0].marker.color == "orange"
assert fig.data[1].marker.color == "purple"
# fallback to layout.colorway when trace-specific colors don't exist
fig = px.box(tips, x="day", y="total_bill", color="sex", template=template)
assert fig.data[0].marker.color == "yellow"
assert fig.data[1].marker.color == "green"
# timeline special case (maps to bar)
df_timeline = pd.DataFrame(
{
"Task": ["Job A", "Job B"],
"Start": ["2009-01-01", "2009-03-05"],
"Finish": ["2009-02-28", "2009-04-15"],
"Resource": ["Alex", "Max"],
}
)
fig = px.timeline(
df_timeline,
x_start="Start",
x_end="Finish",
y="Task",
color="Resource",
template=template,
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"


def test_px_defaults():
px.defaults.labels = dict(x="hey x")
px.defaults.category_orders = dict(color=["b", "a"])
Expand Down