Coverage for peakipy/fitting.py: 98%

624 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-09-14 14:49 -0400

1import re 

2from pathlib import Path 

3from dataclasses import dataclass, field 

4from typing import List, Tuple, Optional 

5 

6import numpy as np 

7from numpy import sqrt 

8import pandas as pd 

9from rich import print 

10from lmfit import Model, Parameters, Parameter 

11from lmfit.model import ModelResult 

12from pydantic import BaseModel 

13 

14from peakipy.lineshapes import ( 

15 Lineshape, 

16 pvoigt2d, 

17 pv_pv, 

18 pv_g, 

19 pv_l, 

20 voigt2d, 

21 gaussian_lorentzian, 

22 get_lineshape_function, 

23) 

24from peakipy.constants import log2 

25 

26 

27class FitDataModel(BaseModel): 

28 plane: int 

29 clustid: int 

30 assignment: str 

31 memcnt: int 

32 amp: float 

33 height: float 

34 center_x_ppm: float 

35 center_y_ppm: float 

36 fwhm_x_hz: float 

37 fwhm_y_hz: float 

38 lineshape: str 

39 x_radius: float 

40 y_radius: float 

41 center_x: float 

42 center_y: float 

43 sigma_x: float 

44 sigma_y: float 

45 

46 

47class FitDataModelPVGL(FitDataModel): 

48 fraction: float 

49 

50 

51class FitDataModelVoigt(FitDataModel): 

52 fraction: float 

53 gamma_x: float 

54 gamma_y: float 

55 

56 

57class FitDataModelPVPV(FitDataModel): 

58 fraction_x: float 

59 fraction_y: float 

60 

61 

62def validate_fit_data(dict): 

63 lineshape = dict.get("lineshape") 

64 if lineshape in ["PV", "G", "L"]: 

65 fit_data = FitDataModelPVGL(**dict) 

66 elif lineshape == "V": 

67 fit_data = FitDataModelVoigt(**dict) 

68 else: 

69 fit_data = FitDataModelPVPV(**dict) 

70 

71 return fit_data.model_dump() 

72 

73 

74def validate_fit_dataframe(df): 

75 validated_fit_data = [] 

76 for _, row in df.iterrows(): 

77 fit_data = validate_fit_data(row.to_dict()) 

78 validated_fit_data.append(fit_data) 

79 return pd.DataFrame(validated_fit_data) 

80 

81 

82def make_mask(data, c_x, c_y, r_x, r_y): 

83 """Create and elliptical mask 

84 

85 Generate an elliptical boolean mask with center c_x/c_y in points 

86 with radii r_x and r_y. Used to generate fit mask 

87 

88 :param data: 2D array 

89 :type data: np.array 

90 

91 :param c_x: x center 

92 :type c_x: float 

93 

94 :param c_y: y center 

95 :type c_y: float 

96 

97 :param r_x: radius in x 

98 :type r_x: float 

99 

100 :param r_y: radius in y 

101 :type r_y: float 

102 

103 :return: boolean mask of data.shape 

104 :rtype: numpy.array 

105 

106 """ 

107 a, b = c_y, c_x 

108 n_y, n_x = data.shape 

109 y, x = np.ogrid[-a : n_y - a, -b : n_x - b] 

110 mask = x**2.0 / r_x**2.0 + y**2.0 / r_y**2.0 <= 1.0 

111 return mask 

112 

113 

114def fix_params(params, to_fix): 

115 """Set parameters to fix 

116 

117 

118 :param params: lmfit parameters 

119 :type params: lmfit.Parameters 

120 

121 :param to_fix: list of parameter name to fix 

122 :type to_fix: list 

123 

124 :return: updated parameter object 

125 :rtype: lmfit.Parameters 

126 

127 """ 

128 for k in params: 

129 for p in to_fix: 

130 if p in k: 

131 params[k].vary = False 

132 

133 return params 

134 

135 

136def get_params(params, name): 

137 ps = [] 

138 ps_err = [] 

139 names = [] 

140 prefixes = [] 

141 for k in params: 

142 if name in k: 

143 ps.append(params[k].value) 

144 ps_err.append(params[k].stderr) 

145 names.append(k) 

146 prefixes.append(k.split(name)[0]) 

147 return ps, ps_err, names, prefixes 

148 

149 

150@dataclass 

151class PeakLimits: 

152 """Given a peak position and linewidth in points determine 

153 the limits based on the data 

154 

155 Arguments 

156 --------- 

157 peak: pd.DataFrame 

158 peak is a row from a pandas dataframe 

159 data: np.array 

160 2D numpy array 

161 """ 

162 

163 peak: pd.DataFrame 

164 data: np.array 

165 min_x: int = field(init=False) 

166 max_x: int = field(init=False) 

167 min_y: int = field(init=False) 

168 max_y: int = field(init=False) 

169 

170 def __post_init__(self): 

171 assert self.peak.Y_AXIS <= self.data.shape[0] 

172 assert self.peak.X_AXIS <= self.data.shape[1] 

173 self.max_y = int(np.ceil(self.peak.Y_AXIS + self.peak.YW)) + 1 

174 if self.max_y > self.data.shape[0]: 

175 self.max_y = self.data.shape[0] 

176 self.max_x = int(np.ceil(self.peak.X_AXIS + self.peak.XW)) + 1 

177 if self.max_x > self.data.shape[1]: 

178 self.max_x = self.data.shape[1] 

179 

180 self.min_y = int(self.peak.Y_AXIS - self.peak.YW) 

181 if self.min_y < 0: 

182 self.min_y = 0 

