diff --git a/refellips/dispersion.py b/refellips/dispersion.py index aeccc77..15e4970 100644 --- a/refellips/dispersion.py +++ b/refellips/dispersion.py @@ -293,12 +293,12 @@ def complex(self, wavelength): wav = wavelength # Convert between μm & nm (constants are typically given in μm) - wav *= 1e-3 + wav_um = wav * 1e-3 real = np.sqrt( self.Einf.value - + (self.Am.value * wav**2) / (wav**2 - self.En.value**2) - - (self.P.value * wav**2) + + (self.Am.value * wav_um**2) / (wav_um**2 - self.En.value**2) + - (self.P.value * wav_um**2) ) return real + 1j * 0.0 @@ -311,12 +311,12 @@ def epsilon(self, wavelength): wav = wavelength # Convert between μm & nm (constants are typically given in μm) - wav *= 1e-3 + wav_um = wav * 1e-3 real = ( self.Einf.value - + (self.Am.value * wav**2) / (wav**2 - self.En.value**2) - - (self.P.value * wav**2) + + (self.Am.value * wav_um**2) / (wav_um**2 - self.En.value**2) + - (self.P.value * wav_um**2) ) return real + 1j * 0 diff --git a/tools/plottools.py b/tools/plottools.py index fb0387e..23f8e5d 100644 --- a/tools/plottools.py +++ b/tools/plottools.py @@ -12,6 +12,7 @@ def plot_ellipsdata( xaxis="aoi", plot_labels=True, legend=True, + resax=None, ): """ Plots delta and psi values as a function of wavelength or angle of incidence. @@ -33,31 +34,35 @@ def plot_ellipsdata( The default is None. objective : refellips.objectiveSE.ObjectiveSE, optional Objective (containing model and data) to plot. If the objective is provided, - neither model or data should be provided. The default is None. + neither model nor data should be provided. The default is None. xaxis : String, optional Either 'aoi' or 'wavelength'. The default is 'aoi'. plot_labels : Bool, optional Whether to plot axis labels. The default is True. legend : Bool, optional Whether to plot the legend. The default is True. - + resax : matplotlib.axes._subplots.AxesSubplot, optional + Axis object on which delta and psi residuals. If not provided, + no residual will be plotted. Returns ------- None. """ - if objective != None: + if objective is not None: assert ( - data == None and model == None + data is None and model is None ), "If objective is supplied, model and data should not be passed" data = objective.data model = objective.model - elif model != None: - assert data != None, "If you supply a model, you must also supply data" + elif model is not None: + assert ( + data is not None + ), "If you supply a model, you must also supply data" else: assert ( - data != None + data is not None ), "must supply at least one of data, model or objective" assert ( @@ -72,40 +77,70 @@ def plot_ellipsdata( x = data.aoi xlab = "AOI (°)" - if model != None: + if model is not None: for wav in unique_wavs: psis, deltas = model(np.c_[np.ones_like(aois) * wav, aois]) ax.plot(aois, psis, color="r") axt.plot(aois, deltas, color="b") elif xaxis == "wavelength": - unique_aois = np.unique(data.aoi) + # Doesn't work for VASE wavs = np.linspace( np.min(data.wavelength) - 50, np.max(data.wavelength) + 50 ) x = data.wavelength + xlab = "Wavelength (nm)" - if model != None: - print(x) - for idx, wav in enumerate(np.unique(data.wavelength)): - wavelength, aoi, d_psi, d_delta = list( - data.unique_wavelength_data() - )[idx] + if model is not None: + wavelength_aois = np.c_[wavs, data.aoi[0] * np.ones_like(wavs)] + psi, delta = model(wavelength_aois) + ax.plot(wavs, psi, color="r") + axt.plot(wavs, delta, color="b") + # for idx, wav in enumerate(np.unique(data.wavelength)): + # wavelength, aoi, d_psi, d_delta = list( + # data.unique_wavelength_data() + # )[idx] - psi, delta = model(np.c_[np.ones_like(aoi) * wavelength, aoi]) - ax.plot(np.ones_like(psi) * wavelength, psi, color="r") - axt.plot(np.ones_like(delta) * wavelength, delta, color="b") + # psi, delta = model(np.c_[np.ones_like(aoi) * wavelength, aoi]) + # ax.plot(np.ones_like(psi) * wavelength, psi, color="r") + # axt.plot(np.ones_like(delta) * wavelength, delta, color="b") - xlab = "Wavelength (nm)" + else: + assert False, "xaxis must be 'aoi' or 'wavelength'" - p = ax.scatter(x, data.psi, color="r") - d = axt.scatter(x, data.delta, color="b") + # Plot data + p = ax.scatter(x, data.psi, color="r", alpha=0.5) + d = axt.scatter(x, data.delta, color="b", alpha=0.5) - ax.legend(handles=[p, d], labels=["Psi", "Delta"]) + if legend: + ax.legend(handles=[p, d], labels=["Psi", "Delta"], loc="center right") + + if resax is not None: + assert ( + objective is not None + ), "To plot residuals you must supply an objective" + res = objective.residuals() + numdp = int(len(res) / 2) + psires = res[:numdp] + delres = res[numdp:] + resax.scatter(x, psires, color="r") + resax.scatter(x, delres, color="b") + resax.text( + 0.95, + 0.1, + s=r"$\chi^2 = $" + f"{np.round(objective.chisqr(),3)}", + transform=resax.transAxes, + ha="right", + va="bottom", + ) if plot_labels: - ax.set(ylabel="Psi", xlabel=xlab) - axt.set(ylabel="Delta") + ax.set_ylabel("Psi", color="red") + ax.set_xlabel(xlab) + + axt.set_ylabel("Delta", color="blue") + if resax is not None: + resax.set(ylabel="error") def plot_structure( @@ -131,7 +166,7 @@ def plot_structure( objective : refellips.objectiveSE.ObjectiveSE, optional Objective (containing model and data) to plot. If the objective is provided, structure should not be provided. The default is None. - structure : refnx.reflect.structure.Structure + structure : refnx.reflect.structure.Structure. Structure (which represents the interface) to be plotted. If structure is provided, objective should not be provided. The default is None. reverse_structure : bool @@ -144,19 +179,23 @@ def plot_structure( None. """ - if objective != None: + if objective is not None: assert ( - structure == None + structure is None ), "you must supply either an objective or structure, not both" structure = objective.model.structure wavelengths = np.unique(objective.data.wavelength) else: assert ( - structure != None + structure is not None ), "you must supply either an objective or structure" wavelengths = [658] if len(wavelengths) > 1: + if len(wavelengths) > 8: + wavelengths = np.linspace( + np.min(wavelengths), np.max(wavelengths), 6 + ) colors = plt.cm.viridis(np.linspace(0, 1, len(wavelengths))) alpha = 0.5 else: