diff --git a/src/corner/core.py b/src/corner/core.py index 50fd9f2..b26c01a 100644 --- a/src/corner/core.py +++ b/src/corner/core.py @@ -619,7 +619,13 @@ def hist2d( levels = 1.0 - np.exp(-0.5 * np.arange(0.5, 2.1, 0.5) ** 2) # This is the base color of the axis (background color) - base_color = ax.get_facecolor() + if isinstance(pcolor_kwargs, dict): + if "cmap" in list(pcolor_kwargs.keys()): + base_color = pcolor_kwargs["cmap"].resampled(512)(512) + else: + base_color = ax.get_facecolor() + else: + base_color = ax.get_facecolor() # This is the color map for the density plot, over-plotted to indicate the # density of the points near the center. @@ -745,18 +751,9 @@ def hist2d( ] ) - if plot_datapoints: - if data_kwargs is None: - data_kwargs = dict() - data_kwargs["color"] = data_kwargs.get("color", color) - data_kwargs["ms"] = data_kwargs.get("ms", 2.0) - data_kwargs["mec"] = data_kwargs.get("mec", "none") - data_kwargs["alpha"] = data_kwargs.get("alpha", 0.1) - ax.plot(x, y, "o", zorder=-1, rasterized=True, **data_kwargs) - # Plot the base fill to hide the densest data points. if (plot_contours or plot_density) and not no_fill_contours: - ax.contourf( + clevels = ax.contourf( X2, Y2, H2.T, @@ -765,6 +762,25 @@ def hist2d( antialiased=False, ) + if plot_datapoints: + if data_kwargs is None: + data_kwargs = dict() + data_kwargs["color"] = data_kwargs.get("color", color) + data_kwargs["ms"] = data_kwargs.get("ms", 2.0) + data_kwargs["mec"] = data_kwargs.get("mec", "none") + data_kwargs["alpha"] = data_kwargs.get("alpha", 0.1) + + try: + p = clevels.collections[0].get_paths() + indices = np.zeros_like(x, dtype=bool) + for level in p: + indices |= level.contains_points(list(zip(*(x, y)))) + ax.plot( + x[~indices], y[~indices], "o", rasterized=True, **data_kwargs + ) + except NameError: + ax.plot(x, y, "o", zorder=-1, rasterized=True, **data_kwargs) + if plot_contours and fill_contours: if contourf_kwargs is None: contourf_kwargs = dict() @@ -785,7 +801,11 @@ def hist2d( elif plot_density: if pcolor_kwargs is None: pcolor_kwargs = dict() - ax.pcolor(X, Y, H.max() - H.T, cmap=density_cmap, **pcolor_kwargs) + + if "cmap" not in pcolor_kwargs.keys(): + pcolor_kwargs["cmap"] = density_cmap + + ax.pcolor(X, Y, H.max() - H.T, **pcolor_kwargs) # Plot the contour edge colors. if plot_contours: diff --git a/tests/test_hist2d.py b/tests/test_hist2d.py index cadd867..f14d71a 100644 --- a/tests/test_hist2d.py +++ b/tests/test_hist2d.py @@ -110,6 +110,13 @@ def test_levels1(): _run_hist2d("levels1", levels=[0.68, 0.95]) +@image_comparison( + baseline_images=["levels1"], remove_text=True, extensions=["png"] +) +def test_cmap(): + _run_hist2d("cmap", levels=[0.68, 0.95], pcolor_kwargs={"cmap": pl.cm.jet}) + + @image_comparison( baseline_images=["levels2"], remove_text=True, extensions=["png"] )