183 self.min_x = int(self.peak.X_AXIS - self.peak.XW) 

184 if self.min_x < 0: 

185 self.min_x = 0 

186 

187 

188def estimate_amplitude(peak, data): 

189 assert len(data.shape) == 2 

190 limits = PeakLimits(peak, data) 

191 amplitude_est = data[limits.min_y : limits.max_y, limits.min_x : limits.max_x].sum() 

192 return amplitude_est 

193 

194 

195def make_param_dict(peaks, data, lineshape: Lineshape = Lineshape.PV): 

196 """Make dict of parameter names using prefix""" 

197 

198 param_dict = {} 

199 

200 for _, peak in peaks.iterrows(): 

201 str_form = lambda x: "%s%s" % (to_prefix(peak.ASS), x) 

202 # using exact value of points (i.e decimal) 

203 param_dict[str_form("center_x")] = peak.X_AXISf 

204 param_dict[str_form("center_y")] = peak.Y_AXISf 

205 # estimate peak volume 

206 amplitude_est = estimate_amplitude(peak, data) 

207 param_dict[str_form("amplitude")] = amplitude_est 

208 # sigma linewidth esimate 

209 param_dict[str_form("sigma_x")] = peak.XW / 2.0 

210 param_dict[str_form("sigma_y")] = peak.YW / 2.0 

211 

212 match lineshape: 

213 case lineshape.V: 

214 #  Voigt G sigma from linewidth esimate 

215 param_dict[str_form("sigma_x")] = peak.XW / ( 

216 2.0 * sqrt(2.0 * log2) 

217 ) # 3.6013 

218 param_dict[str_form("sigma_y")] = peak.YW / ( 

219 2.0 * sqrt(2.0 * log2) 

220 ) # 3.6013 

221 #  Voigt L gamma from linewidth esimate 

222 param_dict[str_form("gamma_x")] = peak.XW / 2.0 

223 param_dict[str_form("gamma_y")] = peak.YW / 2.0 

224 # height 

225 # add height here 

226 

227 case lineshape.G: 

228 param_dict[str_form("fraction")] = 0.0 

229 case lineshape.L: 

230 param_dict[str_form("fraction")] = 1.0 

231 case lineshape.PV_PV: 

232 param_dict[str_form("fraction_x")] = 0.5 

233 param_dict[str_form("fraction_y")] = 0.5 

234 case _: 

235 param_dict[str_form("fraction")] = 0.5 

236 

237 return param_dict 

238 

239 

240def to_prefix(x): 

241 """ 

242 Peak assignments with characters that are not compatible lmfit model naming 

243 are converted to lmfit "safe" names. 

244 

245 :param x: Peak assignment to be used as prefix for lmfit model 

246 :type x: str 

247 

248 :returns: lmfit model prefix (_Peak_assignment_) 

249 :rtype: str 

250 

251 """ 

252 # must be string 

253 if type(x) != str: 

254 x = str(x) 

255 

256 prefix = "_" + x 

257 to_replace = [ 

258 [".", "_"], 

259 [" ", ""], 

260 ["{", "_"], 

261 ["}", "_"], 

262 ["[", "_"], 

263 ["]", "_"], 

264 ["-", ""], 

265 ["/", "or"], 

266 ["?", "maybe"], 

267 ["\\", ""], 

268 ["(", "_"], 

269 [")", "_"], 

270 ["@", "_at_"], 

271 ] 

272 for p in to_replace: 

273 prefix = prefix.replace(*p) 

274 

275 # Replace any remaining disallowed characters with underscore 

276 prefix = re.sub(r"[^a-z0-9_]", "_", prefix) 

277 return prefix + "_" 

278 

279 

280def make_models( 

281 model, 

282 peaks, 

283 data, 

284 lineshape: Lineshape = Lineshape.PV, 

285 xy_bounds=None, 

286): 

287 """Make composite models for multiple peaks 

288 

289 :param model: lineshape function 

290 :type model: function 

291 

292 :param peaks: instance of pandas.df.groupby("CLUSTID") 

293 :type peaks: pandas.df.groupby("CLUSTID") 

294 

295 :param data: NMR data 

296 :type data: numpy.array 

297 

298 :param lineshape: lineshape to use for fit (PV/G/L/PV_PV) 

299 :type lineshape: str 

300 

301 :param xy_bounds: bounds for peak centers (+/-x, +/-y) 

302 :type xy_bounds: tuple 

303 

304 :return mod: Composite lmfit model containing all peaks 

305 :rtype mod: lmfit.CompositeModel 

306 

307 :return p_guess: params for composite model with starting values 

308 :rtype p_guess: lmfit.Parameters 

309 

310 """ 

311 if len(peaks) == 1: 

312 # make model for first peak 

313 mod = Model(model, prefix="%s" % to_prefix(peaks.ASS.iloc[0])) 

314 # add parameters 

315 param_dict = make_param_dict( 

316 peaks, 

317 data, 

318 lineshape=lineshape, 

319 ) 

320 p_guess = mod.make_params(**param_dict) 

321 

322 elif len(peaks) > 1: 

323 # make model for first peak 

324 first_peak, *remaining_peaks = peaks.iterrows() 

325 mod = Model(model, prefix="%s" % to_prefix(first_peak[1].ASS)) 

326 for _, peak in remaining_peaks: 

327 mod += Model(model, prefix="%s" % to_prefix(peak.ASS)) 

328 

329 param_dict = make_param_dict( 

330 peaks, 

331 data, 

332 lineshape=lineshape, 

333 ) 

