Source code for ipsuite.analysis.molecules

import ase
import rdkit2ase
import tqdm
import zntrack

from ipsuite import base
from ipsuite.geometry import BarycenterMapping


[docs] class AllowedStructuresFilter(base.IPSNode): """Search a given dataset for outliers. Iterates all structures in the dataset, uses covalent radii to determine the atoms in each molecule, and checks if the molecule is allowed. Attributes ---------- data : list[ase.Atoms] The dataset to search. molecules : list[ase.Atoms], optional The molecules that are allowed. smiles : list[str], optional The SMILES strings of the allowed molecules. cutoffs : dict[str, float] | None, optional The cutoffs for each element. If None, use the `ase.data.covalent_radii`. Default: None """ data: list[ase.Atoms] = zntrack.deps() molecules: list[ase.Atoms] = zntrack.deps(default_factory=list) smiles: list[str] = zntrack.params(default_factory=list) cutoffs: dict[str, float] | None = zntrack.params(None) fail: bool = zntrack.params(False) outliers: list[int] = zntrack.outs()
[docs] def run(self): molecules = self.molecules + [rdkit2ase.smiles2atoms(s) for s in self.smiles] mapping = BarycenterMapping(cutoffs=self.cutoffs) outliers_set = set() for idx, atoms in enumerate(tqdm.tqdm(self.data)): _, mols = mapping.forward_mapping(atoms) for mol in mols: # check if the atomic numbers are the same if sorted(mol.get_atomic_numbers()) in [ sorted(m.get_atomic_numbers()) for m in molecules ]: continue if self.fail: raise ValueError(f"Outlier found at index {idx} for molecule {mol}") else: print(f"Outlier found at index {idx} for molecule {mol}") outliers_set.add(idx) self.outliers = list(outliers_set)
@property def excluded_frames(self) -> list[ase.Atoms]: return [self.data[idx] for idx in self.outliers] @property def frames(self) -> list[ase.Atoms]: return [ self.data[idx] for idx in range(len(self.data)) if idx not in self.outliers ]