Skip to content

Commit 7addb8b

Browse files
committed
Update plottools.py
Added option to plot residuals. Still some work to be done on getting plotting working nicely for VASE measurements.
1 parent 87ae7e8 commit 7addb8b

File tree

1 file changed

+42
-20
lines changed

1 file changed

+42
-20
lines changed

tools/plottools.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def plot_ellipsdata(
1212
xaxis="aoi",
1313
plot_labels=True,
1414
legend=True,
15+
resAx=None
1516
):
1617
"""
1718
Plots delta and psi values as a function of wavelength or angle of incidence.
@@ -79,34 +80,53 @@ def plot_ellipsdata(
7980
axt.plot(aois, deltas, color="b")
8081

8182
elif xaxis == "wavelength":
82-
unique_aois = np.unique(data.aoi)
83+
# Doesn't work for VASE
8384
wavs = np.linspace(
8485
np.min(data.wavelength) - 50, np.max(data.wavelength) + 50
8586
)
8687
x = data.wavelength
8788

8889
if model != None:
89-
print(x)
90-
for idx, wav in enumerate(np.unique(data.wavelength)):
91-
wavelength, aoi, d_psi, d_delta = list(
92-
data.unique_wavelength_data()
93-
)[idx]
94-
95-
psi, delta = model(np.c_[np.ones_like(aoi) * wavelength, aoi])
96-
ax.plot(np.ones_like(psi) * wavelength, psi, color="r")
97-
axt.plot(np.ones_like(delta) * wavelength, delta, color="b")
98-
99-
xlab = "Wavelength (nm)"
100-
101-
p = ax.scatter(x, data.psi, color="r")
102-
d = axt.scatter(x, data.delta, color="b")
103-
104-
ax.legend(handles=[p, d], labels=["Psi", "Delta"])
90+
wavelength_aois = np.c_[wavs, data.aoi[0]*np.ones_like(wavs)]
91+
psi, delta = model(wavelength_aois)
92+
ax.plot(wavs, psi, color="r")
93+
axt.plot(wavs, delta, color="b")
94+
# for idx, wav in enumerate(np.unique(data.wavelength)):
95+
# wavelength, aoi, d_psi, d_delta = list(
96+
# data.unique_wavelength_data()
97+
# )[idx]
98+
99+
# psi, delta = model(np.c_[np.ones_like(aoi) * wavelength, aoi])
100+
# ax.plot(np.ones_like(psi) * wavelength, psi, color="r")
101+
# axt.plot(np.ones_like(delta) * wavelength, delta, color="b")
102+
103+
xlab = "Wavelength (nm)"
104+
105+
# Plot data
106+
p = ax.scatter(x, data.psi, color="r", alpha=0.5)
107+
d = axt.scatter(x, data.delta, color="b", alpha=0.5)
108+
109+
# ax.legend(handles=[p, d], labels=["Psi", "Delta"], loc='center right')
110+
111+
if resAx != None:
112+
assert objective != None, 'To plot residuals you must supply an objective'
113+
res = objective.residuals()
114+
numdp = int(len(res)/2)
115+
psires = res[:numdp]
116+
delres = res[numdp:]
117+
resAx.scatter(x, psires, color='r')
118+
resAx.scatter(x, delres, color='b')
119+
resAx.text(0.95, 0.1, s=r'$\chi^2 = $' + f'{np.round(objective.chisqr(),3)}', transform=resAx.transAxes, ha='right', va='bottom')
120+
121+
105122

106123
if plot_labels:
107-
ax.set(ylabel="Psi", xlabel=xlab)
108-
axt.set(ylabel="Delta")
109-
124+
ax.set_ylabel("Psi", color='red')
125+
ax.set_xlabel(xlab)
126+
127+
axt.set_ylabel("Delta", color='blue')
128+
if resAx != None:
129+
resAx.set(ylabel='error')
110130

111131
def plot_structure(
112132
ax,
@@ -157,6 +177,8 @@ def plot_structure(
157177
wavelengths = [658]
158178

159179
if len(wavelengths) > 1:
180+
if len(wavelengths) >8:
181+
wavelengths = np.linspace(np.min(wavelengths), np.max(wavelengths), 6)
160182
colors = plt.cm.viridis(np.linspace(0, 1, len(wavelengths)))
161183
alpha = 0.5
162184
else:

0 commit comments

Comments
 (0)