Coverage for peakipy/plotting.py: 96%
181 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-15 20:42 -0400
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-15 20:42 -0400
1from dataclasses import dataclass, field
2from typing import List
4import pandas as pd
5import numpy as np
6import plotly.graph_objects as go
7import matplotlib.pyplot as plt
8from matplotlib import cm
9from matplotlib.widgets import Button
10from matplotlib.backends.backend_pdf import PdfPages
11from rich import print
13from peakipy.io import Pseudo3D
14from peakipy.utils import df_to_rich_table, bad_color_selection, bad_column_selection
17@dataclass
18class PlottingDataForPlane:
19 pseudo3D: Pseudo3D
20 plane_id: int
21 plane_lineshape_parameters: pd.DataFrame
22 X: np.array
23 Y: np.array
24 mask: np.array
25 individual_masks: List[np.array]
26 sim_data: np.array
27 sim_data_singles: List[np.array]
28 min_x: int
29 max_x: int
30 min_y: int
31 max_y: int
32 fit_color: str
33 data_color: str
34 rcount: int
35 ccount: int
37 x_plot: np.array = field(init=False)
38 y_plot: np.array = field(init=False)
39 masked_data: np.array = field(init=False)
40 masked_sim_data: np.array = field(init=False)
41 residual: np.array = field(init=False)
42 single_colors: List = field(init=False)
44 def __post_init__(self):
45 self.plane_data = self.pseudo3D.data[self.plane_id]
46 self.masked_data = self.plane_data.copy()
47 self.masked_sim_data = self.sim_data.copy()
48 self.masked_data[~self.mask] = np.nan
49 self.masked_sim_data[~self.mask] = np.nan
51 self.x_plot = self.pseudo3D.uc_f2.ppm(
52 self.X[self.min_y : self.max_y, self.min_x : self.max_x]
53 )
54 self.y_plot = self.pseudo3D.uc_f1.ppm(
55 self.Y[self.min_y : self.max_y, self.min_x : self.max_x]
56 )
57 self.masked_data = self.masked_data[
58 self.min_y : self.max_y, self.min_x : self.max_x
59 ]
60 self.sim_plot = self.masked_sim_data[
61 self.min_y : self.max_y, self.min_x : self.max_x
62 ]
63 self.residual = self.masked_data - self.sim_plot
65 for single_mask, single in zip(self.individual_masks, self.sim_data_singles):
66 single[~single_mask] = np.nan
67 self.sim_data_singles = [
68 sim_data_single[self.min_y : self.max_y, self.min_x : self.max_x]
69 for sim_data_single in self.sim_data_singles
70 ]
71 self.single_colors = [
72 cm.viridis(i) for i in np.linspace(0, 1, len(self.sim_data_singles))
73 ]
76def plot_data_is_valid(plot_data: PlottingDataForPlane) -> bool:
77 if len(plot_data.x_plot) < 1 or len(plot_data.y_plot) < 1:
78 print(
79 f"[red]Nothing to plot for cluster {int(plot_data.plane_lineshape_parameters.clustid)}[/red]"
80 )
81 print(f"[red]x={plot_data.x_plot},y={plot_data.y_plot}[/red]")
82 print(
83 df_to_rich_table(
84 plot_data.plane_lineshape_parameters,
85 title="",
86 columns=bad_column_selection,
87 styles=bad_color_selection,
88 )
89 )
90 plt.close()
91 validated = False
92 # print(Fore.RED + "Maybe your F1/F2 radii for fitting were too small...")
93 elif plot_data.masked_data.shape[0] == 0 or plot_data.masked_data.shape[1] == 0:
94 print(f"[red]Nothing to plot for cluster {int(plot_data.plane.clustid)}[/red]")
95 print(
96 df_to_rich_table(
97 plot_data.plane_lineshape_parameters,
98 title="Bad plane",
99 columns=bad_column_selection,
100 styles=bad_color_selection,
101 )
102 )
103 spec_lim_f1 = " - ".join(
104 ["%8.3f" % i for i in plot_data.pseudo3D.f1_ppm_limits]
105 )
106 spec_lim_f2 = " - ".join(
107 ["%8.3f" % i for i in plot_data.pseudo3D.f2_ppm_limits]
108 )
109 print(f"Spectrum limits are {plot_data.pseudo3D.f2_label:4s}:{spec_lim_f2} ppm")
110 print(f" {plot_data.pseudo3D.f1_label:4s}:{spec_lim_f1} ppm")
111 plt.close()
112 validated = False
113 else:
114 validated = True
115 return validated
118def create_matplotlib_figure(
119 plot_data: PlottingDataForPlane,
120 pdf: PdfPages,
121 individual=False,
122 label=False,
123 ccpn_flag=False,
124 show=True,
125 test=False,
126):
127 fig = plt.figure(figsize=(10, 6))
128 ax = fig.add_subplot(projection="3d")
129 if plot_data_is_valid(plot_data):
130 cset = ax.contourf(
131 plot_data.x_plot,
132 plot_data.y_plot,
133 plot_data.residual,
134 zdir="z",
135 offset=np.nanmin(plot_data.masked_data) * 1.1,
136 alpha=0.5,
137 cmap=cm.coolwarm,
138 )
139 cbl = fig.colorbar(cset, ax=ax, shrink=0.5, format="%.2e")
140 cbl.ax.set_title("Residual", pad=20)
142 if individual:
143 # for plotting single fit surfaces
144 single_colors = [
145 cm.viridis(i)
146 for i in np.linspace(0, 1, len(plot_data.sim_data_singles))
147 ]
148 [
149 ax.plot_surface(
150 plot_data.x_plot,
151 plot_data.y_plot,
152 z_single,
153 color=c,
154 alpha=0.5,
155 )
156 for c, z_single in zip(single_colors, plot_data.sim_data_singles)
157 ]
158 ax.plot_wireframe(
159 plot_data.x_plot,
160 plot_data.y_plot,
161 plot_data.sim_plot,
162 # colors=[cm.coolwarm(i) for i in np.ravel(residual)],
163 colors=plot_data.fit_color,
164 linestyle="--",
165 label="fit",
166 rcount=plot_data.rcount,
167 ccount=plot_data.ccount,
168 )
169 ax.plot_wireframe(
170 plot_data.x_plot,
171 plot_data.y_plot,
172 plot_data.masked_data,
173 colors=plot_data.data_color,
174 linestyle="-",
175 label="data",
176 rcount=plot_data.rcount,
177 ccount=plot_data.ccount,
178 )
179 ax.set_ylabel(plot_data.pseudo3D.f1_label)
180 ax.set_xlabel(plot_data.pseudo3D.f2_label)
182 # axes will appear inverted
183 ax.view_init(30, 120)
185 title = f"Plane={plot_data.plane_id},Cluster={plot_data.plane_lineshape_parameters.clustid.iloc[0]}"
186 plt.title(title)
187 print(f"[green]Plotting: {title}[/green]")
188 out_str = "Volumes (Heights)\n===========\n"
189 for _, row in plot_data.plane_lineshape_parameters.iterrows():
190 out_str += f"{row.assignment} = {row.amp:.3e} ({row.height:.3e})\n"
191 if label:
192 ax.text(
193 row.center_x_ppm,
194 row.center_y_ppm,
195 row.height * 1.2,
196 row.assignment,
197 (1, 1, 1),
198 )
200 ax.text2D(
201 -0.5,
202 1.0,
203 out_str,
204 transform=ax.transAxes,
205 fontsize=10,
206 fontfamily="sans-serif",
207 va="top",
208 bbox=dict(boxstyle="round", ec="k", fc="k", alpha=0.5),
209 )
211 ax.legend()
213 if show:
215 def exit_program(event):
216 exit()
218 def next_plot(event):
219 plt.close()
221 axexit = plt.axes([0.81, 0.05, 0.1, 0.075])
222 bnexit = Button(axexit, "Exit")
223 bnexit.on_clicked(exit_program)
224 axnext = plt.axes([0.71, 0.05, 0.1, 0.075])
225 bnnext = Button(axnext, "Next")
226 bnnext.on_clicked(next_plot)
227 if test:
228 return
229 if ccpn_flag:
230 plt.show(windowTitle="", size=(1000, 500))
231 else:
232 plt.show()
233 else:
234 pdf.savefig()
236 plt.close()
239def create_plotly_wireframe_lines(plot_data: PlottingDataForPlane):
240 lines = []
241 show_legend = lambda x: x < 1
242 showlegend = False
243 # make simulated data wireframe
244 line_marker = dict(color=plot_data.fit_color, width=4)
245 counter = 0
246 for i, j, k in zip(plot_data.x_plot, plot_data.y_plot, plot_data.sim_plot):
247 showlegend = show_legend(counter)
248 lines.append(
249 go.Scatter3d(
250 x=i,
251 y=j,
252 z=k,
253 mode="lines",
254 line=line_marker,
255 name="fit",
256 showlegend=showlegend,
257 )
258 )
259 counter += 1
260 for i, j, k in zip(plot_data.x_plot.T, plot_data.y_plot.T, plot_data.sim_plot.T):
261 lines.append(
262 go.Scatter3d(
263 x=i, y=j, z=k, mode="lines", line=line_marker, showlegend=showlegend
264 )
265 )
266 # make experimental data wireframe
267 line_marker = dict(color=plot_data.data_color, width=4)
268 counter = 0
269 for i, j, k in zip(plot_data.x_plot, plot_data.y_plot, plot_data.masked_data):
270 showlegend = show_legend(counter)
271 lines.append(
272 go.Scatter3d(
273 x=i,
274 y=j,
275 z=k,
276 mode="lines",
277 name="data",
278 line=line_marker,
279 showlegend=showlegend,
280 )
281 )
282 counter += 1
283 for i, j, k in zip(plot_data.x_plot.T, plot_data.y_plot.T, plot_data.masked_data.T):
284 lines.append(
285 go.Scatter3d(
286 x=i, y=j, z=k, mode="lines", line=line_marker, showlegend=showlegend
287 )
288 )
290 return lines
293def construct_surface_legend_string(row):
294 surface_legend = ""
295 surface_legend += row.assignment
296 return surface_legend
299def create_plotly_surfaces(plot_data: PlottingDataForPlane):
300 data = []
301 color_scale_values = np.linspace(0, 1, len(plot_data.single_colors))
302 color_scale = [
303 [val, f"rgb({', '.join('%d'%(i*255) for i in c[0:3])})"]
304 for val, c in zip(color_scale_values, plot_data.single_colors)
305 ]
306 for val, individual_peak, row in zip(
307 color_scale_values,
308 plot_data.sim_data_singles,
309 plot_data.plane_lineshape_parameters.itertuples(),
310 ):
311 name = construct_surface_legend_string(row)
312 colors = np.zeros(shape=individual_peak.shape) + val
313 data.append(
314 go.Surface(
315 z=individual_peak,
316 x=plot_data.x_plot,
317 y=plot_data.y_plot,
318 opacity=0.5,
319 surfacecolor=colors,
320 colorscale=color_scale,
321 showscale=False,
322 cmin=0,
323 cmax=1,
324 name=name,
325 )
326 )
327 return data
330def create_residual_contours(plot_data: PlottingDataForPlane):
331 contours = go.Contour(
332 x=plot_data.x_plot[0], y=plot_data.y_plot.T[0], z=plot_data.residual
333 )
334 return contours
337def create_residual_figure(plot_data: PlottingDataForPlane):
338 data = create_residual_contours(plot_data)
339 fig = go.Figure(data=data)
340 fig.update_layout(
341 title="Fit residuals",
342 xaxis_title=f"{plot_data.pseudo3D.f2_label} ppm",
343 yaxis_title=f"{plot_data.pseudo3D.f1_label} ppm",
344 xaxis=dict(range=[plot_data.x_plot.max(), plot_data.x_plot.min()]),
345 yaxis=dict(range=[plot_data.y_plot.max(), plot_data.y_plot.min()]),
347 )
348 return fig
351def create_plotly_figure(plot_data: PlottingDataForPlane):
352 lines = create_plotly_wireframe_lines(plot_data)
353 surfaces = create_plotly_surfaces(plot_data)
354 fig = go.Figure(data=lines + surfaces)
355 fig = update_axis_ranges(fig, plot_data)
356 return fig
359def update_axis_ranges(fig, plot_data: PlottingDataForPlane):
360 fig.update_layout(
361 scene=dict(
362 xaxis=dict(range=[plot_data.x_plot.max(), plot_data.x_plot.min()]),
363 yaxis=dict(range=[plot_data.y_plot.max(), plot_data.y_plot.min()]),
364 xaxis_title=f"{plot_data.pseudo3D.f2_label} ppm",
365 yaxis_title=f"{plot_data.pseudo3D.f1_label} ppm",
366 annotations=make_annotations(plot_data),
367 )
368 )
369 return fig
372def make_annotations(plot_data: PlottingDataForPlane):
373 annotations = []
374 for row in plot_data.plane_lineshape_parameters.itertuples():
375 annotations.append(
376 dict(
377 showarrow=True,
378 x=row.center_x_ppm,
379 y=row.center_y_ppm,
380 z=row.height * 1.0,
381 text=row.assignment,
382 opacity=0.8,
383 textangle=0,
384 arrowsize=1,
385 )
386 )
387 return annotations
390def validate_sample_count(sample_count):
391 if type(sample_count) == int:
392 sample_count = sample_count
393 else:
394 raise TypeError("Sample count (ccount, rcount) should be an integer")
395 return sample_count
398def unpack_plotting_colors(colors):
399 match colors:
400 case (data_color, fit_color):
401 data_color, fit_color = colors
402 case _:
403 data_color, fit_color = "green", "blue"
404 return data_color, fit_color