Source code for tad_dftd3.reference

# This file is part of tad-dftd3.
# SPDX-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reference model
===============

This module defines the reference systems for the D3 model to compute the
C6 dispersion coefficients.
"""
import os.path as op
from typing import Optional

import torch
from tad_mctc._version import __tversion__

from .typing import Any, NoReturn, Tensor, get_default_device, get_default_dtype

__all__ = ["Reference"]


def _load_cn(
    dtype: torch.dtype = torch.double, device: Optional[torch.device] = None
) -> Tensor:
    """
    Load reference coordination numbers.

    Parameters
    ----------
    dtype : torch.dtype, optional
        Floating point precision for tensor. Defaults to `torch.double`.
    device : Optional[torch.device], optional
        Device of tensor. Defaults to None.

    Returns
    -------
    Tensor
        Reference coordination numbers.
    """
    # fmt: off
    return torch.tensor(
        [
            [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # None
            [+0.9118, +0.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # H
            [+0.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # He
            [+0.0000, +0.9865, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Li
            [+0.0000, +0.9808, +1.9697, -1.0000, -1.0000, -1.0000, -1.0000],  # Be
            [+0.0000, +0.9706, +1.9441, +2.9128, +4.5856, -1.0000, -1.0000],  # B
            [+0.0000, +0.9868, +1.9985, +2.9987, +3.9844, -1.0000, -1.0000],  # C
            [+0.0000, +0.9944, +2.0143, +2.9903, -1.0000, -1.0000, -1.0000],  # N
            [+0.0000, +0.9925, +1.9887, -1.0000, -1.0000, -1.0000, -1.0000],  # O
            [+0.0000, +0.9982, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # F
            [+0.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Ne
            [+0.0000, +0.9684, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Na
            [+0.0000, +0.9628, +1.9496, -1.0000, -1.0000, -1.0000, -1.0000],  # Mg
            [+0.0000, +0.9648, +1.9311, +2.9146, -1.0000, -1.0000, -1.0000],  # Al
            [+0.0000, +0.9507, +1.9435, +2.9407, +3.8677, -1.0000, -1.0000],  # Si
            [+0.0000, +0.9947, +2.0102, +2.9859, -1.0000, -1.0000, -1.0000],  # P
            [+0.0000, +0.9948, +1.9903, -1.0000, -1.0000, -1.0000, -1.0000],  # S
            [+0.0000, +0.9972, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Cl
            [+0.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Ar
            [+0.0000, +0.9767, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # K
            [+0.0000, +0.9831, +1.9349, -1.0000, -1.0000, -1.0000, -1.0000],  # Ca
            [+0.0000, +1.8627, +2.8999, -1.0000, -1.0000, -1.0000, -1.0000],  # Sc
            [+0.0000, +1.8299, +3.8675, -1.0000, -1.0000, -1.0000, -1.0000],  # Ti
            [+0.0000, +1.9138, +2.9110, -1.0000, -1.0000, -1.0000, -1.0000],  # V
            [+0.0000, +1.8269, 10.6191, -1.0000, -1.0000, -1.0000, -1.0000],  # Cr
            [+0.0000, +1.6406, +9.8849, -1.0000, -1.0000, -1.0000, -1.0000],  # Mn
            [+0.0000, +1.6483, +9.1376, -1.0000, -1.0000, -1.0000, -1.0000],  # Fe
            [+0.0000, +1.7149, +2.9263, +7.7785, -1.0000, -1.0000, -1.0000],  # Co
            [+0.0000, +1.7937, +6.5458, +6.2918, -1.0000, -1.0000, -1.0000],  # Ni
            [+0.0000, +0.9576, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Cu
            [+0.0000, +1.9419, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Zn
            [+0.0000, +0.9601, +1.9315, +2.9233, -1.0000, -1.0000, -1.0000],  # Ga
            [+0.0000, +0.9434, +1.9447, +2.9186, +3.8972, -1.0000, -1.0000],  # Ge
            [+0.0000, +0.9889, +1.9793, +2.9709, -1.0000, -1.0000, -1.0000],  # As
            [+0.0000, +0.9901, +1.9812, -1.0000, -1.0000, -1.0000, -1.0000],  # Se
            [+0.0000, +0.9974, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Br
            [+0.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Kr
            [+0.0000, +0.9738, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Rb
            [+0.0000, +0.9801, +1.9143, -1.0000, -1.0000, -1.0000, -1.0000],  # Sr
            [+0.0000, +1.9153, +2.8903, -1.0000, -1.0000, -1.0000, -1.0000],  # Y
            [+0.0000, +1.9355, +3.9106, -1.0000, -1.0000, -1.0000, -1.0000],  # Zr
            [+0.0000, +1.9545, +2.9225, -1.0000, -1.0000, -1.0000, -1.0000],  # Nb
            [+0.0000, +1.9420, 11.0556, -1.0000, -1.0000, -1.0000, -1.0000],  # Mo
            [+0.0000, +1.6682, +9.5402, -1.0000, -1.0000, -1.0000, -1.0000],  # Tc
            [+0.0000, +1.8584, +8.8895, -1.0000, -1.0000, -1.0000, -1.0000],  # Ru
            [+0.0000, +1.9003, +2.9696, -1.0000, -1.0000, -1.0000, -1.0000],  # Rh
            [+0.0000, +1.8630, +5.7095, -1.0000, -1.0000, -1.0000, -1.0000],  # Pd
            [+0.0000, +0.9679, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Ag
            [+0.0000, +1.9539, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Cd
            [+0.0000, +0.9633, +1.9378, +2.9353, -1.0000, -1.0000, -1.0000],  # In
            [+0.0000, +0.9514, +1.9505, +2.9259, +3.9123, -1.0000, -1.0000],  # Sn
            [+0.0000, +0.9749, +1.9523, +2.9315, -1.0000, -1.0000, -1.0000],  # Sb
            [+0.0000, +0.9811, +1.9639, -1.0000, -1.0000, -1.0000, -1.0000],  # Te
            [+0.0000, +0.9968, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # I
            [+0.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Xe
            [+0.0000, +0.9909, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Cs
            [+0.0000, +0.9797, +1.8467, -1.0000, -1.0000, -1.0000, -1.0000],  # Ba
            [+0.0000, +1.9373, +2.9175, -1.0000, -1.0000, -1.0000, -1.0000],  # La
            [+2.7991, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Ce
            [+0.0000, +2.9425, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Pr
            [+0.0000, +2.9455, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Nd
            [+0.0000, +2.9413, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Pm
            [+0.0000, +2.9300, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Sm
            [+0.0000, +1.8286, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Eu
            [+0.0000, +2.8732, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Gd
            [+0.0000, +2.9086, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Tb
            [+0.0000, +2.8965, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Dy
            [+0.0000, +2.9242, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Ho
            [+0.0000, +2.9282, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Er
            [+0.0000, +2.9246, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Tm
            [+0.0000, +2.8482, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Yb
            [+0.0000, +2.9219, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Lu
            [+0.0000, +1.9254, +3.8840, -1.0000, -1.0000, -1.0000, -1.0000],  # Hf
            [+0.0000, +1.9459, +2.8988, -1.0000, -1.0000, -1.0000, -1.0000],  # Ta
            [+0.0000, +1.9292, 10.9153, -1.0000, -1.0000, -1.0000, -1.0000],  # W
            [+0.0000, +1.8104, +9.8054, -1.0000, -1.0000, -1.0000, -1.0000],  # Re
            [+0.0000, +1.8858, +9.1527, -1.0000, -1.0000, -1.0000, -1.0000],  # Os
            [+0.0000, +1.8648, +2.9424, -1.0000, -1.0000, -1.0000, -1.0000],  # Ir
            [+0.0000, +1.9188, +6.6669, -1.0000, -1.0000, -1.0000, -1.0000],  # Pt
            [+0.0000, +0.9846, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Au
            [+0.0000, +1.9896, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Hg
            [+0.0000, +0.9267, +1.9302, +2.9420, -1.0000, -1.0000, -1.0000],  # Tl
            [+0.0000, +0.9383, +1.9356, +2.9081, +3.9098, -1.0000, -1.0000],  # Pb
            [+0.0000, +0.9820, +1.9655, +2.9500, -1.0000, -1.0000, -1.0000],  # Bi
            [+0.0000, +0.9815, +1.9639, -1.0000, -1.0000, -1.0000, -1.0000],  # Po
            [+0.0000, +0.9954, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # At
            [+0.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Rn
            [+0.0000, +0.9705, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],  # Fr
            [+0.0000, +0.9661, +1.9251, -1.0000, -1.0000, -1.0000, -1.0000],  # Ra
            [+0.0000, +0.9802, +1.9445, +2.9070, +3.8174, +4.6723, +5.5599],  # Ac
            [+0.0000, +0.9847, +1.9560, +2.9302, +3.8997, -1.0000, -1.0000],  # Th
            [+0.0000, +0.9647, +1.9079, +2.9037, +3.8711, +4.9094, +4.5318],  # Pa
            [+0.0000, +0.9766, +2.8888, +3.9129, +4.1181, +5.9187, -1.0000],  # U
            [+0.0000, +0.9838, +1.9499, +2.9159, +3.9358, +4.9069, +5.9005],  # Np
            [+0.0000, +0.9537, +1.9439, +2.9323, +3.9441, +4.9192, +5.8888],  # Pu
            [+0.0000, +0.9163, +1.8563, +2.8823, +4.8005, +5.7794, -1.0000],  # Am
            [+0.0000, +0.9762, +1.9288, +2.8929, +3.8167, +4.7478, +5.6866],  # Cm
            [+0.0000, +0.9705, +1.9511, +2.9262, +3.9342, -1.0000, -1.0000],  # Bk
            [+0.0000, +0.9581, +1.9123, +2.9327, +3.9105, +5.8285, -1.0000],  # Cf
            [+0.0000, +0.9346, +1.8816, +2.9075, +3.8705, +4.8131, +5.7244],  # Es
            [+0.0000, +0.9500, +1.9165, +2.9377, +3.8956, +4.8540, +5.8160],  # Fm
            [+0.0000, +0.9710, +1.9564, +2.9515, +3.9353, -1.0000, -1.0000],  # Md
            [+0.0000, +0.9722, +1.9605, +2.9452, +3.9296, +4.2582, +4.5511],  # No
            [+0.0000, +0.9569, +1.9215, +2.8958, +3.7644, +4.6808, +5.5939],  # Lr
        ],
        device=device,
        dtype=dtype,
    )
    # fmt: on


def _load_c6(
    dtype: torch.dtype = torch.double, device: Optional[torch.device] = None
) -> Tensor:
    """
    Load reference C6 coefficients from file.

    Parameters
    ----------
    dtype : torch.dtype, optional
        Floating point precision for tensor. Defaults to `torch.double`.
    device : Optional[torch.device], optional
        Device of tensor. Defaults to None.

    Returns
    -------
    Tensor
        Reference C6 coefficients.
    """
    kwargs: dict[str, Any] = {"map_location": device}
    if __tversion__ > (1, 12, 1):  # pragma: no cover
        kwargs["weights_only"] = True

    path = op.join(op.dirname(__file__), "reference-c6.pt")
    return torch.load(path, **kwargs).type(dtype=dtype)


[docs] class Reference: """ Reference systems for the D3 dispersion model """ c6: Tensor """C6 coefficients for all pairs of reference systems""" cn: Tensor """Coordination numbers for all reference systems""" __slots__ = [ "c6", "cn", "__dtype", "__device", ] def __init__( self, cn: Optional[Tensor] = None, c6: Optional[Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): if cn is None: cn = _load_cn( dtype=dtype if dtype is not None else get_default_dtype(), device=device if device is not None else get_default_device(), ) self.cn = cn if c6 is None: c6 = _load_c6( dtype=dtype if dtype is not None else get_default_dtype(), device=device if device is not None else get_default_device(), ) self.c6 = c6 self.__dtype = self.c6.dtype self.__device = self.c6.device if any(tensor.device != self.device for tensor in (self.cn, self.c6)): raise RuntimeError("All tensors must be on the same device!") if any(tensor.dtype != self.dtype for tensor in (self.cn, self.c6)): raise RuntimeError("All tensors must have the same dtype!") if any( ( self.c6.shape[-2] != self.c6.shape[-1], self.c6.shape[-1] != self.cn.shape[-1], self.c6.shape[-4] != self.c6.shape[-3], self.c6.shape[-3] != self.cn.shape[-2], ) ): raise RuntimeError("`c6` & `cn` size mismatch found") @property def device(self) -> torch.device: """The device on which the `Reference` object resides.""" return self.__device @device.setter def device(self, *_: Any) -> NoReturn: """ Instruct users to use the ".to" method if wanting to change device. """ raise AttributeError("Move object to device using the `.to` method") @property def dtype(self) -> torch.dtype: """Floating point dtype used by reference object.""" return self.__dtype
[docs] def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> "Reference": """ Returns a copy of the `Reference` instance on the specified device. This method creates and returns a new copy of the `Reference` instance on the specified device "``device``". Parameters ---------- device : torch.device, optional Device to which all associated tensors should be moved. dtype : torch.dtype, optional Floating point type of the tensors. Returns ------- Reference A copy of the `Reference` instance placed on the specified device. Notes ----- If the `Reference` instance is already on the desired device `self` will be returned. """ if self.__device == device: if dtype is not None: return self.type(dtype) return self return self.__class__( self.cn.to(device=device, dtype=dtype), self.c6.to(device=device, dtype=dtype), )
[docs] def type(self, dtype: torch.dtype) -> "Reference": """ Returns a copy of the `Reference` instance with specified floating point type. This method creates and returns a new copy of the `Reference` instance with the specified dtype. Parameters ---------- dtype : torch.dtype Floating point type of the tensors. Returns ------- Reference A copy of the `Reference` instance with the specified dtype. Notes ----- If the `Reference` instance has already the desired dtype `self` will be returned. """ if self.__dtype == dtype: return self return self.__class__( self.cn.type(dtype), self.c6.type(dtype), )
def __str__(self) -> str: """Creates a string representation of the Reference object.""" return ( f"{self.__class__.__name__}(n_element={self.cn.shape[-2]}, " f"n_reference={self.cn.shape[-1]}, dtype={self.__dtype}, " f"device={self.__device})" ) def __repr__(self) -> str: """Creates a string representation of the Reference object.""" return str(self)