Skip to content
This repository was archived by the owner on Aug 29, 2025. It is now read-only.

Commit 385c062

Browse files
density_heatmap and fixes to density_contour
1 parent 2b01357 commit 385c062

File tree

5 files changed

+147
-34
lines changed

5 files changed

+147
-34
lines changed

gallery.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ px.scatter(gapminder, x="gdpPercap", y="lifeExp", animation_frame="year", animat
112112

113113
```python
114114
px.line(gapminder, x="year", y="lifeExp", color="continent", line_group="country", hover_name="country",
115-
line_shape="spline")
115+
line_shape="spline", render_mode="svg")
116116
```
117117

118118
```python
@@ -129,6 +129,10 @@ px.density_contour(iris, x="sepal_width", y="sepal_length")
129129
px.density_contour(iris, x="sepal_width", y="sepal_length", color="species", marginal_x="rug", marginal_y="histogram")
130130
```
131131

132+
```python
133+
px.density_heatmap(iris, x="sepal_width", y="sepal_length", marginal_x="rug", marginal_y="histogram")
134+
```
135+
132136
```python
133137
px.bar(tips, x="sex", y="total_bill", color="smoker", barmode="group")
134138
```

plotly_express/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
parallel_categories,
3131
choropleth,
3232
density_contour,
33+
density_heatmap,
3334
)
3435

