Source code for ipsuite.configuration_selection.base
"""Base Node for ConfigurationSelection."""importloggingimporttypingfrompathlibimportPathimportaseimportmatplotlib.pyplotaspltimportnumpyasnpimportzntrackfromipsuiteimportbaselog=logging.getLogger(__name__)
[docs]classConfigurationSelection(base.IPSNode):"""Base Node for ConfigurationSelection. Attributes ---------- data: list[Atoms]|list[list[Atoms]]|utils.types.SupportsAtoms the data to select from exclude_configurations: dict[str, list]|utils.types.SupportsSelectedConfigurations Atoms to exclude from the exclude: list[zntrack.Node]|zntrack.Node|None Exclude the selected configurations from these nodes. """data:list[ase.Atoms]=zntrack.deps()selected_ids:list[int]=zntrack.outs(independent=True)img_selection:Path=zntrack.outs_path(zntrack.nwd/"selection.png")
[docs]defget_data(self)->list[ase.Atoms]:"""Get the atoms data to process."""ifself.dataisnotNone:returnself.dataelse:raiseValueError("No data given.")
[docs]defrun(self):"""ZnTrack Node Run method."""log.debug(f"Selecting from {len(self.data)} configurations.")self.selected_ids=self.select_atoms(self.data)self._get_plot(self.data,self.selected_ids)
[docs]defselect_atoms(self,atoms_lst:typing.List[ase.Atoms])->typing.List[int]:"""Run the selection method. Attributes ---------- atoms_lst: List[ase.Atoms] List of ase Atoms objects to select configurations from. Returns ------- List[int]: A list of the selected ids from 0 .. len(atoms_lst) """raiseNotImplementedError
@propertydefframes(self)->list[ase.Atoms]:"""Get a list of the selected atoms objects."""return[atomsfori,atomsinenumerate(self.data)ifiinself.selected_ids]@propertydefexcluded_frames(self)->list[ase.Atoms]:"""Get a list of the atoms objects that were not selected."""return[atomsfori,atomsinenumerate(self.data)ifinotinself.selected_ids]def_get_plot(self,atoms_lst:typing.List[ase.Atoms],indices:typing.List[int]):"""Plot the selected configurations."""# if energies are available, plot them, otherwise just plot indices over timefig,ax=plt.subplots()try:line_data=np.array([atoms.get_potential_energy()foratomsinatoms_lst])ax.set_ylabel("Energy")exceptException:line_data=np.arange(len(atoms_lst))ax.set_ylabel("Configuration")ax.plot(line_data)ax.scatter(indices,line_data[indices],c="r")ax.set_xlabel("Configuration")fig.savefig(self.img_selection,bbox_inches="tight")plt.close()
[docs]classBatchConfigurationSelection(ConfigurationSelection):"""Base node for BatchConfigurationSelection. Attributes ---------- data: list[ase.Atoms] The atoms data to process. This must be an input to the Node train_data: list[ase.Atoms] Batch active learning methods usually take into account the data a model was trained on. The training dataset has to be supplied with this argument. atoms: list[ase.Atoms] The processed atoms data. This is an output of the Node. It does not have to be 'field.Atoms' but can also be e.g. a 'property'. """train_data:list[ase.Atoms]=zntrack.deps()