334 p_guess = mod.make_params(**param_dict) 

335 # add Peak params to p_guess 

336 

337 update_params(p_guess, param_dict, lineshape=lineshape, xy_bounds=xy_bounds) 

338 

339 return mod, p_guess 

340 

341 

342def update_params( 

343 params, param_dict, lineshape: Lineshape = Lineshape.PV, xy_bounds=None 

344): 

345 """Update lmfit parameters with values from Peak 

346 

347 :param params: lmfit parameters 

348 :type params: lmfit.Parameters object 

349 :param param_dict: parameters corresponding to each peak in fit 

350 :type param_dict: dict 

351 :param lineshape: Lineshape (PV, G, L, PV_PV etc.) 

352 :type lineshape: Lineshape 

353 :param xy_bounds: bounds on xy peak positions 

354 :type xy_bounds: tuple 

355 

356 :returns: None 

357 :rtype: None 

358 

359 ToDo 

360 -- deal with boundaries 

361 -- currently positions in points 

362 

363 """ 

364 for k, v in param_dict.items(): 

365 params[k].value = v 

366 # print("update", k, v) 

367 if "center" in k: 

368 if xy_bounds == None: 

369 # no bounds set 

370 pass 

371 else: 

372 if "center_x" in k: 

373 # set x bounds 

374 x_bound = xy_bounds[0] 

375 params[k].min = v - x_bound 

376 params[k].max = v + x_bound 

377 elif "center_y" in k: 

378 # set y bounds 

379 y_bound = xy_bounds[1] 

380 params[k].min = v - y_bound 

381 params[k].max = v + y_bound 

382 # pass 

383 # print( 

384 # "setting limit of %s, min = %.3e, max = %.3e" 

385 # % (k, params[k].min, params[k].max) 

386 # ) 

387 elif "sigma" in k: 

388 params[k].min = 0.0 

389 params[k].max = 1e4 

390 

391 elif "gamma" in k: 

392 params[k].min = 0.0 

393 params[k].max = 1e4 

394 # print( 

395 # "setting limit of %s, min = %.3e, max = %.3e" 

396 # % (k, params[k].min, params[k].max) 

397 # ) 

398 elif "fraction" in k: 

399 # fix weighting between 0 and 1 

400 params[k].min = 0.0 

401 params[k].max = 1.0 

402 

403 #  fix fraction of G or L 

404 match lineshape: 

405 case lineshape.G | lineshape.L: 

406 params[k].vary = False 

407 case lineshape.PV | lineshape.PV_PV: 

408 params[k].vary = True 

409 case _: 

410 pass 

411 

412 # return params 

413 

414 

415def make_mask_from_peak_cluster(group, data): 

416 mask = np.zeros(data.shape, dtype=bool) 

417 for _, peak in group.iterrows(): 

418 mask += make_mask( 

419 data, peak.X_AXISf, peak.Y_AXISf, peak.X_RADIUS, peak.Y_RADIUS 

420 ) 

421 return mask, peak 

422 

423 

424def select_reference_planes_using_indices(data, indices: List[int]): 

425 n_planes = data.shape[0] 

426 if indices == []: 

427 return data 

428 

429 max_index = max(indices) 

430 min_index = min(indices) 

431 

432 if max_index >= n_planes: 

433 raise IndexError( 

434 f"Your data has {n_planes}. You selected plane {max_index} (allowed indices between 0 and {n_planes-1})" 

435 ) 

436 elif min_index < (-1 * n_planes): 

437 raise IndexError( 

438 f"Your data has {n_planes}. You selected plane {min_index} (allowed indices between -{n_planes} and {n_planes-1})" 

439 ) 

440 else: 

441 data = data[indices] 

442 return data 

443 

444 

445def select_planes_above_threshold_from_masked_data(data, threshold=None): 

446 """This function returns planes with data above the threshold. 

447 

448 It currently uses absolute intensity values. 

449 Negative thresholds just result in return of the orignal data. 

450 

451 """ 

452 if threshold == None: 

453 selected_data = data 

454 else: 

455 selected_data = data[np.abs(data).max(axis=1) > threshold] 

456 

457 if selected_data.shape[0] == 0: 

458 selected_data = data 

459 

460 return selected_data 

461 

462 

463def validate_plane_selection(plane, pseudo3D): 

464 if (plane == []) or (plane == None): 

465 plane = list(range(pseudo3D.n_planes)) 

466 

467 elif max(plane) > (pseudo3D.n_planes - 1): 

468 raise ValueError( 

469 f"[red]There are {pseudo3D.n_planes} planes in your data you selected --plane {max(plane)}...[red]" 

470 f"plane numbering starts from 0." 

471 ) 

472 elif min(plane) < 0: 

473 raise ValueError( 

474 f"[red]Plane number can not be negative; you selected --plane {min(plane)}...[/red]" 

475 ) 

476 else: 

477 plane = sorted(plane) 

478 

479 return plane 

480 

481 

482def slice_peaks_from_data_using_mask(data, mask): 

483 peak_slices = np.array([d[mask] for d in data]) 

484 return peak_slices 

485 

486 

487def get_limits_for_axis_in_points(group_axis_points, mask_radius_in_points): 

488 max_point, min_point = ( 

489 int(np.ceil(max(group_axis_points) + mask_radius_in_points + 1)), 

490 int(np.floor(min(group_axis_points) - mask_radius_in_points)), 

491 ) 

492 return max_point, min_point 

493 

494 

495def deal_with_peaks_on_edge_of_spectrum(data_shape, max_x, min_x, max_y, min_y): 

