diff --git a/python/pyspark/pandas/plot/matplotlib.py b/python/pyspark/pandas/plot/matplotlib.py index 616cf1de340ed..60d3538605bf3 100644 --- a/python/pyspark/pandas/plot/matplotlib.py +++ b/python/pyspark/pandas/plot/matplotlib.py @@ -23,6 +23,7 @@ import numpy as np from matplotlib.axes._base import _process_plot_format # type: ignore[attr-defined] from matplotlib.figure import Figure +import pandas as pd from pandas.core.dtypes.inference import is_list_like from pandas.io.formats.printing import pprint_thing # type: ignore[import-untyped] from pandas.plotting._matplotlib import ( # type: ignore[import-untyped] @@ -968,5 +969,10 @@ def _plot(data, x=None, y=None, subplots=False, ax=None, kind="line", **kwds): plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds) plot_obj.generate() - plot_obj.draw() + if LooseVersion(pd.__version__) < "3.0.0": + plot_obj.draw() + else: + import matplotlib.pyplot as plt + + plt.draw_if_interactive() return plot_obj.result