# This file is part of tad-dftd3.
# SPDX-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Dispersion energy
=================
This module provides the dispersion energy evaluation for the pairwise interactions.
Example
-------
>>> import torch
>>> import tad_dftd3 as d3
>>> numbers = torch.tensor([ # define fragments by setting atomic numbers to zero
... [8, 1, 1, 8, 1, 6, 1, 1, 1],
... [0, 0, 0, 8, 1, 6, 1, 1, 1],
... [8, 1, 1, 0, 0, 0, 0, 0, 0],
... ])
>>> positions = torch.tensor([ # define coordinates once
... [-4.224363834, +0.270465696, +0.527578960],
... [-5.011768887, +1.780116228, +1.143194385],
... [-2.468758653, +0.479766200, +0.982905589],
... [+1.146167671, +0.452771215, +1.257722311],
... [+1.841554378, -0.628298322, +2.538065200],
... [+2.024899840, -0.438480095, -1.127412563],
... [+1.210773578, +0.791908575, -2.550591723],
... [+4.077073644, -0.342495506, -1.267841745],
... [+1.404422261, -2.365753991, -1.503620411],
... ]).repeat(numbers.shape[0], 1, 1)
>>> ref = d3.reference.Reference()
>>> param = dict( # r²SCAN-D3(BJ)
... a1=torch.tensor(0.49484001),
... s8=torch.tensor(0.78981345),
... a2=torch.tensor(5.73083694),
... )
>>> cn = d3.ncoord.coordination_number(numbers, positions)
>>> weights = d3.model.weight_references(numbers, cn, ref)
>>> c6 = d3.model.atomic_c6(numbers, weights, ref)
>>> energy = d3.disp.dispersion(numbers, positions, param, c6)
>>> torch.set_printoptions(precision=7)
>>> print(torch.sum(energy[0] - energy[1] - energy[2])) # energy in Hartree
tensor(-0.0003964, dtype=torch.float64)
"""
from typing import Dict, Optional
import torch
from tad_mctc import storch
from tad_mctc.batch import real_pairs
from tad_mctc.data import pse
from . import data, defaults, model, ncoord
from .damping import dispersion_atm, rational_damping
from .reference import Reference
from .typing import (
DD,
Any,
CountingFunction,
DampingFunction,
Tensor,
WeightingFunction,
)
__all__ = ["dftd3", "dispersion", "dispersion2", "dispersion3"]
[docs]
def dftd3(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
*,
ref: Optional[Reference] = None,
rcov: Optional[Tensor] = None,
rvdw: Optional[Tensor] = None,
r4r2: Optional[Tensor] = None,
cutoff: Optional[Tensor] = None,
counting_function: CountingFunction = ncoord.exp_count,
weighting_function: WeightingFunction = model.gaussian_weight,
damping_function: DampingFunction = rational_damping,
) -> Tensor:
"""
Evaluate DFT-D3 dispersion energy for a batch of geometries.
Parameters
----------
numbers : torch.Tensor
Atomic numbers of the atoms in the system.
positions : torch.Tensor
Cartesian coordinates of the atoms in the system.
param : dict[str, Tensor]
DFT-D3 damping parameters.
ref : reference.Reference, optional
Reference C6 coefficients.
rcov : torch.Tensor, optional
Covalent radii of the atoms in the system.
rvdw : torch.Tensor, optional
Van der Waals radii of the atoms in the system.
r4r2 : torch.Tensor, optional
r⁴ over r² expectation values of the atoms in the system.
damping_function : Callable, optional
Damping function evaluate distance dependent contributions.
weighting_function : Callable, optional
Function to calculate weight of individual reference systems.
counting_function : Callable, optional
Calculates counting value in range 0 to 1 for each atom pair.
Returns
-------
Tensor
Atom-resolved DFT-D3 dispersion energy for each geometry.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}
if torch.max(numbers) >= defaults.MAX_ELEMENT:
raise ValueError(
f"No D3 parameters available for Z > {defaults.MAX_ELEMENT-1} "
f"({pse.Z2S[defaults.MAX_ELEMENT]})."
)
if cutoff is None:
cutoff = torch.tensor(defaults.D3_DISP_CUTOFF, **dd)
if ref is None:
ref = Reference(**dd)
if rcov is None:
rcov = data.COV_D3.to(**dd)[numbers]
if rvdw is None:
rvdw = data.VDW_D3.to(**dd)[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]
if r4r2 is None:
r4r2 = data.R4R2.to(**dd)[numbers]
cn = ncoord.cn_d3(
numbers, positions, counting_function=counting_function, rcov=rcov
)
weights = model.weight_references(numbers, cn, ref, weighting_function)
c6 = model.atomic_c6(numbers, weights, ref)
return dispersion(
numbers,
positions,
param,
c6,
rvdw,
r4r2,
damping_function,
cutoff=cutoff,
)
[docs]
def dispersion(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
c6: Tensor,
rvdw: Optional[Tensor] = None,
r4r2: Optional[Tensor] = None,
damping_function: DampingFunction = rational_damping,
cutoff: Optional[Tensor] = None,
**kwargs: Any,
) -> Tensor:
"""
Calculate dispersion energy between pairs of atoms.
Parameters
----------
numbers : Tensor
Atomic numbers of the atoms in the system.
positions : Tensor
Cartesian coordinates of the atoms in the system.
param : dict[str, Tensor]
DFT-D3 damping parameters.
c6 : Tensor
Atomic C6 dispersion coefficients.
rvdw : Tensor
Van der Waals radii of the atoms in the system.
r4r2 : Tensor
r⁴ over r² expectation values of the atoms in the system.
damping_function : Callable
Damping function evaluate distance dependent contributions.
Additional arguments are passed through to the function.
Returns
-------
Tensor
Atom-resolved DFT-D3 dispersion energy for each geometry.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}
if cutoff is None:
cutoff = torch.tensor(defaults.D3_DISP_CUTOFF, **dd)
if r4r2 is None:
r4r2 = data.R4R2.to(**dd)[numbers]
if numbers.shape != positions.shape[:-1]:
raise ValueError(
"Shape of positions is not consistent with atomic numbers.",
)
if numbers.shape != r4r2.shape:
raise ValueError(
"Shape of expectation values is not consistent with atomic numbers.",
)
if torch.max(numbers) >= defaults.MAX_ELEMENT:
raise ValueError(
f"No D3 parameters available for Z > {defaults.MAX_ELEMENT-1} "
f"({pse.Z2S[defaults.MAX_ELEMENT]})."
)
# two-body dispersion
energy = dispersion2(
numbers, positions, param, c6, r4r2, damping_function, cutoff, **kwargs
)
# three-body dispersion
if "s9" in param and param["s9"] != 0.0:
if rvdw is None:
rvdw = data.VDW_D3.to(**dd)[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]
energy += dispersion3(numbers, positions, param, c6, rvdw, cutoff)
return energy
[docs]
def dispersion2(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
c6: Tensor,
r4r2: Tensor,
damping_function: DampingFunction,
cutoff: Tensor,
**kwargs: Any,
) -> Tensor:
"""
Calculate dispersion energy between pairs of atoms.
Parameters
----------
numbers : Tensor
Atomic numbers of the atoms in the system.
positions : Tensor
Cartesian coordinates of the atoms in the system.
param : dict[str, Tensor]
DFT-D3 damping parameters.
c6 : Tensor
Atomic C6 dispersion coefficients.
r4r2 : Tensor
r⁴ over r² expectation values of the atoms in the system.
damping_function : Callable
Damping function evaluate distance dependent contributions.
Additional arguments are passed through to the function.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}
mask = real_pairs(numbers, mask_diagonal=True)
distances = torch.where(
mask,
storch.cdist(positions, positions, p=2),
torch.tensor(torch.finfo(positions.dtype).eps, **dd),
)
qq = 3 * r4r2.unsqueeze(-1) * r4r2.unsqueeze(-2)
c8 = c6 * qq
t6 = torch.where(
mask * (distances <= cutoff),
damping_function(6, distances, qq, param, **kwargs),
torch.tensor(0.0, **dd),
)
t8 = torch.where(
mask * (distances <= cutoff),
damping_function(8, distances, qq, param, **kwargs),
torch.tensor(0.0, **dd),
)
e6 = -0.5 * torch.sum(c6 * t6, dim=-1)
e8 = -0.5 * torch.sum(c8 * t8, dim=-1)
s6 = param.get("s6", torch.tensor(defaults.S6, **dd))
s8 = param.get("s8", torch.tensor(defaults.S8, **dd))
return s6 * e6 + s8 * e8
[docs]
def dispersion3(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
c6: Tensor,
rvdw: Tensor,
cutoff: Tensor,
rs9: Tensor = torch.tensor(4.0 / 3.0),
) -> Tensor:
"""
Three-body dispersion term. Currently this is only a wrapper for the
Axilrod-Teller-Muto dispersion term.
Parameters
----------
numbers : Tensor
Atomic numbers of the atoms in the system.
positions : Tensor
Cartesian coordinates of the atoms in the system.
param : dict[str, Tensor]
Dictionary of dispersion parameters. Default values are used for
missing keys.
c6 : Tensor
Atomic C6 dispersion coefficients.
rvdw : Tensor
Van der Waals radii of the atoms in the system.
cutoff : Tensor
Real-space cutoff.
rs9 : Tensor, optional
Scaling for van-der-Waals radii in damping function. Defaults to `4.0/3.0`.
Returns
-------
Tensor
Atom-resolved three-body dispersion energy.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}
alp = param.get("alp", torch.tensor(14.0, **dd))
s9 = param.get("s9", torch.tensor(1.0, **dd))
rs9 = rs9.type(positions.dtype).to(positions.device)
return dispersion_atm(numbers, positions, c6, rvdw, cutoff, s9, rs9, alp)