3536
from ._core import ( # noqa: F401
@@ -50,6 +51,7 @@
5051
"scatter_geo",
5152
"scatter_matrix",
5253
"density_contour",
54+
"density_heatmap",
5355
"line",
5456
"line_polar",
5557
"line_ternary",

plotly_express/_chart_types.py

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def density_contour(
5858
data_frame,
5959
x=None,
6060
y=None,
61+
z=None,
6162
color=None,
6263
facet_row=None,
6364
facet_col=None,
@@ -77,25 +78,91 @@ def density_contour(
7778
log_y=False,
7879
range_x=None,
7980
range_y=None,
81+
histfunc=None,
82+
histnorm=None,
83+
nbinsx=None,
84+
nbinsy=None,
8085
title=None,
8186
template=None,
8287
width=None,
8388
height=None,
8489
):
8590
"""
8691
In a density contour plot, rows of `data_frame` are grouped together into contour marks to \
87-
visualize the density of their distribution in 2D space.
92+
visualize the 2D distribution of an aggregate function `histfunc` (e.g. the count or sum) \
93+
of the value `z`.
8894
"""
8995
return make_figure(
9096
args=locals(),
9197
constructor=go.Histogram2dContour,
92-
trace_patch=dict(contours=dict(coloring="none")),
98+
trace_patch=dict(
99+
contours=dict(coloring="none"),
100+
histfunc=histfunc,
101+
histnorm=histnorm,
102+
nbinsx=nbinsx,
103+
nbinsy=nbinsy,
104+
xbingroup="x",
105+
ybingroup="y",
106+
),
93107
)
94108

95109

96110
density_contour.__doc__ = make_docstring(density_contour)
97111

98112

113+
def density_heatmap(
114+
data_frame,
115+
x=None,
116+
y=None,
117+
z=None,
118+
facet_row=None,
119+
facet_col=None,
120+
hover_name=None,
121+
hover_data=None,
122+
animation_frame=None,
123+
animation_group=None,
124+
category_orders={},
125+
labels={},
126+
color_continuous_scale=None,
127+
color_continuous_midpoint=None,
128+
marginal_x=None,
129+
marginal_y=None,
130+
opacity=None,
131+
log_x=False,
132+
log_y=False,
133+
range_x=None,
134+
range_y=None,
135+
histfunc=None,
136+
histnorm=None,
137+
nbinsx=None,
138+
nbinsy=None,
139+
title=None,
140+
template=None,
141+
width=None,
142+
height=None,
143+
):
144+
"""
145+
In a density heatmap, rows of `data_frame` are grouped together into colored \
146+
rectangular tiles to visualize the 2D distribution of an aggregate function \
147+
`histfunc` (e.g. the count or sum) of the value `z`.
148+
"""
149+
return make_figure(
150+
args=locals(),
151+
constructor=go.Histogram2d,
152+
trace_patch=dict(
153+
histfunc=histfunc,
154+
histnorm=histnorm,
155+
nbinsx=nbinsx,
156+
nbinsy=nbinsy,
157+
xbingroup="x",
158+
ybingroup="y",
159+
),
160+
)
161+
162+
163+
density_heatmap.__doc__ = make_docstring(density_heatmap)
164+
165+
99166
def line(
100167
data_frame,
101168
x=None,
@@ -164,7 +231,6 @@ def area(
164231
range_x=None,
165232
range_y=None,
166233
line_shape=None,
167-
render_mode="auto",
168234
title=None,
169235
template=None,
170236
width=None,
@@ -268,7 +334,8 @@ def histogram(
268334
):
269335
"""
270336
In a histogram, rows of `data_frame` are grouped together into a rectangular mark to \
271-
visualize some aggregate quantity like count or sum.
337+
visualize the 1D distribution of an aggregate function `histfunc` (e.g. the count or sum) \
338+
of the value `y` (or `x` if `orientation` is `'h'`).
272339
"""
273340
return make_figure(
274341
args=locals(),
@@ -278,8 +345,9 @@ def histogram(
278345
histnorm=histnorm,
279346
histfunc=histfunc,
280347
nbinsx=nbins if orientation == "v" else None,
281-
nbinsy=nbins if orientation == "h" else None,
348+
nbinsy=None if orientation == "v" else nbins,
282349
cumulative=dict(enabled=cumulative),
350+
bingroup="x" if orientation == "v" else "y",
283351
),
284352
layout_patch=dict(barmode=barmode, barnorm=barnorm),
285353
)

plotly_express/_core.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import plotly.io as pio
44
from collections import namedtuple, OrderedDict
55
from .colors import qualitative, sequential
6-
import math, pandas
6+
import math
7+
import pandas
78

89

910
class PxDefaults(object):
@@ -88,8 +89,9 @@ def get_label(args, column):
8889
def get_decorated_label(args, column, role):
8990
label = get_label(args, column)
9091
if "histfunc" in args and (
91-
(role == "x" and args["orientation"] == "h")
92-
or (role == "y" and args["orientation"] == "v")
92+
(role == "x" and "orientation" in args and args["orientation"] == "h")
93+
or (role == "y" and "orientation" in args and args["orientation"] == "v")
94+
or (role == "z")
9395
):
9496
if label:
9597
return "%s of %s" % (args["histfunc"] or "count", label)
@@ -171,8 +173,13 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
171173
mapping_labels["%{xaxis.title.text}"] = "%{x}"
172174
mapping_labels["%{yaxis.title.text}"] = "%{y}"
173175

174-
elif v is not None or (
175-
trace_spec.constructor == go.Histogram and k in ["x", "y"]
176+
elif (
177+
v is not None
178+
or (trace_spec.constructor == go.Histogram and k in ["x", "y"])
179+
or (
180+
trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour]
181+
and k == "z"
182+
)
176183
):
177184
if k == "size":
178185
if "marker" not in result:
@@ -222,20 +229,28 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
222229
result[error_xy] = {}
223230
result[error_xy][arr] = g[v]
224231
elif k == "hover_name":
225-
if trace_spec.constructor not in [go.Histogram, go.Histogram2dContour]:
232+
if trace_spec.constructor not in [
233+
go.Histogram,
234+
go.Histogram2d,
235+
go.Histogram2dContour,
236+
]:
226237
result["hovertext"] = g[v]
227238
if hover_header == "":
228239
hover_header = "<b>%{hovertext}</b><br><br>"
229240
elif k == "hover_data":
230-
if trace_spec.constructor not in [go.Histogram, go.Histogram2dContour]:
241+
if trace_spec.constructor not in [
242+
go.Histogram,
243+
go.Histogram2d,
244+
go.Histogram2dContour,
245+
]:
231246
result["customdata"] = g[v].values
232247
for i, col in enumerate(v):
233248
v_label_col = get_decorated_label(args, col, None)
234249
mapping_labels[v_label_col] = "%%{customdata[%d]}" % i
235250
elif k == "color":
236251
if trace_spec.constructor == go.Choropleth:
237252
result["z"] = g[v]
238-
result["z"]["coloraxis"] = "coloraxis1"
253+
result["coloraxis"] = "coloraxis1"
239254
mapping_labels[v_label] = "%{z}"
240255
else:
241256
colorable = "marker"
@@ -255,7 +270,7 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
255270
if v:
256271
result[k] = g[v]
257272
mapping_labels[v_label] = "%%{%s}" % k
258-
if trace_spec.constructor not in [go.Histogram2dContour, go.Parcoords, go.Parcats]:
273+
if trace_spec.constructor not in [go.Parcoords, go.Parcats]:
259274
hover_lines = [k + "=" + v for k, v in mapping_labels.items()]
260275
result["hovertemplate"] = hover_header + "<br>".join(hover_lines)
261276
return result, fit_results
@@ -542,7 +557,7 @@ def make_trace_spec(args, constructor, attrs, trace_patch):
542557
trace_spec = TraceSpec(
543558
constructor=go.Histogram,
544559
attrs=[letter],
545-
trace_patch=dict(opacity=0.5, **axis_map),
560+
trace_patch=dict(opacity=0.5, bingroup=letter, **axis_map),
546561
)
547562
elif args["marginal_" + letter] == "violin":
548563
trace_spec = TraceSpec(
@@ -571,10 +586,10 @@ def make_trace_spec(args, constructor, attrs, trace_patch):
571586
**axis_map
572587
),
573588
)
574-
if "color" in attrs:
589+
if "color" in attrs or "color" not in args:
575590
if "marker" not in trace_spec.trace_patch:
576591
trace_spec.trace_patch["marker"] = dict()
577-
first_default_color = args["color_discrete_sequence"][0]
592+
first_default_color = args["color_continuous_scale"][0]
578593
trace_spec.trace_patch["marker"]["color"] = first_default_color
579594
result.append(trace_spec)
580595
if "trendline" in args and args["trendline"]:
@@ -702,6 +717,11 @@ def infer_config(args, constructor, trace_patch):
702717
grouped_attrs.append("marker.symbol")
703718

704719
trace_patch = trace_patch.copy()
720+
721+
if constructor == go.Histogram2d:
722+
show_colorbar = True
723+
trace_patch["coloraxis"] = "coloraxis1"
724+
705725
if "opacity" in args:
706726
if args["opacity"] is None:
707727
if "barmode" in args and args["barmode"] == "overlay":
@@ -797,8 +817,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
797817
trace_names = trace_names_by_frame[frame_name]
798818

799819
for trace_spec in trace_specs:
800-
constructor = trace_spec.constructor
801-
if constructor in [go.Scatter, go.Scatterpolar]:
820+
constructor_to_use = trace_spec.constructor
821+
if constructor_to_use in [go.Scatter, go.Scatterpolar]:
802822
if "render_mode" in args and (
803823
args["render_mode"] == "webgl"
804824
or (
@@ -807,11 +827,18 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
807827
and args["animation_frame"] is None
808828
)
809829
):
810-
constructor = (
811-
go.Scattergl if constructor == go.Scatter else go.Scatterpolargl
830+
constructor_to_use = (
831+
go.Scattergl
832+
if constructor_to_use == go.Scatter
833+
else go.Scatterpolargl
812834
)
813-
trace = trace_spec.constructor(name=trace_name)
814-
if trace_spec.constructor not in [go.Parcats, go.Parcoords, go.Choropleth]:
835+
trace = constructor_to_use(name=trace_name)
836+
if trace_spec.constructor not in [
837+
go.Parcats,
838+
go.Parcoords,
839+
go.Choropleth,
840+
go.Histogram2d,
841+
]:
815842
trace.update(
816843
legendgroup=trace_name,
817844
showlegend=(trace_name != "" and trace_name not in trace_names),
@@ -870,14 +897,18 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
870897
)
871898
layout_patch = layout_patch.copy()
872899
if show_colorbar:
900+
if "color" in args:
901+
colorvar = "color"
902+
elif constructor == go.Histogram2d:
903+
colorvar = "z"
873904
d = len(args["color_continuous_scale"]) - 1
874905
layout_patch["coloraxis1"] = dict(
875-
colorbar=dict(title=get_decorated_label(args, args["color"], "color")),
876906
colorscale=[
877907
[(1.0 * i) / (1.0 * d), x]
878908
for i, x in enumerate(args["color_continuous_scale"])
879909
],
880910
cmid=args["color_continuous_midpoint"],
911+
colorbar=dict(title=get_decorated_label(args, args[colorvar], colorvar)),
881912
)
882913
for v in ["title", "height", "width", "template"]:
883914
if args[v]:

plotly_express/_doc.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616
x=[
1717
colref,
1818
"Values from this column are used to position marks along the x axis in cartesian coordinates.",
19+
"For horizontal `histogram`s, these values are used as inputs to `histfunc`.",
1920
],
2021
y=[
2122
colref,
2223
"Values from this column are used to position marks along the y axis in cartesian coordinates.",
24+
"For vertical `histogram`s, these values are used as inputs to `histfunc`.",
2325
],
2426
z=[
2527
colref,
2628
"Values from this column are used to position marks along the z axis in cartesian coordinates.",
29+
"For `density_heatmap` and `density_contour` these values are used as the inputs to `histfunc`.",
2730
],
2831
a=[
2932
colref,
@@ -255,12 +258,20 @@
255258
"(integer, default is 90)",
256259
"Sets start angle for the angular axis, with 0 being due east and 90 being due north.",
257260
],
261+
histfunc=[
262+
"(string, one of `'count'`, `'sum'`, `'avg'`, `'min'`, `'max'`. Default is `'count'`)"
263+
"Function used to aggregate values for summarization (note: can be normalized with `histnorm`).",
264+
"The arguments to this function for `histogram` are the values of `y` if `orientation` is `'v'`,",
265+
"otherwise the arguements are the values of `x`.",
266+
"The arguments to this function for `density_heatmap` and `density_contour` are the values of `z`.",
267+
],
258268
histnorm=[
259269
"(string, one of `'percent'`, `'probability'`, `'density'`, `'probability density'`, default `None`)",
260-
"If `None`, the span of each bar corresponds to the number of occurrences (i.e. the number of data points lying inside the bins).",
261-
"If `'percent'` or `'probability'`, the span of each bar corresponds to the percentage / fraction of occurrences with respect to the total number of sample points (here, the sum of all bin HEIGHTS equals 100% / 1).",
262-
"If `'density'`, the span of each bar corresponds to the number of occurrences in a bin divided by the size of the bin interval (here, the sum of all bin AREAS equals the total number of sample points).",
263-
"If `'probability density'`, the area of each bar corresponds to the probability that an event will fall into the corresponding bin (here, the sum of all bin AREAS equals 1).",
270+
"If `None`, the output of `histfunc` is used as is.",
271+
"If `'probability'`, the output of `histfunc` for a given bin is divided by the sum of the output of `histfunc` for all bins.",
272+
"If `'percent'`, the output of `histfunc` for a given bin is divided by the sum of the output of `histfunc` for all bins and multiplied by 100.",
273+
"If `'density'`, the output of `histfunc` for a given bin is divided by the size of the bin.",
274+
"If `'probability density'`, the output of `histfunc` for a given bin is normalized such that it corresponds to the probability that a random event whose distribution is described by the output of `histfunc` will fall into that bin.",
264275
],
265276
barnorm=[
266277
"(string, one of `'fraction'` or `'percent'`, default is `None`)",
@@ -298,11 +309,6 @@
298309
"If `True`, an extra line segment is drawn between the first and last point.",
299310
],
300311
line_shape=["(string, one of `'linear'` or `'spline'`)", "Default is `'linear'`."],
301-
histfunc=[
302-
"(string, one of `'count'`, `'sum'`, `'avg'`, `'min'`, `'max'`. Default is `'count'`)"
303-
"Function used to compute histogram bar lengths.",
304-
"The arguments to this function are the values of `y` if `orientation` is `'v'`, otherwise the arguements are the values of `x`.",
305-
],
306312
scope=[
307313
"(string, one of `'world'`, `'usa'`, `'europe'`, `'asia'`, `'africa'`, `'north america'`, `'south america'`)"
308314
"Default is `'world'` unless `projection` is set to `'albers usa'`, which forces `'usa'`."
@@ -328,6 +334,8 @@
328334
"If `True`, histogram values are cumulative.",
329335
],
330336
nbins=["(positive integer)", "Sets the number of bins."],
337+
nbinsx=["(positive integer)", "Sets the number of bins along the x axis."],
338+
nbinsy=["(positive integer)", "Sets the number of bins along the y axis."],
331339
)
332340

333341

0 commit comments

Comments
 (0)