"""Base class for all MLModel Implementations."""importpathlibimporttypingimportase.calculators.calculatorimportase.ioimporttqdmimportzntrackfromipsuiteimportbasefromipsuite.utils.ase_simimportfreeze_copy_atoms
[docs]classMLModel(base.AnalyseAtoms):"""Parent class for all MLModel Implementations."""_name_="MLModel"use_energy:bool=zntrack.params(True)use_forces:bool=zntrack.params(True)use_stresses:bool=zntrack.params(False)
[docs]defget_calculator(self,**kwargs)->ase.calculators.calculator.Calculator:"""Get a model specific ase calculator object. Returns ------- calc: ase calculator object """raiseNotImplementedError
[docs]defpredict(self,atoms_list:typing.List[ase.Atoms])->typing.List[ase.Atoms]:"""Predict energy, forces and stresses. based on what was used to train for given atoms objects. Parameters ---------- atoms_list: typing.List[ase.Atoms] list of atoms objects to predict on Returns ------- Prediction: typing.List[ase.Atoms] Atoms with updated calculators """calc=self.get_calculator()results=[]foratomsintqdm.tqdm(atoms_list,ncols=120):atoms.calc=calcatoms.get_potential_energy()results.append(freeze_copy_atoms(atoms))returnresults
@propertydeflammps_pair_style(self)->str:"""Get the lammps pair_style command attribute. See https://docs.lammps.org/pair_style.html Returns ------- This can be e.g. 'quip' or 'allegro' """raiseNotImplementedError@propertydeflammps_pair_coeff(self)->typing.List[str]:"""Get the lammps pair_coeff command attribute. See https://docs.lammps.org/pair_coeff.html Returns ------- a list of pair_coeff attributes. E.g. [' * * model/deployed_model.pth B C F H N'] """raiseNotImplementedError
[docs]@staticmethoddefwrite_data_to_file(file,atoms_list:typing.List[ase.Atoms]):"""Save e.g. train / test data to a file. Parameters ---------- file: str|Path path to save to. atoms_list: list[Atoms] atoms that should be saved. """pathlib.Path(file).parent.mkdir(parents=True,exist_ok=True)foratominatoms_list:atom.wrap()ase.io.write(file,images=atoms_list)