496 if min_y < 0: 

497 min_y = 0 

498 

499 if min_x < 0: 

500 min_x = 0 

501 

502 if max_y > data_shape[-2]: 

503 max_y = data_shape[-2] 

504 

505 if max_x > data_shape[-1]: 

506 max_x = data_shape[-1] 

507 return max_x, min_x, max_y, min_y 

508 

509 

510def make_meshgrid(data_shape): 

511 # must be a better way to make the meshgrid 

512 x = np.arange(data_shape[-1]) 

513 y = np.arange(data_shape[-2]) 

514 XY = np.meshgrid(x, y) 

515 return XY 

516 

517 

518def unpack_xy_bounds(xy_bounds, peakipy_data): 

519 match xy_bounds: 

520 case (0, 0): 

521 xy_bounds = None 

522 case (x, y): 

523 # convert ppm to points 

524 xy_bounds = list(xy_bounds) 

525 xy_bounds[0] = xy_bounds[0] * peakipy_data.pt_per_ppm_f2 

526 xy_bounds[1] = xy_bounds[1] * peakipy_data.pt_per_ppm_f1 

527 case _: 

528 raise TypeError( 

529 "xy_bounds should be a tuple (<x_bounds_ppm>, <y_bounds_ppm>)" 

530 ) 

531 return xy_bounds 

532 

533 

534def select_specified_planes(plane, peakipy_data): 

535 plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]]) 

536 # only fit specified planes 

537 if plane: 

538 inds = [i for i in plane] 

539 data_inds = [ 

540 (i in inds) for i in range(peakipy_data.data.shape[peakipy_data.dims[0]]) 

541 ] 

542 plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]])[ 

543 data_inds 

544 ] 

545 peakipy_data.data = peakipy_data.data[data_inds] 

546 print( 

547 "[yellow]Using only planes {plane} data now has the following shape[/yellow]", 

548 peakipy_data.data.shape, 

549 ) 

550 if peakipy_data.data.shape[peakipy_data.dims[0]] == 0: 

551 print("[red]You have excluded all the data![/red]", peakipy_data.data.shape) 

552 exit() 

553 return plane_numbers, peakipy_data 

554 

555 

556def exclude_specified_planes(exclude_plane, peakipy_data): 

557 plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]]) 

558 # do not fit these planes 

559 if exclude_plane: 

560 inds = [i for i in exclude_plane] 

561 data_inds = [ 

562 (i not in inds) 

563 for i in range(peakipy_data.data.shape[peakipy_data.dims[0]]) 

564 ] 

565 plane_numbers = np.arange(peakipy_data.data.shape[peakipy_data.dims[0]])[ 

566 data_inds 

567 ] 

568 peakipy_data.data = peakipy_data.data[data_inds] 

569 print( 

570 f"[yellow]Excluding planes {exclude_plane} data now has the following shape[/yellow]", 

571 peakipy_data.data.shape, 

572 ) 

573 if peakipy_data.data.shape[peakipy_data.dims[0]] == 0: 

574 print("[red]You have excluded all the data![/red]", peakipy_data.data.shape) 

575 exit() 

576 return plane_numbers, peakipy_data 

577 

578 

579def get_fit_data_for_selected_peak_clusters(fits, clusters): 

580 match clusters: 

581 case None | []: 

582 pass 

583 case _: 

584 # only use these clusters 

585 fits = fits[fits.clustid.isin(clusters)] 

586 if len(fits) < 1: 

587 exit(f"Are you sure clusters {clusters} exist?") 

588 return fits 

589 

590 

591def make_masks_from_plane_data(empty_mask_array, plane_data): 

592 # make masks 

593 individual_masks = [] 

594 for cx, cy, rx, ry, name in zip( 

595 plane_data.center_x, 

596 plane_data.center_y, 

597 plane_data.x_radius, 

598 plane_data.y_radius, 

599 plane_data.assignment, 

600 ): 

601 tmp_mask = make_mask(empty_mask_array, cx, cy, rx, ry) 

602 empty_mask_array += tmp_mask 

603 individual_masks.append(tmp_mask) 

604 filled_mask_array = empty_mask_array 

605 return individual_masks, filled_mask_array 

606 

607 

608def simulate_pv_pv_lineshapes_from_fitted_peak_parameters( 

609 peak_parameters, XY, sim_data, sim_data_singles 

610): 

611 for amp, c_x, c_y, s_x, s_y, frac_x, frac_y, ls in zip( 

612 peak_parameters.amp, 

613 peak_parameters.center_x, 

614 peak_parameters.center_y, 

615 peak_parameters.sigma_x, 

616 peak_parameters.sigma_y, 

617 peak_parameters.fraction_x, 

618 peak_parameters.fraction_y, 

619 peak_parameters.lineshape, 

620 ): 

621 sim_data_i = pv_pv(XY, amp, c_x, c_y, s_x, s_y, frac_x, frac_y).reshape( 

622 sim_data.shape 

623 ) 

624 sim_data += sim_data_i 

625 sim_data_singles.append(sim_data_i) 

626 return sim_data, sim_data_singles 

627 

628 

629def simulate_lineshapes_from_fitted_peak_parameters( 

630 peak_parameters, XY, sim_data, sim_data_singles 

631): 

632 shape = sim_data.shape 

633 for amp, c_x, c_y, s_x, s_y, frac, lineshape in zip( 

634 peak_parameters.amp, 

635 peak_parameters.center_x, 

636 peak_parameters.center_y, 

637 peak_parameters.sigma_x, 

638 peak_parameters.sigma_y, 

639 peak_parameters.fraction, 

640 peak_parameters.lineshape, 

641 ): 

