Source code for ipsuite.analysis.model.predict

import multiprocessing
import os
import pathlib
from concurrent.futures import ProcessPoolExecutor
from typing import List, Optional

import ase
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import typing_extensions as tyex
import uncertainty_toolbox as uct
import zntrack
from ase.calculators.singlepoint import PropertyNotImplementedError

from ipsuite import base, models, utils
from ipsuite.analysis.model.math import (
    compute_uncertainty_metrics,
    decompose_stress_tensor,
    force_decomposition,
)
from ipsuite.analysis.model.plots import (  # get_cdf_figure,
    get_calibration_figure,
    get_figure,
    get_gaussianicity_figure,
    get_hist,
    slice_ensemble_uncertainty,
)
from ipsuite.geometry import BarycenterMapping
from ipsuite.utils.ase_sim import freeze_copy_atoms


[docs] @tyex.deprecated( "Use `ipsuite.ApplyCalculator` instead." " Reason: Generalization and misleading node name." ) class Prediction(base.ProcessAtoms): """Create and Save the predictions from model on atoms. Attributes ---------- model: The MLModel node that implements the 'predict' method atoms: list[Atoms] to predict properties for predictions: list[Atoms] the atoms that have the predicted properties from model """ model: models.MLModel = zntrack.deps()
[docs] def run(self): self.frames = [] calc = self.model.get_calculator() for configuration in tqdm.tqdm(self.get_data(), ncols=70): configuration: ase.Atoms # Run calculation atoms = configuration.copy() atoms.calc = calc atoms.get_potential_energy() if "stress" in calc.implemented_properties: try: atoms.get_stress() except ( PropertyNotImplementedError, ValueError, ): # required for nequip, GAP pass self.frames.append(freeze_copy_atoms(atoms))
[docs] class PredictionMetrics(base.ComparePredictions): """Compare model predictions against reference data with comprehensive metrics. Computes and visualizes prediction accuracy for energy, forces, and stress with statistical measures including MAE, RMSE, and correlation coefficients. Useful for benchmarking different models and analyzing prediction quality. Parameters ---------- x : list[ase.Atoms] Reference (ground truth) atomic configurations with calculated properties. y : list[ase.Atoms] Model predictions on the same configurations. figure_ymax : dict[str, float], optional Y-axis limits for plots by property type (energy, forces, stress). Attributes ---------- energy : dict Energy prediction metrics (MAE, RMSE, R², etc.) in meV/atom. forces : dict Force prediction metrics (MAE, RMSE, R², etc.) in meV/Å. stress : dict Stress prediction metrics in eV/ų. plots_dir : Path Directory containing generated comparison plots. Examples -------- >>> medium_model = ips.MACEMPModel(model="medium") >>> small_model = ips.MACEMPModel(model="small") >>> with project: ... data = ips.AddData(file="ethanol.xyz") ... medium_data = ips.ApplyCalculator(data=data.frames, model=medium_model) ... small_data = ips.ApplyCalculator(data=data.frames, model=small_model) ... metrics = ips.PredictionMetrics(x=medium_data.frames, y=small_data.frames) >>> project.repro() >>> print(f"Energy MAE: {metrics.energy['mae']:.2f} meV/atom") Energy MAE: 52.23 meV/atom >>> print(f"Force MAE: {metrics.forces['mae']:.2f} meV/Å") Force MAE: 213.22 meV/Å """ # TODO ADD OPTIONAL YMAX PARAMETER figure_ymax: dict[str, float] = zntrack.params(default_factory=dict) data_file: pathlib.Path = zntrack.outs_path(zntrack.nwd / "data.npz") energy: dict = zntrack.metrics() forces: dict = zntrack.metrics() stress: dict = zntrack.metrics() stress_hydro: dict = zntrack.metrics() stress_deviat: dict = zntrack.metrics() plots_dir: pathlib.Path = zntrack.outs_path(zntrack.nwd / "plots") def __post_init__(self): self.content = {} def _post_load_(self): """Load metrics - if available.""" try: with self.state.fs.open(self.data_file, "rb") as f: self.content = dict(np.load(f)) except FileNotFoundError: self.content = {}
[docs] def get_data(self): """Create dict of all data.""" true_keys = self.x[0].calc.results.keys() pred_keys = self.y[0].calc.results.keys() energy_true = [x.get_potential_energy() / len(x) for x in self.x] energy_true = np.array(energy_true) * 1000 self.content["energy_true"] = energy_true energy_prediction = [x.get_potential_energy() / len(x) for x in self.y] energy_prediction = np.array(energy_prediction) * 1000 self.content["energy_pred"] = energy_prediction self.content["energy_error"] = energy_true - energy_prediction if "forces" in true_keys and "forces" in pred_keys: true_forces = [x.get_forces() for x in self.x] true_forces = np.concatenate(true_forces, axis=0) * 1000 self.content["forces_true"] = np.reshape(true_forces, (-1,)) pred_forces = [x.get_forces() for x in self.y] pred_forces = np.concatenate(pred_forces, axis=0) * 1000 self.content["forces_pred"] = np.reshape(pred_forces, (-1,)) self.content["forces_error"] = ( self.content["forces_true"] - self.content["forces_pred"] ) if "stress" in true_keys and "stress" in pred_keys: true_stress = np.array([x.get_stress(voigt=False) for x in self.x]) pred_stress = np.array([x.get_stress(voigt=False) for x in self.y]) hydro_true, deviat_true = decompose_stress_tensor(true_stress) hydro_pred, deviat_pred = decompose_stress_tensor(pred_stress) self.content["stress_true"] = np.reshape(true_stress, (-1,)) self.content["stress_pred"] = np.reshape(pred_stress, (-1,)) self.content["stress_error"] = ( self.content["stress_true"] - self.content["stress_pred"] ) self.content["stress_hydro_true"] = np.reshape(hydro_true, (-1,)) self.content["stress_hydro_pred"] = np.reshape(hydro_pred, (-1,)) self.content["stress_hydro_error"] = ( self.content["stress_hydro_true"] - self.content["stress_hydro_pred"] ) self.content["stress_deviat_true"] = np.reshape(deviat_true, (-1,)) self.content["stress_deviat_pred"] = np.reshape(deviat_pred, (-1,)) self.content["stress_deviat_error"] = ( self.content["stress_deviat_true"] - self.content["stress_deviat_pred"] )
[docs] def get_metrics(self): """Update the metrics.""" self.energy = utils.metrics.get_full_metrics( self.content["energy_true"], self.content["energy_pred"] ) if "forces_true" in self.content.keys(): self.forces = utils.metrics.get_full_metrics( self.content["forces_true"], self.content["forces_pred"] ) else: self.forces = {} if "stress_true" in self.content.keys(): self.stress = utils.metrics.get_full_metrics( self.content["stress_true"], self.content["stress_pred"] ) self.stress_hydro = utils.metrics.get_full_metrics( self.content["stress_hydro_true"], self.content["stress_hydro_pred"] ) self.stress_deviat = utils.metrics.get_full_metrics( self.content["stress_deviat_true"], self.content["stress_deviat_pred"] ) else: self.stress = {} self.stress_hydro = {} self.stress_deviat = {}
[docs] def get_plots(self, save=False): """Create figures for all available data.""" self.plots_dir.mkdir(exist_ok=True) e_ymax = self.figure_ymax.get("energy", None) energy_plot = get_figure( self.content["energy_true"], self.content["energy_error"], datalabel=f"MAE: {self.energy['mae']:.2f} meV/atom", xlabel=r"$ab~initio$ energy $E$ / meV/atom", ylabel=r"$\Delta E$ / meV/atom", ymax=e_ymax, ) if save: energy_plot.savefig(self.plots_dir / "energy.png") if "forces_true" in self.content: xlabel = ( r"$ab~initio$ force components per atom $F_{alpha,i}$ / meV$ \cdot" r" \AA^{-1}$" ) ylabel = r"$\Delta F_{alpha,i}$ / meV$ \cdot \AA^{-1}$" f_ymax = self.figure_ymax.get("forces", None) forces_plot = get_figure( self.content["forces_true"], self.content["forces_error"], datalabel=rf"MAE: {self.forces['mae']:.2f} meV$ / \AA$", xlabel=xlabel, ylabel=ylabel, ymax=f_ymax, ) if save: forces_plot.savefig(self.plots_dir / "forces.png") if "stress_true" in self.content: s_true = self.content["stress_true"] s_error = self.content["stress_error"] shydro_true = self.content["stress_hydro_true"] shydro_error = self.content["stress_hydro_error"] sdeviat_true = self.content["stress_deviat_true"] sdeviat_error = self.content["stress_deviat_error"] s_ymax = self.figure_ymax.get("stress", None) hs_ymax = self.figure_ymax.get("stress_hydro", None) ds_ymax = self.figure_ymax.get("stress_deviat", None) stress_plot = get_figure( s_true, s_error, datalabel=rf"Max: {self.stress['max']:.4f}", xlabel=r"$ab~initio$ stress", ylabel=r"$\Delta$ stress", ymax=s_ymax, ) hydrostatic_stress_plot = get_figure( shydro_true, shydro_error, datalabel=rf"Max: {self.stress_hydro['max']:.4f}", xlabel=r"$ab~initio$ hydrostatic stress", ylabel=r"$\Delta$ hydrostatic stress", ymax=hs_ymax, ) deviatoric_stress_plot = get_figure( sdeviat_true, sdeviat_error, datalabel=rf"Max: {self.stress_deviat['max']:.4f}", xlabel=r"$ab~initio$ deviatoric stress", ylabel=r"$\Delta$ deviatoric stress", ymax=ds_ymax, ) if save: stress_plot.savefig(self.plots_dir / "stress.png") hydrostatic_stress_plot.savefig(self.plots_dir / "hydrostatic_stress.png") deviatoric_stress_plot.savefig(self.plots_dir / "deviatoric_stress.png")
[docs] def run(self): self.nwd.mkdir(exist_ok=True, parents=True) self.get_data() np.savez(self.data_file, **self.content) self.get_metrics() self.get_plots(save=True)
[docs] def get_content(self): with self.state.fs.open(self.data_file, mode="rb") as f: content = dict(np.load(f)) return content
[docs] class CalibrationMetrics(base.ComparePredictions): """Analyse the calibration of a models uncertainty estimate. Plots the empirical vs predicted error distribution, a log-log calibration plot and the miscalibration area. Further, various UQ metrics are computed: - Mean absolute calibration error - Root mean square miscalibration error - Miscalibration area - NLL - RLL For more information checkout the uncertainty toolbox or the following paper: 10.1088/2632-2153/ad594a Parameters ---------- force_dist_slices: List[tuple] Interval in which to analyse the gassianity of error distributions. """ force_dist_slices: Optional[List[tuple]] = zntrack.params(None) data_file: pathlib.Path = zntrack.outs_path(zntrack.nwd / "data.npz") energy: dict = zntrack.metrics() forces: dict = zntrack.metrics() plots_dir: pathlib.Path = zntrack.outs_path(zntrack.nwd / "plots") def __post_init__(self): self.content = {} self.force_dist_slices = [] def _post_load_(self): """Load metrics - if available.""" try: with self.state.fs.open(self.data_file, "rb") as f: self.content = dict(np.load(f)) except FileNotFoundError: self.content = {}
[docs] def get_data(self): """Create dict of all data.""" pred_keys = self.y[0].calc.results.keys() energy_true = [a.get_potential_energy() / len(a) for a in self.x] energy_true = np.array(energy_true) * 1000 self.content["energy_true"] = energy_true energy_pred = [a.get_potential_energy() / len(a) for a in self.y] energy_pred = np.array(energy_pred) * 1000 self.content["energy_pred"] = energy_pred energy_uncertainty = [ a.calc.results["energy_uncertainty"] / len(a) for a in self.y ] energy_uncertainty = np.array(energy_uncertainty) * 1000 self.content["energy_unc"] = energy_uncertainty if "forces" in pred_keys: true_forces = [a.get_forces() for a in self.x] true_forces = np.concatenate(true_forces, axis=0) * 1000 pred_forces = [a.get_forces() for a in self.y] pred_forces = np.concatenate(pred_forces, axis=0) * 1000 forces_uncertainty = [a.calc.results["forces_uncertainty"] for a in self.y] forces_uncertainty = np.concatenate(forces_uncertainty, axis=0) * 1000 self.content["forces_true"] = np.reshape(true_forces, (-1,)) self.content["forces_pred"] = np.reshape(pred_forces, (-1,)) self.content["forces_unc"] = np.reshape(forces_uncertainty, (-1,)) if "forces_ensemble" in self.y[0].calc.results.keys(): n_ens = self.y[0].calc.results["forces_ensemble"].shape[2] forces_ensemble = [ np.reshape(a.calc.results["forces_ensemble"], (-1, n_ens)) for a in self.y ] forces_ensemble = np.concatenate(forces_ensemble, axis=0) * 1000 self.content["forces_ensemble"] = forces_ensemble
[docs] def get_metrics(self): """Update the metrics.""" e_pred = self.content["energy_pred"] e_std = self.content["energy_unc"] e_true = self.content["energy_true"] metrics = compute_uncertainty_metrics(e_pred, e_std, e_true) self.energy = metrics if "forces_unc" in self.content: f_pred = self.content["forces_pred"] f_std = self.content["forces_unc"] f_true = self.content["forces_true"] metrics = compute_uncertainty_metrics(f_pred, f_std, f_true) self.forces = metrics
[docs] def get_plots(self, save=False): """Create figures for all available data.""" self.plots_dir.mkdir(exist_ok=True) e_err = np.abs(self.content["energy_pred"] - self.content["energy_true"]) energy_plot = get_calibration_figure( e_err, self.content["energy_unc"], markersize=10, datalabel=rf"RLL={self.energy['rll']:.1f}", forces=False, ) energy_gauss = get_gaussianicity_figure( e_err, self.content["energy_unc"], forces=False ) energy_cdf_plot, e_cdf_ax = plt.subplots() e_cdf_ax = uct.plot_calibration( self.content["energy_pred"], self.content["energy_unc"], self.content["energy_true"], ax=e_cdf_ax, ) if save: energy_plot.savefig(self.plots_dir / "energy.png") energy_gauss.savefig(self.plots_dir / "energy_gaussianicity.png") energy_cdf_plot.savefig(self.plots_dir / "energy_cdf.png") if "forces_unc" in self.content: f_err = np.abs(self.content["forces_pred"] - self.content["forces_true"]) f_err = np.reshape(f_err, (-1,)) forces_plot = get_calibration_figure( f_err, self.content["forces_unc"], datalabel=rf"RLL={self.forces['rll']:.1f}", forces=True, ) forces_cdf_plot, f_cdf_ax = plt.subplots() f_cdf_ax = uct.plot_calibration( self.content["forces_pred"], self.content["forces_unc"], self.content["forces_true"], ax=f_cdf_ax, ) gaussianicy_figures = [] if "forces_ensemble" in self.content.keys(): for start, end in self.force_dist_slices: error_true, error_pred = slice_ensemble_uncertainty( self.content["forces_true"], self.content["forces_ensemble"], start, end, ) fig = get_gaussianicity_figure(error_true, error_pred, forces=True) gaussianicy_figures.append(fig) if save: forces_plot.savefig(self.plots_dir / "forces.png") forces_cdf_plot.savefig(self.plots_dir / "forces_cdf.png") for ii, fig in enumerate(gaussianicy_figures): fig.savefig(self.plots_dir / f"forces_gaussianicity_{ii}.png")
[docs] def run(self): self.nwd.mkdir(exist_ok=True, parents=True) self.get_data() np.savez(self.data_file, **self.content) self.get_metrics() self.get_plots(save=True)
[docs] class ForceAngles(base.ComparePredictions): plot: pathlib.Path = zntrack.outs_path(zntrack.nwd / "angle.png") log_plot: pathlib.Path = zntrack.outs_path(zntrack.nwd / "angle_ylog.png") angles: dict = zntrack.metrics()
[docs] def run(self): true_forces = np.reshape([a.get_forces() for a in self.x], (-1, 3)) pred_forces = np.reshape([a.get_forces() for a in self.y], (-1, 3)) angles = utils.metrics.get_angles(true_forces, pred_forces) self.angles = { "rmse": utils.metrics.calculate_l_p_norm(np.zeros_like(angles), angles, p=2), "lp4": utils.metrics.calculate_l_p_norm(np.zeros_like(angles), angles, p=4), "max": utils.metrics.maximum_error(np.zeros_like(angles), angles), "mae": utils.metrics.calculate_l_p_norm(np.zeros_like(angles), angles, p=1), } fig, ax = get_hist( data=angles, label=rf"MAE: ${self.angles['mae']:.2f}^\circ$", xlabel=r"Angle between true and predicted forces $\theta / ^\circ$", ylabel="Probability / %", ) fig.savefig(self.plot) ax.set_yscale("log") fig.savefig(self.log_plot)
[docs] class ForceDecomposition(base.ComparePredictions): """Node for decomposing forces in a system of molecular units into translational, rotational and vibrational components. The implementation follows the method described in https://doi.org/10.26434/chemrxiv-2022-l4tb9 Currently, single atoms and diatomic molecules are simply filtered out. Please raise an issue if you need those cases to be treated correctly in your work. Attributes ---------- wasserstein_distance: float Compute the wasserstein distance between the distributions of the predicted and true forces for each trans, rot, vib component. """ trans_forces: dict = zntrack.metrics() rot_forces: dict = zntrack.metrics() vib_forces: dict = zntrack.metrics() wasserstein_distance: dict = zntrack.metrics() rot_force_plt: pathlib.Path = zntrack.outs_path(zntrack.nwd / "rot_force.png") trans_force_plt: pathlib.Path = zntrack.outs_path(zntrack.nwd / "trans_force.png") vib_force_plt: pathlib.Path = zntrack.outs_path(zntrack.nwd / "vib_force.png") histogram_plt: pathlib.Path = zntrack.outs_path(zntrack.nwd / "histogram.png")
[docs] def get_plots(self): true_trans = np.reshape(self.true_forces["trans"], -1) pred_trans = np.reshape(self.pred_forces["trans"], -1) fig = get_figure( true_trans, true_trans - pred_trans, datalabel=rf"Trans. MAE: {self.trans_forces['mae']:.2f} meV$ / \AA$", xlabel=r"$ab~initio$ forces / meV$ \cdot \AA^{-1}$", ylabel=r"$\Delta F_{alpha,i,trans}$ / meV$ \cdot \AA^{-1}$", ) fig.savefig(self.trans_force_plt) true_rot = np.reshape(self.true_forces["rot"], -1) pred_rot = np.reshape(self.pred_forces["rot"], -1) fig = get_figure( true_rot, true_rot - pred_rot, datalabel=rf"Rot. MAE: {self.rot_forces['mae']:.2f} meV$ / \AA$", xlabel=r"$ab~initio$ forces / meV$ \cdot \AA^{-1}$", ylabel=r"$\Delta F_{alpha,i,rot}$ / meV$ \cdot \AA^{-1}$", ) fig.savefig(self.rot_force_plt) true_vib = np.reshape(self.true_forces["vib"], -1) pred_vib = np.reshape(self.pred_forces["vib"], -1) fig = get_figure( true_vib, true_vib - pred_vib, datalabel=rf"Vib. MAE: {self.vib_forces['mae']:.2f} meV$ / \AA$", xlabel=r"$ab~initio$ forces / meV$ \cdot \AA^{-1}$", ylabel=r"$\Delta F_{alpha,i,vib}$ / meV$ \cdot \AA^{-1}$", ) fig.savefig(self.vib_force_plt)
[docs] def get_metrics(self): """Update the metrics.""" self.trans_forces = utils.metrics.get_full_metrics( self.true_forces["trans"], self.pred_forces["trans"] ) self.rot_forces = utils.metrics.get_full_metrics( self.true_forces["rot"], self.pred_forces["rot"] ) self.vib_forces = utils.metrics.get_full_metrics( self.true_forces["vib"], self.pred_forces["vib"] )
[docs] def get_histogram(self): import matplotlib.pyplot as plt from scipy.stats import wasserstein_distance def get_rel_scalar_prod(main, relative) -> np.ndarray: x = np.einsum("ij,ij->i", main, relative) x /= np.linalg.norm(main, axis=-1) return x fig, axes = plt.subplots(4, 3, figsize=(4 * 5, 3 * 3)) fig.suptitle( ( r"A fraction $\dfrac{\vec{a} \cdot" r" \vec{b}}{\left|\left|\vec{a}\right|\right|_{2}} $ of $\vec{b}$ that" r" contributes to $\vec{a}$" ), fontsize=16, ) self.wasserstein_distance = {} for label, ax_ in zip(self.true_forces.keys(), axes): self.wasserstein_distance[label] = {} for part, ax in zip(["vib", "rot", "trans"], ax_): data = get_rel_scalar_prod( self.true_forces[label], self.true_forces[part] ) true_bins = ax.hist( data, bins=50, density=True, label=f"true {label} {part}" ) data = get_rel_scalar_prod( self.pred_forces[label], self.pred_forces[part] ) pred_bins = ax.hist( data, bins=true_bins[1], density=True, alpha=0.5, label=f"pred {label} {part}", ) ax.legend() self.wasserstein_distance[label][part] = wasserstein_distance( true_bins[0], pred_bins[0] ) fig.savefig(self.histogram_plt, bbox_inches="tight")
[docs] def run(self): mapping = BarycenterMapping() # TODO make the force_decomposition return full forces # TODO check if you sum the forces they yield the full forces # TODO make mapping a 'zn.nodes' with Mapping(species="BF4") # maybe allow smiles and enumeration 0, 1, ... self.true_forces = {"all": [], "trans": [], "rot": [], "vib": []} self.pred_forces = {"all": [], "trans": [], "rot": [], "vib": []} for atom in tqdm.tqdm(self.x, ncols=70): atom_trans_forces, atom_rot_forces, atom_vib_forces = force_decomposition( atom, mapping ) self.true_forces["all"].append(atom.get_forces()) self.true_forces["trans"].append(atom_trans_forces) self.true_forces["rot"].append(atom_rot_forces) self.true_forces["vib"].append(atom_vib_forces) for atom in tqdm.tqdm(self.y, ncols=70): atom_trans_forces, atom_rot_forces, atom_vib_forces = force_decomposition( atom, mapping ) self.pred_forces["all"].append(atom.get_forces()) self.pred_forces["trans"].append(atom_trans_forces) self.pred_forces["rot"].append(atom_rot_forces) self.pred_forces["vib"].append(atom_vib_forces) self.pred_forces = { k: np.concatenate(v) * 1000 for k, v in self.pred_forces.items() } self.true_forces = { k: np.concatenate(v) * 1000 for k, v in self.true_forces.items() } self.get_metrics() self.get_plots() self.get_histogram()
[docs] def decompose_force_uncertainty(atom_true, atom_pred): mapping = BarycenterMapping(frozen=True) trans_true, rot_true, vib_true = force_decomposition( atom_true, mapping, key="forces", ) trans_ens, rot_ens, vib_ens = force_decomposition( atom_pred, mapping, key="forces_ensemble", ) n_ens = trans_ens.shape[2] trans_pred = np.mean(trans_ens, axis=-1) rot_pred = np.mean(rot_ens, axis=-1) vib_pred = np.mean(vib_ens, axis=-1) trans_unc = np.sum((trans_ens - trans_pred[:, :, None]) ** 2, axis=-1) / (n_ens - 1) rot_unc = np.sum((rot_ens - rot_pred[:, :, None]) ** 2, axis=-1) / (n_ens - 1) vib_unc = np.sum((vib_ens - vib_pred[:, :, None]) ** 2, axis=-1) / (n_ens - 1) true = (trans_true, rot_true, vib_true) pred = (trans_pred, rot_pred, vib_pred) unc = (trans_unc, rot_unc, vib_unc) return true, pred, unc
[docs] class ForceUncertaintyDecomposition(base.ComparePredictions): """Node for decomposing force uncertainties in a system of molecular units into translational, rotational and vibrational components. The implementation follows the method described in https://doi.org/10.26434/chemrxiv-2022-l4tb9 Currently, single atoms and diatomic molecules are simply filtered out. Please raise an issue if you need those cases to be treated correctly in your work. """ trans_forces: dict = zntrack.metrics() rot_forces: dict = zntrack.metrics() vib_forces: dict = zntrack.metrics() plots_dir: pathlib.Path = zntrack.outs_path(zntrack.nwd / "plots")
[docs] def get_plots(self): self.plots_dir.mkdir(exist_ok=True) trans_err = np.abs(self.f_true["trans"] - self.f_pred["trans"]) rot_err = np.abs(self.f_true["rot"] - self.f_pred["rot"]) vib_err = np.abs(self.f_true["vib"] - self.f_pred["vib"]) trans_plot = get_calibration_figure( trans_err, self.f_unc["trans"], markersize=5, datalabel=rf"RLL={self.trans_forces['rll']:.1f}", forces=True, ) trans_gauss = get_gaussianicity_figure( trans_err, self.f_unc["trans"], forces=True ) trans_plot.savefig(self.plots_dir / "trans.png") trans_gauss.savefig(self.plots_dir / "trans_gauss.png") rot_plot = get_calibration_figure( rot_err, self.f_unc["rot"], markersize=5, datalabel=rf"RLL={self.rot_forces['rll']:.1f}", forces=True, ) rot_gauss = get_gaussianicity_figure(rot_err, self.f_unc["rot"], forces=True) rot_plot.savefig(self.plots_dir / "rot.png") rot_gauss.savefig(self.plots_dir / "rot_gauss.png") vib_plot = get_calibration_figure( vib_err, self.f_unc["vib"], markersize=5, datalabel=rf"RLL={self.vib_forces['rll']:.1f}", forces=True, ) vib_gauss = get_gaussianicity_figure(vib_err, self.f_unc["vib"], forces=True) vib_plot.savefig(self.plots_dir / "vib.png") vib_gauss.savefig(self.plots_dir / "vib_gauss.png")
[docs] def get_metrics(self): """Update the metrics.""" metrics = compute_uncertainty_metrics( self.f_pred["trans"], self.f_unc["trans"], self.f_true["trans"] ) self.trans_forces = metrics metrics = compute_uncertainty_metrics( self.f_pred["rot"], self.f_unc["rot"], self.f_true["rot"] ) self.rot_forces = metrics metrics = compute_uncertainty_metrics( self.f_pred["vib"], self.f_unc["vib"], self.f_true["vib"] ) self.vib_forces = metrics
[docs] def run(self): self.f_true = {"trans": [], "rot": [], "vib": []} self.f_pred = {"trans": [], "rot": [], "vib": []} self.f_unc = {"trans": [], "rot": [], "vib": []} nproc = os.getenv("IPSUITE_NPROC", multiprocessing.cpu_count() - 1) process_pool = ProcessPoolExecutor(nproc) pbar = tqdm.trange( 0, len(self.x), desc="structures", ncols=70, leave=True, mininterval=0.25, ) for result in process_pool.map(decompose_force_uncertainty, self.x, self.y): y_true, y_pred, y_unc = result trans_true, rot_true, vib_true = y_true trans_pred, rot_pred, vib_pred = y_pred trans_unc, rot_unc, vib_unc = y_unc self.f_true["trans"].append(trans_true) self.f_true["rot"].append(rot_true) self.f_true["vib"].append(vib_true) self.f_pred["trans"].append(trans_pred) self.f_pred["rot"].append(rot_pred) self.f_pred["vib"].append(vib_pred) self.f_unc["trans"].append(trans_unc) self.f_unc["rot"].append(rot_unc) self.f_unc["vib"].append(vib_unc) pbar.update(1) self.f_true = { k: np.reshape(np.concatenate(v), (-1,)) * 1000 for k, v in self.f_true.items() } self.f_pred = { k: np.reshape(np.concatenate(v), (-1,)) * 1000 for k, v in self.f_pred.items() } self.f_unc = { k: np.reshape(np.concatenate(v), (-1,)) * 1000 for k, v in self.f_unc.items() } self.get_metrics() self.get_plots()