Skip to content

Commit 50289e3

Browse files
authored
Adds the shift_value argument to plot_ref_sld (#174)
1 parent 45e44ef commit 50289e3

File tree

2 files changed

+20
-204
lines changed

2 files changed

+20
-204
lines changed

ratapi/utils/plotting.py

Lines changed: 15 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ratapi.rat_core import PlotEventData, makeSLDProfile
2222

2323

24-
def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool):
24+
def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool, shift_value: float):
2525
"""Extract the plot data for the sld, ref, error plot lines.
2626
2727
Parameters
@@ -33,6 +33,8 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool
3333
Controls whether Q^4 is plotted on the reflectivity plot
3434
show_error_bar : bool, default: True
3535
Controls whether the error bars are shown
36+
shift_value : float
37+
A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts
3638
3739
Returns
3840
-------
@@ -42,9 +44,12 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool
4244
"""
4345
results = {"ref": [], "error": [], "sld": [], "sld_resample": []}
4446

47+
if shift_value < 1 or shift_value > 100:
48+
raise ValueError("Parameter `shift_value` must be between 1 and 100")
49+
4550
for i, (r, data, sld) in enumerate(zip(event_data.reflectivity, event_data.shiftedData, event_data.sldProfiles)):
4651
# Calculate the divisor
47-
div = 1 if i == 0 and not q4 else 2 ** (4 * (i + 1))
52+
div = 1 if i == 0 and not q4 else 10 ** ((i / 100) * shift_value)
4853
q4_data = 1 if not q4 or not event_data.dataPresent[i] else data[:, 0] ** 4
4954
mult = q4_data / div
5055

@@ -87,194 +92,6 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool
8792
return results
8893

8994