642 # print(amp) 

643 match lineshape: 

644 case "G" | "L" | "PV": 

645 sim_data_i = pvoigt2d(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) 

646 case "PV_L": 

647 sim_data_i = pv_l(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) 

648 

649 case "PV_G": 

650 sim_data_i = pv_g(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) 

651 

652 case "G_L": 

653 sim_data_i = gaussian_lorentzian( 

654 XY, amp, c_x, c_y, s_x, s_y, frac 

655 ).reshape(shape) 

656 

657 case "V": 

658 sim_data_i = voigt2d(XY, amp, c_x, c_y, s_x, s_y, frac).reshape(shape) 

659 sim_data += sim_data_i 

660 sim_data_singles.append(sim_data_i) 

661 return sim_data, sim_data_singles 

662 

663 

664@dataclass 

665class FitPeaksArgs: 

666 noise: float 

667 uc_dics: dict 

668 lineshape: Lineshape 

669 dims: List[int] = field(default_factory=lambda: [0, 1, 2]) 

670 colors: Tuple[str] = ("#5e3c99", "#e66101") 

671 max_cluster_size: Optional[int] = None 

672 to_fix: List[str] = field(default_factory=lambda: ["fraction", "sigma", "center"]) 

673 xy_bounds: Tuple[float, float] = ((0, 0),) 

674 vclist: Optional[Path] = (None,) 

675 plane: Optional[List[int]] = (None,) 

676 exclude_plane: Optional[List[int]] = (None,) 

677 reference_plane_indices: List[int] = ([],) 

678 initial_fit_threshold: Optional[float] = (None,) 

679 jack_knife_sample_errors: bool = False 

680 mp: bool = (True,) 

681 verbose: bool = (False,) 

682 vclist_data: Optional[np.array] = None 

683 

684 

685@dataclass 

686class Config: 

687 fit_method: str = "leastsq" 

688 

689 

690@dataclass 

691class FitPeaksInput: 

692 """input data for the fit_peaks function""" 

693 

694 args: FitPeaksArgs 

695 data: np.array 

696 config: Config 

697 plane_numbers: list 

698 

699 

700@dataclass 

701class FitPeakClusterInput: 

702 args: FitPeaksArgs 

703 data: np.array 

704 config: Config 

705 plane_numbers: list 

706 clustid: int 

707 group: pd.DataFrame 

708 last_peak: pd.DataFrame 

709 mask: np.array 

710 mod: Model 

711 p_guess: Parameters 

712 XY: np.array 

713 peak_slices: np.array 

714 XY_slices: np.array 

715 min_x: float 

716 max_x: float 

717 min_y: float 

718 max_y: float 

719 uc_dics: dict 

720 first_plane_data: np.array 

721 weights: np.array 

722 fit_method: str = "leastsq" 

723 verbose: bool = False 

724 masked_plane_data: np.array = field(init=False) 

725 

726 def __post_init__(self): 

727 self.masked_plane_data = np.array([d[self.mask] for d in self.data]) 

728 

729 

730@dataclass 

731class FitResult: 

732 out: ModelResult 

733 mask: np.array 

734 fit_str: str 

735 log: str 

736 group: pd.core.groupby.generic.DataFrameGroupBy 

737 uc_dics: dict 

738 min_x: float 

739 min_y: float 

740 max_x: float 

741 max_y: float 

742 X: np.array 

743 Y: np.array 

744 Z: np.array 

745 Z_sim: np.array 

746 peak_slices: np.array 

747 XY_slices: np.array 

748 weights: np.array 

749 mod: Model 

750 

751 def check_shifts(self): 

752 """Calculate difference between initial peak positions 

753 and check whether they moved too much from original 

754 position 

755 

756 """ 

757 pass 

758 

759 

760@dataclass 

761class FitPeaksResult: 

762 df: pd.DataFrame 

763 log: str 

764 

765 

766class FitPeaksResultDfRow(BaseModel): 

767 fit_prefix: str 

768 assignment: str 

769 amp: float 

770 amp_err: float 

771 center_x: float 

772 init_center_x: float 

773 center_y: float 

774 init_center_y: float 

775 sigma_x: float 

776 sigma_y: float 

777 clustid: int 

778 memcnt: int 

779 plane: int 

780 x_radius: float 

781 y_radius: float 

782 x_radius_ppm: float 

783 y_radius_ppm: float 

784 lineshape: str 

785 aic: float 

786 chisqr: float 

787 redchi: float 

788 residual_sum: float 

789 height: float 

790 height_err: float 

791 fwhm_x: float 

792 fwhm_y: float 

793 center_x_ppm: float 

794 center_y_ppm: float 

795 init_center_x_ppm: float 

796 init_center_y_ppm: float 

797 sigma_x_ppm: float 

798 sigma_y_ppm: float 

799 fwhm_x_ppm: float 

800 fwhm_y_ppm: float 

801 fwhm_x_hz: float 

802 fwhm_y_hz: float 

803 jack_knife_sample_index: Optional[int] 

804 

805 

806class FitPeaksResultRowGLPV(FitPeaksResultDfRow): 

807 fraction: float 

808 

809 

810class FitPeaksResultRowPVPV(FitPeaksResultDfRow): 

811 fraction_x: float # for PV_PV model 

812 fraction_y: float # for PV_PV model 

813 

814 

815class FitPeaksResultRowVoigt(FitPeaksResultDfRow): 

816 gamma_x_ppm: float # for voigt 

