diff --git a/modulus/sym/utils/io/plotter.py b/modulus/sym/utils/io/plotter.py index bbc41e69..2ca431e1 100644 --- a/modulus/sym/utils/io/plotter.py +++ b/modulus/sym/utils/io/plotter.py @@ -139,7 +139,7 @@ def __call__(self, invar, outvar): for k in outvar: f = plt.figure(figsize=(5, 4), dpi=100) if ndim == 1: - plt.plot(invar[dims[0]][:, 0], outvar[:, 0]) + plt.plot(invar[dims[0]][:, 0], outvar[k][:, 0]) plt.xlabel(dims[0]) elif ndim == 2: plt.imshow(outvar[k].T, origin="lower", extent=extent)