Source code for ipsuite.models.torch_d3

import dataclasses
from typing import Dict, Optional, Tuple

from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.units import Bohr

try:
    import torch
    from torch import Tensor
    from torch_dftd.functions.edge_extraction import calc_edge_index
    from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
except ImportError as e:
    raise ImportError(
        "torch_dftd is not installed. You can install it using the"
        " extra 'pip install ipsuite[d3]' command."
    ) from e


class TorchDFTD3CalculatorNL(TorchDFTD3Calculator):
    def __init__(
        self,
        dft: Optional[Calculator] = None,
        atoms: Optional[Atoms] = None,
        damping: str = "zero",
        xc: str = "pbe",
        old: bool = False,
        device: str = "cpu",
        cutoff: float = 95.0 * Bohr,
        cnthr: float = 40.0 * Bohr,
        abc: bool = False,
        # --- torch dftd3 specific params ---
        dtype: torch.dtype = torch.float32,
        bidirectional: bool = True,
        cutoff_smoothing: str = "none",
        skin=0.5,
        **kwargs,
    ):
        self.skin = skin
        self.pbc = torch.tensor([False, False, False], device=device)
        self.Z = None
        self.pos0 = None
        self.edge_index = None
        self.S = None
        super().__init__(
            dft=dft,
            atoms=atoms,
            damping=damping,
            xc=xc,
            old=old,
            device=device,
            cutoff=cutoff,
            cnthr=cnthr,
            abc=abc,
            dtype=dtype,
            bidirectional=bidirectional,
            cutoff_smoothing=cutoff_smoothing,
            **kwargs,
        )

    def _calc_edge_index(
        self,
        pos: Tensor,
        cell: Optional[Tensor] = None,
        pbc: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        edge_index = calc_edge_index(
            pos,
            cell,
            pbc,
            cutoff=self.cutoff + self.skin,
            bidirectional=self.bidirectional,
        )
        return edge_index

    def _preprocess_atoms(self, atoms: Atoms) -> Dict[str, Optional[Tensor]]:
        pos = torch.tensor(atoms.get_positions(), device=self.device, dtype=self.dtype)
        Z = torch.tensor(atoms.get_atomic_numbers(), device=self.device)

        if self.pos0 is None:
            self.pos0 = torch.zeros_like(pos)
        if self.Z is None:
            self.Z = Z.clone()

        if any(atoms.pbc):
            cell: Optional[Tensor] = torch.tensor(
                atoms.get_cell(), device=self.device, dtype=self.dtype
            )
        else:
            cell = None
        pbc = torch.tensor(atoms.pbc, device=self.device)
        condition = (
            self.edge_index is None
            or torch.any(self.pbc != pbc)
            or len(self.Z) != len(Z)
            or ((self.pos0 - pos) ** 2).sum(1).max() > self.skin**2 / 4.0
        )

        if condition:
            self.edge_index, self.S = self._calc_edge_index(pos, cell, pbc)
            self.pos0 = pos
            self.pbc = pbc

        if cell is None:
            shift_pos = self.S
        else:
            shift_pos = torch.mm(self.S, cell.detach())

        input_dicts = {
            "pos": pos,
            "Z": Z,
            "cell": cell,
            "pbc": pbc,
            "edge_index": self.edge_index,
            "shift_pos": shift_pos,
        }
        return input_dicts


[docs] @dataclasses.dataclass class TorchDFTD3: """Compute D3 correction terms using torch-dftd. Attributes ---------- xc : str damping : str cutoff : float abc : bool ATM 3-body interaction cnthr : float Coordination number cutoff distance in angstrom dtype : str Data type used for the calculation. device : str Device used for the calculation. Defaults to "cuda" if available, otherwise "cpu". skin : float If > 0, switches to a D3 implementation that reuses neighborlists. This can significantly improve performance. """ xc: str damping: str cutoff: float abc: bool cnthr: float dtype: str device: str | None = None skin: float = 0.0
[docs] def get_calculator(self, **kwargs): if self.device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" if self.dtype == "float64": dtype = torch.float64 elif self.dtype == "float32": dtype = torch.float32 else: raise ValueError("dtype must be float64 or float32") if self.skin < 1e-5: calc = TorchDFTD3Calculator( xc=self.xc, damping=self.damping, cutoff=self.cutoff, abc=self.abc, cnthr=self.cnthr, dtype=dtype, atoms=None, device=self.device, ) else: calc = TorchDFTD3CalculatorNL( xc=self.xc, damping=self.damping, cutoff=self.cutoff, abc=self.abc, cnthr=self.cnthr, dtype=dtype, atoms=None, device=self.device, skin=self.skin, ) return calc