Source code for ipsuite.dynamics.checks

import collections
import dataclasses

import ase
import numpy as np
from ase.geometry import conditional_find_mic
from ase.neighborlist import build_neighbor_list, natural_cutoffs

from ipsuite import base
from ipsuite.utils.ase_sim import get_density_from_atoms, get_energy


[docs] @dataclasses.dataclass class DebugCheck(base.Check): """A check that interrupts the dynamics after a fixed amount of iterations. For testing purposes. Attributes ---------- n_iterations: int number of iterations before stopping """ n_iterations: int = 10
[docs] def initialize(self, atoms: ase.Atoms) -> None: self.counter = 0 self.is_initialized = True self.status = "n_iterations not reached"
[docs] def check(self, atoms): if self.counter >= self.n_iterations: self.status = "n_iterations reached" self.counter = 0 return True self.counter += 1 self.status = "n_iterations not reached" return False
[docs] @dataclasses.dataclass class NaNCheck(base.Check): """Check Node to see whether positions, energies or forces become NaN during a simulation. """
[docs] def initialize(self, atoms: ase.Atoms) -> None: self.is_initialized = True
[docs] def check(self, atoms: ase.Atoms) -> bool: positions = atoms.positions epot = atoms.get_potential_energy() forces = atoms.get_forces() positions_is_none = np.any(positions is None) epot_is_none = epot is None forces_is_none = np.any(forces is None) if any([positions_is_none, epot_is_none, forces_is_none]): self.status = ( "NaN check failed: last iterationpositions energy or forces = NaN" ) return True else: self.status = "No NaN occurred" return False
[docs] @dataclasses.dataclass class ConnectivityCheck(base.Check): """Check to see whether the covalent connectivity of the system changes during a simulation. The connectivity is based on ASE's natural cutoffs. The pair of atoms which triggered this check will be converted to Lithium for easy visibility """ bonded_min_dist: float = 0.6 bonded_max_dist: float = 2.0 def __post_init__(self) -> None: self.nl = None self.first_cm = None
[docs] def initialize(self, atoms): cutoffs = natural_cutoffs(atoms, mult=1.0) nl = build_neighbor_list( atoms, cutoffs=cutoffs, skin=0.0, self_interaction=False, bothways=False ) first_cm = nl.get_connectivity_matrix(sparse=True) self.indices = np.vstack(first_cm.nonzero()).T self.idx_i, self.idx_j = self.indices.T self.is_initialized = True
[docs] def check(self, atoms: ase.Atoms) -> bool: p1 = atoms.positions[self.idx_i] p2 = atoms.positions[self.idx_j] _, dists = conditional_find_mic(p1 - p2, atoms.cell, atoms.pbc) unstable = False if self.bonded_min_dist: min_dist = np.min(dists) too_close = min_dist < self.bonded_min_dist unstable = unstable or too_close if too_close: min_idx = np.argmin(dists) first_atom = self.idx_i[min_idx] second_atom = self.idx_j[min_idx] atoms.numbers[first_atom] = 3 atoms.numbers[second_atom] = 3 if self.bonded_max_dist: max_dist = np.max(dists) too_far = max_dist > self.bonded_max_dist unstable = unstable or too_far if too_far: max_idx = np.argmax(dists) first_atom = self.idx_i[max_idx] second_atom = self.idx_j[max_idx] atoms.numbers[first_atom] = 3 atoms.numbers[second_atom] = 3 if unstable: self.status = ( "Connectivity check failed: last iteration" "covalent connectivity of the system changed" ) return True else: self.status = "covalent connectivity of the system is intact" return False
[docs] @dataclasses.dataclass class EnergySpikeCheck(base.Check): """Check to see whether the potential energy of the system has fallen below a minimum or above a maximum threshold. Attributes ---------- min_factor: Simulation stops if `E(current) > E(initial) * min_factor` max_factor: Simulation stops if `E(current) < E(initial) * max_factor` """ min_factor: float = 0.5 max_factor: float = 2.0 max_energy: float | None = None min_energy: float | None = None
[docs] def initialize(self, atoms: ase.Atoms) -> None: epot = atoms.get_potential_energy() self.max_energy = epot * self.max_factor self.min_energy = epot * self.min_factor
[docs] def check(self, atoms: ase.Atoms) -> bool: epot = atoms.get_potential_energy() # energy is negative, hence sign convention if epot < self.max_energy: self.status = ( "Energy spike check failed: last iteration" f"E {epot} > E_max {self.max_energy}" ) return True elif epot > self.min_energy: self.status = ( "Energy spike check failed: last iteration" f"E {epot} < E_min {self.min_energy}" ) return True else: self.status = "No energy spike occurred" return False
[docs] @dataclasses.dataclass class TemperatureCheck(base.Check): """Calculate and check teperature during a MD simulation Attributes ---------- max_temperature: float maximum temperature, when reaching it simulation will be stopped """ max_temperature: float = 10000.0
[docs] def initialize(self, atoms: ase.Atoms) -> None: self.is_initialized = True
[docs] def check(self, atoms): self.temperature, _ = get_energy(atoms) if self.temperature > self.max_temperature: self.status = ( "Temperature Check failed last iteration" f"T {self.temperature} K > T_max {self.max_temperature} K" ) return True else: self.status = ( f"Temperature Check: T {self.temperature} K <" f"T_max {self.max_temperature} K" ) return False
[docs] @dataclasses.dataclass class ThresholdCheck(base.Check): """Calculate and check a given threshold and std during a MD simulation Compute the standard deviation of the selected property. If the property is off by more than a selected amount from the mean, the simulation will be stopped. Furthermore, the simulation will be stopped if the property exceeds a threshold value. Attributes ---------- key: str name of the property to check max_std: float, optional Maximum number of standard deviations away from the mean to stop the simulation. Roughly the value corresponds to the following percentiles: {1: 68%, 2: 95%, 3: 99.7%} window_size: int, optional Number of steps to average over max_value: float, optional Maximum value of the property to check before the simulation is stopped minimum_window_size: int, optional Minimum number of steps to average over before checking the standard deviation. Also minimum number of steps to run, before the simulation can be stopped. larger_only: bool, optional Only check the standard deviation of points that are larger than the mean. E.g. useful for uncertainties, where a lower uncertainty is not a problem. """ key: str = "energy_uncertainty" max_std: float = None window_size: int = 500 max_value: float = None minimum_window_size: int = 1 larger_only: bool = False def __post_init__(self): if self.max_std is None and self.max_value is None: raise ValueError("Either max_std or max_value must be set") self.values = collections.deque(maxlen=self.window_size)
[docs] def initialize(self, atoms: ase.Atoms) -> None: # clear the deque self.values.clear() self.status = None
[docs] def get_value(self, atoms): """Get the value of the property to check. Extracted into method so it can be subclassed. """ return np.max(atoms.calc.results[self.key])
[docs] def get_quantity(self): if self.max_value is None: return f"{self.key}-threshold-std-{self.max_std}" else: return f"{self.key}-threshold-max-{self.max_value}"
[docs] def check(self, atoms) -> bool: value = atoms.calc.results[self.key] self.values.append(value) mean = np.mean(self.values) std = np.std(self.values) distance = value - mean if self.larger_only: distance = np.abs(distance) if len(self.values) < self.minimum_window_size: return False if self.max_value is not None and np.max(value) > self.max_value: self.status = ( f"StandardDeviationCheck for {self.key} triggered by" f" '{np.max(self.values[-1]):.3f}' > max_value {self.max_value}" ) return True elif self.max_std is not None and np.max(distance) > self.max_std * std: self.status = ( f"StandardDeviationCheck for '{self.key}' triggered by" f" '{np.max(self.values[-1]):.3f}' for '{mean:.3f} +-" f" {std:.3f}' and max value '{self.max_value}'" ) return True else: self.status = ( f"StandardDeviationCheck for '{self.key}' passed with" f" '{np.max(self.values[-1]):.3f}' for '{mean:.3f} +-" f" {std:.3f}' and max value '{self.max_value}'" ) return False
[docs] @dataclasses.dataclass class DensityCheck(base.Check): max_density: float | None = None min_density: float | None = None status: str | None | bool = None
[docs] def get_quantity(self): return "Density"
[docs] def get_value(self, atoms): """Get the value of the density to check.""" return get_density_from_atoms(atoms)
[docs] def check(self, atoms: ase.Atoms) -> bool: density = get_density_from_atoms(atoms) if self.max_density is not None and density > self.max_density: self.status = ( "Density Check failed: last iteration" f" density {density} > max density {self.max_density}" ) return True elif self.min_density is not None and density < self.min_density: self.status = ( "Density Check failed: last iteration" f" density {density} < min density {self.min_density}" ) return True else: self.status = ( f"Density Check passed: density {density} " f"between min {self.min_density} and max {self.max_density}" ) return False