817 gamma_y_ppm: float # for voigt 

818 

819 

820def get_fit_peaks_result_validation_model(lineshape): 

821 match lineshape: 

822 case lineshape.V: 

823 validation_model = FitPeaksResultRowVoigt 

824 case lineshape.PV_PV: 

825 validation_model = FitPeaksResultRowPVPV 

826 case _: 

827 validation_model = FitPeaksResultRowGLPV 

828 return validation_model 

829 

830 

831def filter_peak_clusters_by_max_cluster_size(grouped_peak_clusters, max_cluster_size): 

832 filtered_peak_clusters = grouped_peak_clusters.filter( 

833 lambda x: len(x) <= max_cluster_size 

834 ) 

835 return filtered_peak_clusters 

836 

837 

838def set_parameters_to_fix_during_fit(first_plane_fit_params, to_fix): 

839 # fix sigma center and fraction parameters 

840 # could add an option to select params to fix 

841 match to_fix: 

842 case None | () | []: 

843 float_str = "Floating all parameters" 

844 parameter_set = first_plane_fit_params 

845 case ["None"] | ["none"]: 

846 float_str = "Floating all parameters" 

847 parameter_set = first_plane_fit_params 

848 case _: 

849 float_str = f"Fixing parameters: {to_fix}" 

850 parameter_set = fix_params(first_plane_fit_params, to_fix) 

851 return parameter_set, float_str 

852 

853 

854def get_default_lineshape_param_names(lineshape: Lineshape): 

855 match lineshape: 

856 case Lineshape.PV | Lineshape.G | Lineshape.L: 

857 param_names = Model(pvoigt2d).param_names 

858 case Lineshape.V: 

859 param_names = Model(voigt2d).param_names 

860 case Lineshape.PV_PV: 

861 param_names = Model(pv_pv).param_names 

862 return param_names 

863 

864 

865def split_parameter_sets_by_peak( 

866 default_param_names: List, params: List[Tuple[str, Parameter]] 

867): 

868 """params is a list of tuples where the first element of each tuple is a 

869 prefixed parameter name and the second element is the corresponding 

870 Parameter object. This is created by calling .items() on a Parameters 

871 object 

872 """ 

873 number_of_fitted_parameters = len(params) 

874 number_of_default_params = len(default_param_names) 

875 number_of_fitted_peaks = int(number_of_fitted_parameters / number_of_default_params) 

876 split_param_items = [ 

877 params[i : (i + number_of_default_params)] 

878 for i in range(0, number_of_fitted_parameters, number_of_default_params) 

879 ] 

880 assert len(split_param_items) == number_of_fitted_peaks 

881 return split_param_items 

882 

883 

884def create_parameter_dict(prefix, parameters: List[Tuple[str, Parameter]]): 

885 parameter_dict = dict(prefix=prefix) 

886 parameter_dict.update({k.replace(prefix, ""): v.value for k, v in parameters}) 

887 parameter_dict.update( 

888 {f"{k.replace(prefix,'')}_stderr": v.stderr for k, v in parameters} 

889 ) 

890 return parameter_dict 

891 

892 

893def get_prefix_from_parameter_names( 

894 default_param_names: List, parameters: List[Tuple[str, Parameter]] 

895): 

896 prefixes = [ 

897 param_key_val[0].replace(default_param_name, "") 

898 for param_key_val, default_param_name in zip(parameters, default_param_names) 

899 ] 

900 assert len(set(prefixes)) == 1 

901 return prefixes[0] 

902 

903 

904def unpack_fitted_parameters_for_lineshape( 

905 lineshape: Lineshape, params: List[dict], plane_number: int 

906): 

907 default_param_names = get_default_lineshape_param_names(lineshape) 

908 split_parameter_names = split_parameter_sets_by_peak(default_param_names, params) 

909 prefixes = [ 

910 get_prefix_from_parameter_names(default_param_names, i) 

911 for i in split_parameter_names 

912 ] 

913 unpacked_params = [] 

914 for parameter_names, prefix in zip(split_parameter_names, prefixes): 

915 parameter_dict = create_parameter_dict(prefix, parameter_names) 

916 parameter_dict.update({"plane": plane_number}) 

917 unpacked_params.append(parameter_dict) 

918 return unpacked_params 

919 

920 

921def perform_initial_lineshape_fit_on_cluster_of_peaks( 

922 fit_peak_cluster_input: FitPeakClusterInput, 

923) -> FitResult: 

924 mod = fit_peak_cluster_input.mod 

925 peak_slices = fit_peak_cluster_input.peak_slices 

926 XY_slices = fit_peak_cluster_input.XY_slices 

927 p_guess = fit_peak_cluster_input.p_guess 

928 weights = fit_peak_cluster_input.weights 

929 fit_method = fit_peak_cluster_input.fit_method 

930 mask = fit_peak_cluster_input.mask 

931 XY = fit_peak_cluster_input.XY 

932 X, Y = XY 

933 first_plane_data = fit_peak_cluster_input.first_plane_data 

934 peak = fit_peak_cluster_input.last_peak 

935 group = fit_peak_cluster_input.group 

936 min_x = fit_peak_cluster_input.min_x 

937 min_y = fit_peak_cluster_input.min_y 

938 max_x = fit_peak_cluster_input.max_x 

939 max_y = fit_peak_cluster_input.max_y 

940 verbose = fit_peak_cluster_input.verbose 

941 uc_dics = fit_peak_cluster_input.uc_dics 

942 

943 out = mod.fit( 

944 peak_slices, XY=XY_slices, params=p_guess, weights=weights, method=fit_method 

945 ) 