90-
class PlotSLDWithBlitting:
91-
"""Create a SLD plot that uses blitting to get faster draws.
92-
93-
The blit plot stores the background from an
94-
initial draw then updates the foreground (lines and error bars) if the background is not changed.
95-
96-
Parameters
97-
----------
98-
data : PlotEventData
99-
The plot event data that contains all the information
100-
to generate the ref and sld plots
101-
fig : matplotlib.pyplot.figure, optional
102-
The figure class that has two subplots
103-
linear_x : bool, default: False
104-
Controls whether the x-axis on reflectivity plot uses the linear scale
105-
q4 : bool, default: False
106-
Controls whether Q^4 is plotted on the reflectivity plot
107-
show_error_bar : bool, default: True
108-
Controls whether the error bars are shown
109-
show_grid : bool, default: False
110-
Controls whether the grid is shown
111-
show_legend : bool, default: True
112-
Controls whether the legend is shown
113-
"""
114-
115-
def __init__(
116-
self,
117-
data: PlotEventData,
118-
fig: Optional[matplotlib.pyplot.figure] = None,
119-
linear_x: bool = False,
120-
q4: bool = False,
121-
show_error_bar: bool = True,
122-
show_grid: bool = False,
123-
show_legend: bool = True,
124-
):
125-
self.figure = fig
126-
self.linear_x = linear_x
127-
self.q4 = q4
128-
self.show_error_bar = show_error_bar
129-
self.show_grid = show_grid
130-
self.show_legend = show_legend
131-
self.updatePlot(data)
132-
self.event_id = self.figure.canvas.mpl_connect("resize_event", self.resizeEvent)
133-
134-
def __del__(self):
135-
self.figure.canvas.mpl_disconnect(self.event_id)
136-
137-
def resizeEvent(self, _event):
138-
"""Ensure the background is updated after a resize event."""
139-
self.__background_changed = True
140-
141-
def update(self, data: PlotEventData):
142-
"""Update the foreground, if background has not changed otherwise it updates full plot.
143-
144-
Parameters
145-
----------
146-
data : PlotEventData
147-
The plot event data that contains all the information
148-
to generate the ref and sld plots
149-
"""
150-
if self.__background_changed:
151-
self.updatePlot(data)
152-
else:
153-
self.updateForeground(data)
154-
155-
def __setattr__(self, name, value):
156-
super().__setattr__(name, value)
157-
if name in ["figure", "linear_x", "q4", "show_error_bar", "show_grid", "show_legend"]:
158-
self.__background_changed = True
159-
160-
def setAnimated(self, is_animated: bool):
161-
"""Set the animated property of foreground plot elements.
162-
163-
Parameters
164-
----------
165-
is_animated : bool
166-
Indicates if the animated property should been set.
167-
"""
168-
for line in self.figure.axes[0].lines:
169-
line.set_animated(is_animated)
170-
for line in self.figure.axes[1].lines:
171-
line.set_animated(is_animated)
172-
for container in self.figure.axes[0].containers:
173-
container[2][0].set_animated(is_animated)
174-
175-
def adjustErrorBar(self, error_bar_container, x, y, y_error):
176-
"""Adjust the error bar data.
177-
178-
Parameters
179-
----------
180-
error_bar_container : Tuple
181-
Tuple containing the artist of the errorbar i.e. (data line, cap lines, bar lines)
182-
x : np.ndarray
183-
The shifted data x axis data
184-
y : np.ndarray
185-
The shifted data y axis data
186-
y_error : np.ndarray
187-
The shifted data y axis error data
188-
"""
189-
line, _, (bars_y,) = error_bar_container
190-
191-
line.set_data(x, y)
192-
x_base = x
193-
y_base = y
194-
195-
y_error_top = y_base + y_error
196-
y_error_bottom = y_base - y_error
197-
198-
new_segments_y = [np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom)]
199-
bars_y.set_segments(new_segments_y)
200-
201-
def updatePlot(self, data: PlotEventData):
202-
"""Update the full plot.
203-
204-
Parameters
205-
----------
206-
data : PlotEventData
207-
The plot event data that contains all the information
208-
to generate the ref and sld plots
209-
"""
210-
if self.figure is not None:
211-
self.figure.clf()
212-
self.figure = plot_ref_sld_helper(
213-
data,
214-
self.figure,
215-
linear_x=self.linear_x,
216-
q4=self.q4,
217-
show_error_bar=self.show_error_bar,
218-
show_grid=self.show_grid,
219-
show_legend=self.show_legend,
220-
animated=True,
221-
)
222-
223-
self.bg = self.figure.canvas.copy_from_bbox(self.figure.bbox)
224-
for line in self.figure.axes[0].lines:
225-
self.figure.axes[0].draw_artist(line)
226-
for line in self.figure.axes[1].lines:
227-
self.figure.axes[1].draw_artist(line)
228-
for container in self.figure.axes[0].containers:
229-
self.figure.axes[0].draw_artist(container[2][0])
230-
self.figure.canvas.blit(self.figure.bbox)
231-
self.setAnimated(False)
232-
self.__background_changed = False
233-
234-
def updateForeground(self, data: PlotEventData):
235-
"""Update the plot foreground only.
236-
237-
Parameters
238-
----------
239-
data : PlotEventData
240-
The plot event data that contains all the information
241-
to generate the ref and sld plots
242-
"""
243-
self.setAnimated(True)
244-
self.figure.canvas.restore_region(self.bg)
245-
plot_data = _extract_plot_data(data, self.q4, self.show_error_bar)
246-
247-
offset = 2 if self.show_error_bar else 1
248-
for i in range(
249-
0,
250-
len(self.figure.axes[0].lines),
251-
):
252-
self.figure.axes[0].lines[i].set_data(plot_data["ref"][i // offset][0], plot_data["ref"][i // offset][1])
253-
self.figure.axes[0].draw_artist(self.figure.axes[0].lines[i])
254-
255-
i = 0
256-
for j in range(len(plot_data["sld"])):
257-
for sld in plot_data["sld"][j]:
258-
self.figure.axes[1].lines[i].set_data(sld[0], sld[1])
259-
self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i])
260-
i += 1
261-
262-
if plot_data["sld_resample"]:
263-
for resampled in plot_data["sld_resample"][j]:
264-
self.figure.axes[1].lines[i].set_data(resampled[0], resampled[1])
265-
self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i])
266-
i += 1
267-
268-
for i, container in enumerate(self.figure.axes[0].containers):
269-
self.adjustErrorBar(container, plot_data["error"][i][0], plot_data["error"][i][1], plot_data["error"][i][2])
270-
self.figure.axes[0].draw_artist(container[2][0])
271-
self.figure.axes[0].draw_artist(container[0])
272-
273-
self.figure.canvas.blit(self.figure.bbox)
274-
self.figure.canvas.flush_events()
275-
self.setAnimated(False)
276-
277-
27895
def plot_ref_sld_helper(
27996
data: PlotEventData,
28097
fig: Optional[matplotlib.pyplot.figure] = None,
@@ -285,6 +102,7 @@ def plot_ref_sld_helper(
285102
show_error_bar: bool = True,
286103
show_grid: bool = False,
287104
show_legend: bool = True,
105+
shift_value: float = 100,
288106
animated=False,
289107
):
290108
"""Clear the previous plots and updates the ref and SLD plots.
@@ -311,6 +129,8 @@ def plot_ref_sld_helper(
311129
Controls whether the grid is shown
312130
show_legend : bool, default: True
313131
Controls whether the legend is shown
132+
shift_value : float, default: 100
133+
A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts
314134
animated : bool, default: False
315135
Controls whether the animated property of foreground plot elements should be set.
316136
@@ -339,7 +159,7 @@ def plot_ref_sld_helper(
339159
ref_plot.cla()
340160
sld_plot.cla()
341161

342-
plot_data = _extract_plot_data(data, q4, show_error_bar)
162+
plot_data = _extract_plot_data(data, q4, show_error_bar, shift_value)
343163
for i, name in enumerate(data.contrastNames):
344164
ref_plot.plot(plot_data["ref"][i][0], plot_data["ref"][i][1], label=name, linewidth=1, animated=animated)
345165
color = ref_plot.get_lines()[-1].get_color()
@@ -427,6 +247,7 @@ def plot_ref_sld(
427247
show_error_bar: bool = True,
428248
show_grid: bool = False,
429249
show_legend: bool = True,
250+
shift_value: float = 100,
430251
) -> Union[plt.Figure, None]:
431252
"""Plot the reflectivity and SLD profiles.
432253
@@ -454,6 +275,8 @@ def plot_ref_sld(
454275
Controls whether the grid is shown
455276
show_legend : bool, default: True
456277
Controls whether the legend is shown
278+
shift_value : float, default: 100
279+
A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts
457280
458281
Returns
459282
-------
@@ -524,6 +347,7 @@ def plot_ref_sld(
524347
show_error_bar=show_error_bar,
525348
show_grid=show_grid,
526349
show_legend=show_legend,
350+
shift_value=shift_value,
527351
)
528352

529353
if return_fig:

tests/test_plotting.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -481,20 +481,12 @@ def test_bayes_validation(input_project, reflectivity_calculation_results):
481481

482482
@pytest.mark.parametrize("data", [data(), domains_data()])
483483
def test_extract_plot_data(data) -> None:
484-
plot_data = RATplot._extract_plot_data(data, False, True)
484+
plot_data = RATplot._extract_plot_data(data, False, True, 50)
485485
assert len(plot_data["ref"]) == len(data.reflectivity)
486486
assert len(plot_data["sld"]) == len(data.shiftedData)
487487

488+
with pytest.raises(ValueError, match=r"Parameter `shift_value` must be between 1 and 100"):
489+
RATplot._extract_plot_data(data, False, True, 0)
488490

489-
@patch("ratapi.utils.plotting.plot_ref_sld_helper")
490-
def test_blit_plot(plot_helper, fig: plt.figure) -> None:
491-
plot_helper.return_value = fig
492-
event_data = data()
493-
new_plot = RATplot.PlotSLDWithBlitting(event_data)
494-
assert plot_helper.call_count == 1
495-
new_plot.update(event_data)
496-
assert plot_helper.call_count == 1 # foreground only is updated so no call to plot helper
497-
new_plot.show_grid = False
498-
new_plot.figure = plt.subplots(1, 2)[0]
499-
new_plot.update(event_data) # plot properties have changed so update should call plot_helper
500-
assert plot_helper.call_count == 2
491+
with pytest.raises(ValueError, match=r"Parameter `shift_value` must be between 1 and 100"):
492+
RATplot._extract_plot_data(data, False, True, 100.5)

0 commit comments

Comments
 (0)