[docs]classAllowedStructuresFilter(base.IPSNode):"""Search a given dataset for outliers. Iterates all structures in the dataset, uses covalent radii to determine the atoms in each molecule, and checks if the molecule is allowed. Attributes ---------- data : list[ase.Atoms] The dataset to search. molecules : list[ase.Atoms], optional The molecules that are allowed. smiles : list[str], optional The SMILES strings of the allowed molecules. cutoffs : dict[str, float] | None, optional The cutoffs for each element. If None, use the `ase.data.covalent_radii`. Default: None """data:list[ase.Atoms]=zntrack.deps()molecules:list[ase.Atoms]=zntrack.deps(default_factory=list)smiles:list[str]=zntrack.params(default_factory=list)cutoffs:dict[str,float]|None=zntrack.params(None)fail:bool=zntrack.params(False)outliers:list[int]=zntrack.outs()
[docs]defrun(self):molecules=self.molecules+[rdkit2ase.smiles2atoms(s)forsinself.smiles]mapping=BarycenterMapping(cutoffs=self.cutoffs)outliers_set=set()foridx,atomsinenumerate(tqdm.tqdm(self.data)):_,mols=mapping.forward_mapping(atoms)formolinmols:# check if the atomic numbers are the sameifsorted(mol.get_atomic_numbers())in[sorted(m.get_atomic_numbers())forminmolecules]:continueifself.fail:raiseValueError(f"Outlier found at index {idx} for molecule {mol}")else:print(f"Outlier found at index {idx} for molecule {mol}")outliers_set.add(idx)self.outliers=list(outliers_set)