946 

947 if verbose: 

948 print(out.fit_report()) 

949 

950 z_sim = mod.eval(XY=XY, params=out.params) 

951 z_sim[~mask] = np.nan 

952 z_plot = first_plane_data.copy() 

953 z_plot[~mask] = np.nan 

954 fit_str = "" 

955 log = "" 

956 

957 return FitResult( 

958 out=out, 

959 mask=mask, 

960 fit_str=fit_str, 

961 log=log, 

962 group=group, 

963 uc_dics=uc_dics, 

964 min_x=min_x, 

965 min_y=min_y, 

966 max_x=max_x, 

967 max_y=max_y, 

968 X=X, 

969 Y=Y, 

970 Z=z_plot, 

971 Z_sim=z_sim, 

972 peak_slices=peak_slices, 

973 XY_slices=XY_slices, 

974 weights=weights, 

975 mod=mod, 

976 ) 

977 

978 

979def refit_peak_cluster_with_constraints( 

980 fit_input: FitPeakClusterInput, fit_result: FitPeaksResult 

981): 

982 fit_results = [] 

983 for num, d in enumerate(fit_input.masked_plane_data): 

984 plane_number = fit_input.plane_numbers[num] 

985 fit_result.out.fit( 

986 data=d, 

987 params=fit_result.out.params, 

988 weights=fit_result.weights, 

989 ) 

990 fit_results.extend( 

991 unpack_fitted_parameters_for_lineshape( 

992 fit_input.args.lineshape, 

993 list(fit_result.out.params.items()), 

994 plane_number, 

995 ) 

996 ) 

997 return fit_results 

998 

999 

1000def merge_unpacked_parameters_with_metadata(cluster_fit_df, group_of_peaks_df): 

1001 group_of_peaks_df["prefix"] = group_of_peaks_df.ASS.apply(to_prefix) 

1002 merged_cluster_fit_df = cluster_fit_df.merge( 

1003 group_of_peaks_df, on="prefix", suffixes=["", "_init"] 

1004 ) 

1005 return merged_cluster_fit_df 

1006 

1007 

1008def update_cluster_df_with_fit_statistics(cluster_df, fit_result: ModelResult): 

1009 cluster_df["chisqr"] = fit_result.chisqr 

1010 cluster_df["redchi"] = fit_result.redchi 

1011 cluster_df["residual_sum"] = np.sum(fit_result.residual) 

1012 cluster_df["aic"] = fit_result.aic 

1013 cluster_df["bic"] = fit_result.bic 

1014 cluster_df["nfev"] = fit_result.nfev 

1015 cluster_df["ndata"] = fit_result.ndata 

1016 return cluster_df 

1017 

1018 

1019def rename_columns_for_compatibility(df): 

1020 mapping = { 

1021 "amplitude": "amp", 

1022 "amplitude_stderr": "amp_err", 

1023 "X_AXIS": "init_center_x", 

1024 "Y_AXIS": "init_center_y", 

1025 "ASS": "assignment", 

1026 "MEMCNT": "memcnt", 

1027 "X_RADIUS": "x_radius", 

1028 "Y_RADIUS": "y_radius", 

1029 } 

1030 df = df.rename(columns=mapping) 

1031 return df 

1032 

1033 

1034def add_vclist_to_df(fit_input: FitPeaksInput, df: pd.DataFrame): 

1035 vclist_data = fit_input.args.vclist_data 

1036 df["vclist"] = df.plane.apply(lambda x: vclist_data[x]) 

1037 return df 

1038 

1039 

1040def prepare_group_of_peaks_for_fitting(clustid, group, fit_peaks_input: FitPeaksInput): 

1041 lineshape_function = get_lineshape_function(fit_peaks_input.args.lineshape) 

1042 

1043 first_plane_data = fit_peaks_input.data[0] 

1044 mask, peak = make_mask_from_peak_cluster(group, first_plane_data) 

1045 

1046 x_radius = group.X_RADIUS.max() 

1047 y_radius = group.Y_RADIUS.max() 

1048 

1049 max_x, min_x = get_limits_for_axis_in_points( 

1050 group_axis_points=group.X_AXISf, mask_radius_in_points=x_radius 

1051 ) 

1052 max_y, min_y = get_limits_for_axis_in_points( 

1053 group_axis_points=group.Y_AXISf, mask_radius_in_points=y_radius 

1054 ) 

1055 max_x, min_x, max_y, min_y = deal_with_peaks_on_edge_of_spectrum( 

1056 fit_peaks_input.data.shape, max_x, min_x, max_y, min_y 

1057 ) 

1058 selected_data = select_reference_planes_using_indices( 

1059 fit_peaks_input.data, fit_peaks_input.args.reference_plane_indices 

1060 ).sum(axis=0) 

1061 mod, p_guess = make_models( 

1062 lineshape_function, 

1063 group, 

1064 selected_data, 

1065 lineshape=fit_peaks_input.args.lineshape, 

1066 xy_bounds=fit_peaks_input.args.xy_bounds, 

1067 ) 

1068 peak_slices = slice_peaks_from_data_using_mask(fit_peaks_input.data, mask) 

1069 peak_slices = select_reference_planes_using_indices( 

1070 peak_slices, fit_peaks_input.args.reference_plane_indices 

1071 ) 

1072 peak_slices = select_planes_above_threshold_from_masked_data( 

1073 peak_slices, fit_peaks_input.args.initial_fit_threshold 

1074 ) 

1075 peak_slices = peak_slices.sum(axis=0) 

1076 

