from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Final, Literal, cast
import torch
from e3nn import o3
from graph_pes.atomic_graph import (
DEFAULT_CUTOFF,
AtomicGraph,
PropertyKey,
index_over_neighbours,
neighbour_distances,
neighbour_vectors,
)
from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.graph_property_model import GraphTensorModel
from graph_pes.models.components.aggregation import (
NeighbourAggregation,
NeighbourAggregationMode,
)
from graph_pes.models.components.distances import (
DistanceExpansion,
PolynomialEnvelope,
get_distance_expansion,
)
from graph_pes.models.components.scaling import (
LocalEnergiesScaler,
LocalTensorScaler,
)
from graph_pes.models.e3nn.mace_utils import (
Contraction,
ContractionConfig,
UnflattenIrreps,
parse_irreps,
)
from graph_pes.models.e3nn.utils import (
LinearReadOut,
LinearTPReadOut,
NonLinearReadOut,
NonLinearTPReadOut,
ReadOut,
SphericalHarmonics,
UnrestrictedLinearReadOut,
UnrestrictedNonLinearReadOut,
as_irreps,
build_limited_tensor_product,
to_full_irreps,
)
from graph_pes.utils.nn import (
MLP,
AtomicOneHot,
HaddamardProduct,
MLPConfig,
PerElementEmbedding,
UniformModuleList,
)
class MACEInteraction(torch.nn.Module):
"""
The MACE interaction block.
Generates new node embeddings from the old node embeddings and the
spherical harmonic expansion and mangitudes of the neighbour vectors.
"""
def __init__(
self,
# input nodes
irreps_in: list[o3.Irrep],
nodes: NodeDescription,
# input edges
sph_harmonics: o3.Irreps,
radial_basis_features: int,
mlp: MLPConfig,
# other
aggregation: NeighbourAggregationMode,
mix_attributes: bool,
):
super().__init__()
irreps_out = [ir for _, ir in sph_harmonics]
features_in = as_irreps([(nodes.channels, ir) for ir in irreps_in])
self.pre_linear = o3.Linear(
features_in,
features_in,
internal_weights=True,
shared_weights=True,
)
self.tp = build_limited_tensor_product(
features_in,
sph_harmonics,
irreps_out,
)
mid_features = self.tp.irreps_out.simplify()
assert all(ir in mid_features for ir in irreps_out)
self.weight_generator = MLP.from_config(
mlp,
input_features=radial_basis_features,
output_features=self.tp.weight_numel,
bias=False,
)
features_out = as_irreps(
[(nodes.channels, ir) for (_, ir) in sph_harmonics]
)
self.post_linear = o3.Linear(
mid_features,
features_out,
internal_weights=True,
shared_weights=True,
)
self.aggregator = NeighbourAggregation.parse(aggregation)
if mix_attributes:
self.attribute_mixer = o3.FullyConnectedTensorProduct(
irreps_in1=features_out,
irreps_in2=o3.Irreps(f"{nodes.attributes}x0e"),
irreps_out=features_out,
)
else:
self.attribute_mixer = None
self.reshape = UnflattenIrreps(irreps_out, nodes.channels)
# book-keeping
self.irreps_in = features_in
self.irreps_out = features_out
def forward(
self,
node_features: torch.Tensor,
node_attributes: torch.Tensor,
sph_harmonics: torch.Tensor,
radial_basis: torch.Tensor,
graph: AtomicGraph,
) -> torch.Tensor:
# pre-linear
node_features = self.pre_linear(node_features) # (N, a)
# tensor product: mix node and edge features
neighbour_features = index_over_neighbours(
node_features, graph
) # (E, a)
weights = self.weight_generator(radial_basis) # (E, b)
messages = self.tp(
neighbour_features,
sph_harmonics,
weights,
) # (E, c)
# aggregate
total_message = self.aggregator(messages, graph) # (N, c)
# post-linear
node_features = self.post_linear(total_message) # (N, d)
if self.attribute_mixer is not None:
node_features = self.attribute_mixer(node_features, node_attributes)
return self.reshape(node_features) # (N, channels, d')
# type hints for mypy
def __call__(
self,
node_features: torch.Tensor,
node_attributes: torch.Tensor,
sph_harmonics: torch.Tensor,
radial_basis: torch.Tensor,
graph: AtomicGraph,
) -> torch.Tensor:
return super().__call__(
node_features,
node_attributes,
sph_harmonics,
radial_basis,
graph,
)
@dataclass
class NodeDescription:
channels: int
attributes: int
hidden_features: list[o3.Irrep]
def hidden_irreps(self) -> o3.Irreps:
return to_full_irreps(self.channels, self.hidden_features)
class MACELayer(torch.nn.Module):
def __init__(
self,
irreps_in: list[o3.Irrep],
nodes: NodeDescription,
correlation: int,
sph_harmonics: o3.Irreps,
radial_basis_features: int,
mlp: MLPConfig,
use_sc: bool,
aggregation: NeighbourAggregationMode,
residual: bool,
final_layer: bool,
output_irrep: str | None = None,
is_pes: bool = True,
):
super().__init__()
self.interaction = MACEInteraction(
irreps_in=irreps_in,
nodes=nodes,
sph_harmonics=sph_harmonics,
radial_basis_features=radial_basis_features,
mlp=mlp,
aggregation=aggregation,
# only mix attributes in the interaction block
# if we **aren't** using a residual connection
mix_attributes=not residual,
)
actual_mid_features = [ir for _, ir in self.interaction.irreps_out]
if is_pes:
output_features = o3.Irreps(
nodes.hidden_irreps()
if not final_layer
else o3.Irreps(f"{nodes.channels}x0e")
)
else:
output_features = o3.Irreps(
nodes.hidden_irreps()
if not final_layer
else o3.Irreps(
f"{nodes.channels}x{output_irrep}"
) # we want this to be the irreps of the output
)
self.contractions = UniformModuleList(
[
Contraction(
config=ContractionConfig(
num_features=nodes.channels,
n_node_attributes=nodes.attributes,
irrep_s_in=actual_mid_features,
irrep_out=target_irrep,
),
correlation=correlation,
)
for target_irrep in [o.ir for o in output_features]
]
)
# TODO: should we change the irreps_in2 to match the tp tensors?
if use_sc and residual:
# links input features to output features via a tensor product
self.residual_update = o3.FullyConnectedTensorProduct(
irreps_in1=[(nodes.channels, ir) for ir in irreps_in],
irreps_in2=o3.Irreps(f"{nodes.attributes}x0e"),
irreps_out=output_features,
)
else:
self.residual_update = None
# update the hidden features from the interaction block
# and target the output features
self.post_linear = o3.Linear(
output_features,
output_features,
internal_weights=True,
shared_weights=True,
)
# book-keeping
self.irreps_in = irreps_in
self.irreps_out: o3.Irreps = output_features # type: ignore
def forward(
self,
node_features: torch.Tensor,
node_attributes: torch.Tensor,
sph_harmonics: torch.Tensor,
radial_basis: torch.Tensor,
graph: AtomicGraph,
) -> torch.Tensor:
# A MACE layer operates on:
# - node features with multiplicity M, e.g. M=16: 16x0e + 16x1o
# - node attributes with multiplicity A e.g. A=5: 5x0e
# - spherical harmonics up to l_max, e.g. l_max=2: 1x0e + 1x1o + 1x2e
# interact
internal_node_features = self.interaction(
node_features,
node_attributes,
sph_harmonics,
radial_basis,
graph,
) # (N, M, irreps)
# contract using the contractions directly
contracted_features = torch.cat(
[
contraction(internal_node_features, node_attributes)
for contraction in self.contractions
],
dim=-1,
) # (N, irreps_out)
# residual update
if self.residual_update is not None:
update = self.residual_update(
node_features,
node_attributes,
) # (N, irreps_out)
contracted_features = contracted_features + update
# linear update
node_features = self.post_linear(contracted_features) # (N, irreps_out)
return node_features
class _BaseMACE(GraphPESModel):
def __init__(
self,
# radial things
cutoff: float,
n_radial: int,
radial_expansion: type[DistanceExpansion] | str,
weights_mlp: MLPConfig,
# node things
nodes: NodeDescription,
node_attribute_generator: Callable[[torch.Tensor], torch.Tensor],
# message passing
layers: int,
l_max: int,
correlation: int,
neighbour_aggregation: NeighbourAggregationMode,
use_self_connection: bool,
# readout
readout_width: int,
):
super().__init__(
cutoff=cutoff,
implemented_properties=["local_energies"],
)
if o3.Irrep("0e") not in nodes.hidden_features:
raise ValueError("MACE requires a `0e` hidden feature")
# radial things
sph_harmonics = cast(o3.Irreps, o3.Irreps.spherical_harmonics(l_max))
self.spherical_harmonics = SphericalHarmonics(
sph_harmonics,
normalize=True,
normalization="component",
)
if isinstance(radial_expansion, str):
radial_expansion = get_distance_expansion(radial_expansion)
self.radial_expansion = HaddamardProduct(
radial_expansion(
n_features=n_radial, cutoff=cutoff, trainable=True
),
PolynomialEnvelope(cutoff=cutoff, p=5),
)
# node things
self.node_attribute_generator = node_attribute_generator
self.initial_node_embedding = PerElementEmbedding(nodes.channels)
# message passing
current_node_irreps = [o3.Irrep("0e")]
self.layers: UniformModuleList[MACELayer] = UniformModuleList([])
for i in range(layers):
# only use residual skip after the first layer
use_residual = i != 0
final_layer = i == layers - 1
layer = MACELayer(
irreps_in=current_node_irreps,
nodes=nodes,
correlation=correlation,
sph_harmonics=sph_harmonics,
radial_basis_features=n_radial,
mlp=weights_mlp,
use_sc=use_self_connection,
aggregation=neighbour_aggregation,
residual=use_residual,
final_layer=final_layer,
)
self.layers.append(layer)
current_node_irreps = [ir for _, ir in layer.irreps_out]
self.readouts: UniformModuleList[ReadOut] = UniformModuleList(
[LinearReadOut(nodes.hidden_irreps()) for _ in range(layers - 1)]
+ [
NonLinearReadOut(
self.layers[-1].irreps_out, hidden_dim=readout_width
)
],
)
self.scaler = LocalEnergiesScaler()
def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]:
# pre-compute some things
vectors = neighbour_vectors(graph)
sph_harmonics = self.spherical_harmonics(vectors)
edge_features = self.radial_expansion(
neighbour_distances(graph).view(-1, 1)
)
node_attributes = self.node_attribute_generator(graph.Z)
# generate initial node features
node_features = self.initial_node_embedding(graph.Z)
# update node features through message passing layers
per_atom_energies = []
for layer, readout in zip(self.layers, self.readouts):
node_features = layer(
node_features,
node_attributes,
sph_harmonics,
edge_features,
graph,
)
per_atom_energies.append(readout(node_features))
# sum up the per-atom energies
local_energies = torch.sum(
torch.stack(per_atom_energies), dim=0
).squeeze()
# return scaled local energy predictions
return {"local_energies": self.scaler(local_energies, graph)}
class _BaseTensorMACE(GraphTensorModel):
def __init__(
self,
# radial things
cutoff: float,
n_radial: int,
radial_expansion: type[DistanceExpansion] | str,
weights_mlp: MLPConfig,
# node things
nodes: NodeDescription,
node_attribute_generator: Callable[[torch.Tensor], torch.Tensor],
# message passing
layers: int,
l_max: int,
correlation: int,
neighbour_aggregation: NeighbourAggregationMode,
use_self_connection: bool,
# readout
readout_width: int,
# tensor related stuff
target_method: Literal["direct", "tensor_product"],
number_of_tps: int | None = None,
target_tensor_irreps: str | None = None,
irrep_tp: str | None = None,
props: str = "tensor",
):
if target_tensor_irreps is None:
target_tensor_irreps = o3.Irreps("0e")
super().__init__(
cutoff=cutoff,
implemented_properties=props,
)
if o3.Irrep("0e") not in nodes.hidden_features:
raise ValueError("MACE requires a `0e` hidden feature")
assert target_method in ["direct", "tensor_product"]
if target_method == "tensor_product":
assert number_of_tps > 1 and number_of_tps % 2 == 0
self.target_method = target_method
self.irrep_tp = irrep_tp
self.number_of_tps = number_of_tps
self.target_tensor_irreps = target_tensor_irreps
# radial things
sph_harmonics = cast(o3.Irreps, o3.Irreps.spherical_harmonics(l_max))
self.spherical_harmonics = SphericalHarmonics(
sph_harmonics,
normalize=True,
normalization="component",
)
if isinstance(radial_expansion, str):
radial_expansion = get_distance_expansion(radial_expansion)
self.radial_expansion = HaddamardProduct(
radial_expansion(
n_features=n_radial, cutoff=cutoff, trainable=True
),
PolynomialEnvelope(cutoff=cutoff, p=5),
)
# node things
self.node_attribute_generator = node_attribute_generator
self.initial_node_embedding = PerElementEmbedding(nodes.channels)
# message passing
current_node_irreps = [o3.Irrep("0e")]
self.layers: UniformModuleList[MACELayer] = UniformModuleList([])
for i in range(layers):
# only use residual skip after the first layer
use_residual = i != 0
final_layer = i == layers - 1
layer = MACELayer(
irreps_in=current_node_irreps,
nodes=nodes,
correlation=correlation,
sph_harmonics=sph_harmonics,
radial_basis_features=n_radial,
mlp=weights_mlp,
use_sc=use_self_connection,
aggregation=neighbour_aggregation,
residual=use_residual,
final_layer=final_layer,
output_irrep=self.irrep_tp,
is_pes=False,
)
self.layers.append(layer)
current_node_irreps = [ir for _, ir in layer.irreps_out]
if self.target_method == "tensor_product":
self.readouts: UniformModuleList[ReadOut] = UniformModuleList(
[
LinearTPReadOut(
nodes.hidden_irreps(),
number_of_tps=self.number_of_tps,
tp_target=self.irrep_tp,
output_irreps=self.target_tensor_irreps,
)
for _ in range(layers - 1)
]
+ [
NonLinearTPReadOut(
self.layers[-1].irreps_out,
hidden_dim=readout_width,
number_of_tps=self.number_of_tps,
tp_target=self.irrep_tp,
output_irreps=self.target_tensor_irreps,
)
],
)
elif self.target_method == "direct":
self.readouts: UniformModuleList[ReadOut] = UniformModuleList(
[
UnrestrictedLinearReadOut(
nodes.hidden_irreps(),
output_irreps=self.target_tensor_irreps,
)
for _ in range(layers - 1)
]
+ [
UnrestrictedNonLinearReadOut(
self.layers[-1].irreps_out,
hidden_dim=readout_width,
output_irreps=self.target_tensor_irreps,
)
],
)
self.target_tensor_irreps = o3.Irreps(self.target_tensor_irreps)
# TODO: do we need a scaler for the tensor properties
self.scaler = LocalTensorScaler(self.target_tensor_irreps.dim)
def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]:
# pre-compute some things
vectors = neighbour_vectors(graph)
sph_harmonics = self.spherical_harmonics(vectors)
edge_features = self.radial_expansion(
neighbour_distances(graph).view(-1, 1)
)
node_attributes = self.node_attribute_generator(graph.Z)
# generate initial node features
node_features = self.initial_node_embedding(graph.Z)
# update node features through message passing layers
per_atom_tensors = []
for layer, readout in zip(self.layers, self.readouts):
node_features = layer(
node_features,
node_attributes,
sph_harmonics,
edge_features,
graph,
)
per_atom_tensors.append(readout(node_features).squeeze(-1))
# stack the per-atom atomic tensors
atomic_tensors = torch.sum(
torch.stack(per_atom_tensors, dim=-1), dim=-1
)
atomic_tensors = self.scaler(atomic_tensors, graph)
preds: dict[PropertyKey, torch.Tensor] = {
self.implemented_properties: atomic_tensors
}
# return scaled local energy predictions
# return {"local_energies": self.scaler(local_energies, graph)}
# TODO: do we need to scale atomic tensors?
# return preds
return preds
DEFAULT_MLP_CONFIG: Final[MLPConfig] = {
"hidden_depth": 3,
"hidden_features": 64,
"activation": "SiLU",
}
[docs]
class MACE(_BaseMACE):
r"""
The `MACE <https://arxiv.org/abs/2206.07697>`__ architecture.
One-hot encodings of the atomic numbers are used to condition the
``TensorProduct`` update in the residual connection of the message passing
layers, as well as the contractions in the message passing layers.
Following the notation used in `ACEsuite/mace <https://github.com/ACEsuit/mace>`__,
the first layer in this model is a ``RealAgnosticInteractionBlock``. Subsequent
layers are then ``RealAgnosticResidualInteractionBlock``\ s
Please cite the following if you use this model in your research:
.. code-block:: bibtex
@misc{Batatia2022MACE,
title = {
MACE: Higher Order Equivariant Message Passing
Neural Networks for Fast and Accurate Force Fields
},
author = {
Batatia, Ilyes and Kov{\'a}cs, D{\'a}vid P{\'e}ter and
Simm, Gregor N. C. and Ortner, Christoph and Cs{\'a}nyi, G{\'a}bor
},
year = {2022},
doi = {10.48550/arXiv.2206.07697},
}
Parameters
----------
elements
list of elements that this MACE model will be able to handle.
cutoff
radial cutoff (in Å) for the radial expansion (and message passing)
n_radial
number of bases to expand the radial distances into
radial_expansion
type of radial expansion to use. See :class:`~graph_pes.models.components.distances.DistanceExpansion`
for available options
weights_mlp
configuration for the MLPs that map the radial basis functions
to the weights of the interactions' tensor products
channels
the multiplicity of the node features corresponding to each irrep
specified in ``hidden_irreps``
hidden_irreps
string representations of the :class:`e3nn.o3.Irrep`\ s to use
for representing the node features between each message passing layer
l_max
the highest order to consider in:
* the spherical harmonics expansion of the neighbour vectors
* the irreps of node features used within each message passing layer
layers
number of message passing layers
correlation
maximum correlation (body-order) of the messages
aggregation
the type of aggregation to use when creating total messages from
neigbour messages :math:`m_{j \rightarrow i}`
self_connection
whether to use self-connections in the message passing layers
readout_width
the width of the MLP used to read out the per-atom energies after the
final message passing layer
Examples
--------
Basic usage:
.. code-block:: python
>>> from graph_pes.models import MACE
>>> model = MACE(
... elements=["H", "C", "N", "O"],
... cutoff=5.0,
... channels=16,
... radial_expansion="Bessel",
... )
Specification in a YAML file:
.. code-block:: yaml
model:
+MACE:
elements: [H, C, N, O]
cutoff: 5.0
radial_expansion: Bessel
# change from the default MLP config:
weights_mlp:
hidden_depth: 2
hidden_features: 16
activation: SiLU
""" # noqa: E501
def __init__(
self,
elements: list[str],
# radial things
cutoff: float = DEFAULT_CUTOFF,
n_radial: int = 8,
radial_expansion: type[DistanceExpansion] | str = "Bessel",
weights_mlp: MLPConfig = DEFAULT_MLP_CONFIG,
# node things
channels: int = 128,
hidden_irreps: str | list[str] = "0e + 1o",
# message passing things
l_max: int = 3,
layers: int = 2,
correlation: int = 3,
aggregation: NeighbourAggregationMode = "constant_fixed",
self_connection: bool = True,
# readout
readout_width: int = 16,
):
Z_embedding = AtomicOneHot(elements)
Z_dim = len(elements)
hidden_irrep_s = parse_irreps(hidden_irreps)
nodes = NodeDescription(
channels=channels,
attributes=Z_dim,
hidden_features=hidden_irrep_s,
)
super().__init__(
cutoff=cutoff,
n_radial=n_radial,
radial_expansion=radial_expansion,
weights_mlp={**DEFAULT_MLP_CONFIG, **weights_mlp},
nodes=nodes,
node_attribute_generator=Z_embedding,
l_max=l_max,
layers=layers,
correlation=correlation,
neighbour_aggregation=aggregation,
use_self_connection=self_connection,
readout_width=readout_width,
)
class TensorMACE(_BaseTensorMACE):
r"""
The `MACE <https://arxiv.org/abs/2206.07697>`__ architecture, targeting
`arbitrary rank<https://arziv.org/abs/2412.15063>` atomic tensors
One-hot encodings of the atomic numbers are used to condition the
``TensorProduct`` update in the residual connection of the message passing
layers, as well as the contractions in the message passing layers.
Following the notation used in `ACEsuite/mace <https://github.com/ACEsuit/mace>`__,
the first layer in this model is a ``RealAgnosticInteractionBlock``. Subsequent
layers are then ``RealAgnosticResidualInteractionBlock``\ s
Please cite the following if you use this model in your research:
.. code-block:: bibtex
@misc{Batatia2022MACE,
title = {
MACE: Higher Order Equivariant Message Passing
Neural Networks for Fast and Accurate Force Fields
},
author = {
Batatia, Ilyes and Kov{\'a}cs, D{\'a}vid P{\'e}ter and
Simm, Gregor N. C. and Ortner, Christoph and Cs{\'a}nyi, G{\'a}bor
},
year = {2022},
doi = {10.48550/arXiv.2206.07697},
}
@misc{BenMahmoud2025NMR,
title = {
Graph-neural-network predictions of solid-state NMR parameters
in silica from spherical tensor decomposition
},
author = {Ben Mahmoud, Chiheb and Rosset, Louise and Yates, Jonathan and
Deringer, Volker
},
year = {2025},
doi = {10.1063/5.0274240},
}
Parameters
----------
elements
list of elements that this MACE model will be able to handle.
cutoff
radial cutoff (in Å) for the radial expansion (and message passing)
n_radial
number of bases to expand the radial distances into
radial_expansion
type of radial expansion to use. See :class:`~graph_pes.models.components.distances.DistanceExpansion`
for available options
weights_mlp
configuration for the MLPs that map the radial basis functions
to the weights of the interactions' tensor products
channels
the multiplicity of the node features corresponding to each irrep
specified in ``hidden_irreps``
hidden_irreps
string representations of the :class:`e3nn.o3.Irrep`\ s to use
for representing the node features between each message passing layer
l_max
the highest order to consider in:
* the spherical harmonics expansion of the neighbour vectors
* the irreps of node features used within each message passing layer
layers
number of message passing layers
correlation
maximum correlation (body-order) of the messages
aggregation
the type of aggregation to use when creating total messages from
neigbour messages :math:`m_{j \rightarrow i}`
self_connection
whether to use self-connections in the message passing layers
readout_width
the width of the MLP used to read out the per-atom energies after the
final message passing layer
props
the property targeted by the model, set to "tensor"
target_method
determine how to reconstruct the target tensor, either by tensor product if
speherical tensor contains "0o", "1e", "2o",.. etc, or direct otherwise
number_of_tps
the number of tensors involved in the tensor product
target_tensor_irreps:
the irreps of the target tensor
irrep_tp:
the irrep of the tensors involved in the tensor product to reconstruct the target
Examples
--------
Basic usage:
.. code-block:: python
>>> from graph_pes.models import MACE
>>> model = TensorMACE(
... elements=["H", "C", "N", "O"],
... cutoff=5.0,
... channels=16,
... radial_expansion="Bessel",
... )
Specification in a YAML file:
.. code-block:: yaml
model:
+TensorMACE:
elements: [H, C, N, O]
cutoff: 5.0
radial_expansion: Bessel
target_method: tensor_product
target_tensor_irreps: 0e + 1e + 2e
number_of_tps: 128
irrep_tp: 3o
# change from the default MLP config:
weights_mlp:
hidden_depth: 2
hidden_features: 16
activation: SiLU
""" # noqa: E501
def __init__(
self,
elements: list[str],
# radial things
cutoff: float = DEFAULT_CUTOFF,
n_radial: int = 8,
radial_expansion: type[DistanceExpansion] | str = "Bessel",
weights_mlp: MLPConfig = DEFAULT_MLP_CONFIG,
# node things
channels: int = 128,
hidden_irreps: str | list[str] = "0e + 1o",
# message passing things
l_max: int = 3,
layers: int = 2,
correlation: int = 3,
aggregation: NeighbourAggregationMode = "constant_fixed",
self_connection: bool = True,
# readout
readout_width: int = 16,
# tensor related
props: str = "tensor",
target_method: str = "tensor_product",
number_of_tps=None,
target_tensor_irreps=None,
irrep_tp="1o",
):
Z_embedding = AtomicOneHot(elements)
Z_dim = len(elements)
hidden_irrep_s = parse_irreps(hidden_irreps)
nodes = NodeDescription(
channels=channels,
attributes=Z_dim,
hidden_features=hidden_irrep_s,
)
super().__init__(
cutoff=cutoff,
n_radial=n_radial,
radial_expansion=radial_expansion,
weights_mlp={**DEFAULT_MLP_CONFIG, **weights_mlp},
nodes=nodes,
node_attribute_generator=Z_embedding,
l_max=l_max,
layers=layers,
correlation=correlation,
neighbour_aggregation=aggregation,
use_self_connection=self_connection,
readout_width=readout_width,
target_method=target_method,
props=props,
number_of_tps=number_of_tps,
target_tensor_irreps=target_tensor_irreps,
irrep_tp=irrep_tp,
)
[docs]
class ZEmbeddingMACE(_BaseMACE):
"""
A variant of MACE that uses a fixed-size (``z_embed_dim``) per-element
embedding of the atomic numbers to condition the ``TensorProduct`` update
in the residual connection of the message passing layers, as well as the
contractions in the message passing layers.
Please cite the following if you use this model in your research:
.. code-block:: bibtex
@misc{Batatia2022MACE,
title = {
MACE: Higher Order Equivariant Message Passing
Neural Networks for Fast and Accurate Force Fields
},
author = {
Batatia, Ilyes and Kov{\'a}cs, D{\'a}vid P{\'e}ter and
Simm, Gregor N. C. and Ortner, Christoph and Cs{\'a}nyi, G{\'a}bor
},
year = {2022},
doi = {10.48550/arXiv.2206.07697},
}
All paramters are identical to :class:`~graph_pes.models.MACE`, except for the following:
- ``elements`` is not required or used here
- ``z_embed_dim`` controls size of the per-element embedding
""" # noqa: E501
def __init__(
self,
z_embed_dim: int = 4,
# radial things
cutoff: float = DEFAULT_CUTOFF,
n_radial: int = 8,
radial_expansion: type[DistanceExpansion] | str = "Bessel",
weights_mlp: MLPConfig = DEFAULT_MLP_CONFIG,
# node things
channels: int = 128,
hidden_irreps: str | list[str] = "0e + 1o",
# message passing things
l_max: int = 3,
layers: int = 2,
correlation: int = 3,
aggregation: NeighbourAggregationMode = "constant_fixed",
self_connection: bool = True,
# readout
readout_width: int = 16,
):
Z_embedding = PerElementEmbedding(z_embed_dim)
hidden_irrep_s = parse_irreps(hidden_irreps)
nodes = NodeDescription(
channels=channels,
attributes=z_embed_dim,
hidden_features=hidden_irrep_s,
)
super().__init__(
cutoff=cutoff,
n_radial=n_radial,
radial_expansion=radial_expansion,
weights_mlp={**DEFAULT_MLP_CONFIG, **weights_mlp},
nodes=nodes,
node_attribute_generator=Z_embedding,
l_max=l_max,
layers=layers,
correlation=correlation,
neighbour_aggregation=aggregation,
use_self_connection=self_connection,
readout_width=readout_width,
)
class ZEmbeddingTensorMACE(_BaseTensorMACE):
"""
A variant of TensorMACE that uses a fixed-size (``z_embed_dim``) per-element
embedding of the atomic numbers to condition the ``TensorProduct`` update
in the residual connection of the message passing layers, as well as the
contractions in the message passing layers.
Please cite the following if you use this model in your research:
.. code-block:: bibtex
@misc{Batatia2022MACE,
title = {
MACE: Higher Order Equivariant Message Passing
Neural Networks for Fast and Accurate Force Fields
},
author = {
Batatia, Ilyes and Kov{\'a}cs, D{\'a}vid P{\'e}ter and
Simm, Gregor N. C. and Ortner, Christoph and Cs{\'a}nyi, G{\'a}bor
},
year = {2022},
doi = {10.48550/arXiv.2206.07697},
}
@misc{BenMahmoud2025NMR,
title = {
Graph-neural-network predictions of solid-state NMR parameters
in silica from spherical tensor decomposition
},
author = {Ben Mahmoud, Chiheb and Rosset, Louise and Yates, Jonathan and
Deringer, Volker
},
year = {2025},
doi = {10.1063/5.0274240},
}
All paramters are identical to :class:`~graph_pes.models.MACE`, except for the following:
- ``elements`` is not required or used here
- ``z_embed_dim`` controls size of the per-element embedding
""" # noqa: E501
def __init__(
self,
z_embed_dim: int = 4,
# radial things
cutoff: float = DEFAULT_CUTOFF,
n_radial: int = 8,
radial_expansion: type[DistanceExpansion] | str = "Bessel",
weights_mlp: MLPConfig = DEFAULT_MLP_CONFIG,
# node things
channels: int = 128,
hidden_irreps: str | list[str] = "0e + 1o",
# message passing things
l_max: int = 3,
layers: int = 2,
correlation: int = 3,
aggregation: NeighbourAggregationMode = "constant_fixed",
self_connection: bool = True,
# readout
readout_width: int = 16,
# tensor related
props: str = "tensor",
target_method: str = "tensor_product",
number_of_tps=None,
target_tensor_irreps=None,
irrep_tp="1o",
):
Z_embedding = PerElementEmbedding(z_embed_dim)
hidden_irrep_s = parse_irreps(hidden_irreps)
nodes = NodeDescription(
channels=channels,
attributes=z_embed_dim,
hidden_features=hidden_irrep_s,
)
super().__init__(
cutoff=cutoff,
n_radial=n_radial,
radial_expansion=radial_expansion,
weights_mlp={**DEFAULT_MLP_CONFIG, **weights_mlp},
nodes=nodes,
node_attribute_generator=Z_embedding,
l_max=l_max,
layers=layers,
correlation=correlation,
neighbour_aggregation=aggregation,
use_self_connection=self_connection,
readout_width=readout_width,
target_method=target_method,
props=props,
number_of_tps=number_of_tps,
target_tensor_irreps=target_tensor_irreps,
irrep_tp=irrep_tp,
)