Source code for ipsuite.configuration_selection.threshold

"""Selecting atoms with a given step between them."""

import typing

import ase
import matplotlib.pyplot as plt
import numpy as np
import zntrack

from ipsuite.configuration_selection import ConfigurationSelection


[docs] def mean_reduction(values, axis): return np.mean(values, axis=axis)
[docs] def max_reduction(values, axis): return np.max(values, axis=axis)
[docs] def check_dimension(values): if values.ndim > 1: raise ValueError( f"Value dimension is {values.ndim} != 1. " "Reduce the dimension by defining dim_reduction, " "use mean or max to get (n_structures,) shape." )
REDUCTIONS = { "mean": mean_reduction, "max": max_reduction, }
[docs] class ThresholdSelection(ConfigurationSelection): """Select atoms based on a given threshold. Select atoms above a given threshold or the n_configurations with the highest / lowest value. Typically useful for uncertainty based selection. Attributes ---------- key: str The key in 'calc.results' to select from threshold: float, optional All values above (or below if negative) this threshold will be selected. If n_configurations is given, 'self.threshold' will be prioritized, but a maximum of n_configurations will be selected. reference: str, optional For visualizing the selection a reference value can be given. For 'energy_uncertainty' this would typically be 'energy'. n_configurations: int, optional Number of configurations to select. min_distance: int, optional Minimum distance between selected configurations. dim_reduction: str, optional Reduces the dimensionality of the chosen uncertainty along the specified axis by calculating either the maximum or mean value. Choose from ["max", "mean"] reduction_axis: tuple(int), optional Specifies the axis along which the reduction occurs. """ key: str = zntrack.params("energy_uncertainty") reference: str = zntrack.params("energy") threshold: float | None = zntrack.params(None) n_configurations: int | None = zntrack.params(None) min_distance: int = zntrack.params(1) dim_reduction: str = zntrack.params(None) reduction_axis: list[int] = zntrack.params((1, 2)) def __post_init__(self): if self.threshold is None and self.n_configurations is None: raise ValueError("Either 'threshold' or 'n_configurations' must not be None.")
[docs] def select_atoms( self, atoms_lst: typing.List[ase.Atoms], save_fig: bool = True ) -> typing.List[int]: """Take every nth (step) object of a given atoms list. Parameters ---------- atoms_lst: typing.List[ase.Atoms] list of atoms objects to arange Returns ------- typing.List[int]: list containing the taken indices """ self.reduction_axis = tuple(self.reduction_axis) values = np.array([atoms.calc.results[self.key] for atoms in atoms_lst]) if self.dim_reduction is not None: reduction_fn = REDUCTIONS[self.dim_reduction] values = reduction_fn(values, self.reduction_axis) check_dimension(values) if self.threshold is not None: if self.threshold < 0: indices = np.where(values < self.threshold)[0] if self.n_configurations is not None: indices = np.argsort(values)[indices] else: indices = np.where(values > self.threshold)[0] if self.n_configurations is not None: indices = np.argsort(values)[::-1][indices] else: if np.mean(values) > 0: indices = np.argsort(values)[::-1] else: indices = np.argsort(values) selection = [] for val in indices: # If the value is close to any of the already selected values, skip it. if not any(np.abs(val - np.array(selection)) < self.min_distance): selection.append(int(val)) if len(selection) == self.n_configurations: break return selection
def _get_plot(self, atoms_lst: typing.List[ase.Atoms], indices: typing.List[int]): indices = np.array(indices) values = np.array([atoms.calc.results[self.key] for atoms in atoms_lst]) if self.dim_reduction is not None: reduction_fn = REDUCTIONS[self.dim_reduction] values = reduction_fn(values, self.reduction_axis) fig, ax = plt.subplots() ax.plot(values, label=self.key) ax.plot(indices, values[indices], "x", color="red") ax.set_ylabel(self.key) ax.set_xlabel("configuration") fig.savefig(self.img_selection, bbox_inches="tight") plt.close()