1077 XY = make_meshgrid(fit_peaks_input.data.shape) 

1078 X, Y = XY 

1079 

1080 XY_slices = np.array([X.copy()[mask], Y.copy()[mask]]) 

1081 weights = 1.0 / np.array([fit_peaks_input.args.noise] * len(np.ravel(peak_slices))) 

1082 return FitPeakClusterInput( 

1083 args=fit_peaks_input.args, 

1084 data=fit_peaks_input.data, 

1085 config=fit_peaks_input.config, 

1086 plane_numbers=fit_peaks_input.plane_numbers, 

1087 clustid=clustid, 

1088 group=group, 

1089 last_peak=peak, 

1090 mask=mask, 

1091 mod=mod, 

1092 p_guess=p_guess, 

1093 XY=XY, 

1094 peak_slices=peak_slices, 

1095 XY_slices=XY_slices, 

1096 weights=weights, 

1097 fit_method=Config.fit_method, 

1098 first_plane_data=first_plane_data, 

1099 uc_dics=fit_peaks_input.args.uc_dics, 

1100 min_x=min_x, 

1101 min_y=min_y, 

1102 max_x=max_x, 

1103 max_y=max_y, 

1104 verbose=fit_peaks_input.args.verbose, 

1105 ) 

1106 

1107 

1108def fit_cluster_of_peaks(data_for_fitting: FitPeakClusterInput) -> pd.DataFrame: 

1109 fit_result = perform_initial_lineshape_fit_on_cluster_of_peaks(data_for_fitting) 

1110 fit_result.out.params, float_str = set_parameters_to_fix_during_fit( 

1111 fit_result.out.params, data_for_fitting.args.to_fix 

1112 ) 

1113 fit_results = refit_peak_cluster_with_constraints(data_for_fitting, fit_result) 

1114 cluster_df = pd.DataFrame(fit_results) 

1115 cluster_df = update_cluster_df_with_fit_statistics(cluster_df, fit_result.out) 

1116 cluster_df["clustid"] = data_for_fitting.clustid 

1117 cluster_df = merge_unpacked_parameters_with_metadata( 

1118 cluster_df, data_for_fitting.group 

1119 ) 

1120 return cluster_df 

1121 

1122 

1123def fit_peak_clusters(peaks: pd.DataFrame, fit_input: FitPeaksInput) -> FitPeaksResult: 

1124 """Fit set of peak clusters to lineshape model 

1125 

1126 :param peaks: peaklist with generated by peakipy read or edit 

1127 :type peaks: pd.DataFrame 

1128 

1129 :param fit_input: Data structure containing input parameters (args, config and NMR data) 

1130 :type fit_input: FitPeaksInput 

1131 

1132 :returns: Data structure containing pd.DataFrame with the fitted results and a log 

1133 :rtype: FitPeaksResult 

1134 """ 

1135 peak_clusters = peaks.groupby("CLUSTID") 

1136 filtered_peaks = filter_peak_clusters_by_max_cluster_size( 

1137 peak_clusters, fit_input.args.max_cluster_size 

1138 ) 

1139 peak_clusters = filtered_peaks.groupby("CLUSTID") 

1140 out_str = "" 

1141 cluster_dfs = [] 

1142 for clustid, peak_cluster in peak_clusters: 

1143 data_for_fitting = prepare_group_of_peaks_for_fitting( 

1144 clustid, 

1145 peak_cluster, 

1146 fit_input, 

1147 ) 

1148 if fit_input.args.jack_knife_sample_errors: 

1149 cluster_df = jack_knife_sample_errors(data_for_fitting) 

1150 else: 

1151 cluster_df = fit_cluster_of_peaks(data_for_fitting) 

1152 cluster_dfs.append(cluster_df) 

1153 df = pd.concat(cluster_dfs, ignore_index=True) 

1154 

1155 df["lineshape"] = fit_input.args.lineshape.value 

1156 

1157 if fit_input.args.vclist: 

1158 df = add_vclist_to_df(fit_input, df) 

1159 df = rename_columns_for_compatibility(df) 

1160 return FitPeaksResult(df=df, log=out_str) 

1161 

1162 

1163def jack_knife_sample_errors(fit_input: FitPeakClusterInput) -> pd.DataFrame: 

1164 peak_slices = fit_input.peak_slices.copy() 

1165 XY_slices = fit_input.XY_slices.copy() 

1166 weights = fit_input.weights.copy() 

1167 masked_plane_data = fit_input.masked_plane_data.copy() 

1168 jk_results = [] 

1169 # first fit without jackknife 

1170 jk_result = fit_cluster_of_peaks(data_for_fitting=fit_input) 

1171 jk_result["jack_knife_sample_index"] = 0 

1172 jk_results.append(jk_result) 

1173 for i in np.arange(0, len(peak_slices), 10, dtype=int): 

1174 fit_input.peak_slices = np.delete(peak_slices, i, None) 

1175 XY_slices_0 = np.delete(XY_slices[0], i, None) 

1176 XY_slices_1 = np.delete(XY_slices[1], i, None) 

1177 fit_input.XY_slices = np.array([XY_slices_0, XY_slices_1]) 

1178 fit_input.weights = np.delete(weights, i, None) 

1179 fit_input.masked_plane_data = np.delete(masked_plane_data, i, axis=1) 

1180 jk_result = fit_cluster_of_peaks(data_for_fitting=fit_input) 

1181 jk_result["jack_knife_sample_index"] = i + 1 

1182 jk_results.append(jk_result) 

1183 return pd.concat(jk_results, ignore_index=True)