Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basis transfer code #216

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
AtomicDataDict.HAMILTONIAN_KEY, # new # should be nested
AtomicDataDict.OVERLAP_KEY, # new # should be nested
AtomicDataDict.ENERGY_EIGENVALUE_KEY, # new # should be nested
AtomicDataDict.EIGENVECTOR_KEY, # new # should be nested
AtomicDataDict.ENERGY_WINDOWS_KEY, # new,
AtomicDataDict.BAND_WINDOW_KEY, # new,
AtomicDataDict.NODE_SOC_SWITCH_KEY # new
Expand Down
1 change: 1 addition & 0 deletions dptb/data/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ATOM_TYPE_KEY: Final[str] = "atom_types"
# [n_batch, n_kpoint, n_orb]
ENERGY_EIGENVALUE_KEY: Final[str] = "eigenvalue"
EIGENVECTOR_KEY: Final[str] = "eigenvector"

# [n_batch, 2]
ENERGY_WINDOWS_KEY = "ewindow"
Expand Down
3 changes: 2 additions & 1 deletion dptb/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from .dftbsk import DFTBSK
from .hamiltonian import E3Hamiltonian, SKHamiltonian
from .hr2hk import HR2HK
from .energy import Eigenvalues
from .energy import Eigenvalues, Eigh

__all__ = [
build_model,
E3Hamiltonian,
SKHamiltonian,
HR2HK,
Eigenvalues,
Eigh,
NNENV,
NNSK,
MIX,
Expand Down
84 changes: 84 additions & 0 deletions dptb/nn/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,90 @@ def forward(self, data: AtomicDataDict.Type, nk: Optional[int]=None) -> AtomicDa

eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field]))
data[self.out_field] = torch.nested.as_nested_tensor([torch.cat(eigvals, dim=0)])
if nested:
data[AtomicDataDict.KPOINT_KEY] = torch.nested.as_nested_tensor([kpoints])
else:
data[AtomicDataDict.KPOINT_KEY] = kpoints

return data

class Eigh(nn.Module):
def __init__(
self,
idp: Union[OrbitalMapper, None]=None,
h_edge_field: str = AtomicDataDict.EDGE_FEATURES_KEY,
h_node_field: str = AtomicDataDict.NODE_FEATURES_KEY,
h_out_field: str = AtomicDataDict.HAMILTONIAN_KEY,
eigval_field: str = AtomicDataDict.ENERGY_EIGENVALUE_KEY,
eigvec_field: str = AtomicDataDict.EIGENVECTOR_KEY,
s_edge_field: str = None,
s_node_field: str = None,
s_out_field: str = None,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu")):
super(Eigh, self).__init__()

self.h2k = HR2HK(
idp=idp,
edge_field=h_edge_field,
node_field=h_node_field,
out_field=h_out_field,
dtype=dtype,
device=device,
)

if s_edge_field is not None:
self.s2k = HR2HK(
idp=idp,
overlap=True,
edge_field=s_edge_field,
node_field=s_node_field,
out_field=s_out_field,
dtype=dtype,
device=device,
)

self.overlap = True
else:
self.overlap = False

self.eigval_field = eigval_field
self.eigvec_field = eigvec_field
self.h_out_field = h_out_field
self.s_out_field = s_out_field


def forward(self, data: AtomicDataDict.Type, nk: Optional[int]=None) -> AtomicDataDict.Type:
kpoints = data[AtomicDataDict.KPOINT_KEY]
if kpoints.is_nested:
nested = True
assert kpoints.size(0) == 1
kpoints = kpoints[0]
else:
nested = False
num_k = kpoints.shape[0]
eigvals = []
eigvecs = []
if nk is None:
nk = num_k
for i in range(int(np.ceil(num_k / nk))):
data[AtomicDataDict.KPOINT_KEY] = kpoints[i*nk:(i+1)*nk]
data = self.h2k(data)
if self.overlap:
data = self.s2k(data)
chklowt = torch.linalg.cholesky(data[self.s_out_field])
chklowtinv = torch.linalg.inv(chklowt)
data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj())
else:
data[self.h_out_field] = data[self.h_out_field]

eigval, eigvec = torch.linalg.eigh(data[self.h_out_field])
eigvecs.append(torch.transpose(torch.transpose(chklowtinv,dim0=1,dim1=2).conj() @ eigvec,dim0=1,dim1=2))
eigvals.append(eigval)

data[self.eigval_field] = torch.nested.as_nested_tensor([torch.cat(eigvals, dim=0)])
data[self.eigvec_field] = torch.cat(eigvecs, dim=0)

if nested:
data[AtomicDataDict.KPOINT_KEY] = torch.nested.as_nested_tensor([kpoints])
else:
Expand Down
102 changes: 102 additions & 0 deletions dptb/nnops/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,108 @@ def forward(
# hopping_loss += self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt))

# return hopping_loss + onsite_loss

@Loss.register("eig_ham")
class EigHamLoss(nn.Module):
def __init__(
self,
basis: Dict[str, Union[str, list]]=None,
idp: Union[OrbitalMapper, None]=None,
overlap: bool=False,
onsite_shift: bool=False,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
diff_on: bool=False,
eout_weight: float=0.01,
diff_weight: float=0.01,
diff_valence: dict=None,
spin_deg: int = 2,
coeff_ham: float=1.,
coeff_ovp: float=1.,
**kwargs,
):
super(EigHamLoss, self).__init__()
self.loss1 = nn.L1Loss()
self.loss2 = nn.MSELoss()
self.overlap = overlap
self.device = device
self.onsite_shift = onsite_shift
self.coeff_ham = coeff_ham
self.coeff_ovp = coeff_ovp

if basis is not None:
self.idp = OrbitalMapper(basis, method="e3tb", device=self.device)
if idp is not None:
assert idp == self.idp, "The basis of idp and basis should be the same."
else:
assert idp is not None, "Either basis or idp should be provided."
self.idp = idp

self.eigloss = EigLoss(
idp=self.idp,
overlap=overlap,
diff_on=diff_on,
eout_weight=eout_weight,
diff_weight=diff_weight,
diff_valence=diff_valence,
spin_deg=spin_deg,
dtype=dtype,
device=device,
)

def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
# mask the data

if self.onsite_shift:
batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0]))
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
if batch.max() == 0: # when batchsize is zero
mu = mu.mean().detach()
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
elif batch.max() >= 1:
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
slices = [0] + slices
ndiag_batch = torch.stack([i.sum() for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
mu = mu.detach()
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, device=self.device)
for i in range(1, batch.max().item()+1):
edge_mu_index[data["__slices__"]["edge_index"][i]:data["__slices__"]["edge_index"][i+1]] += i
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu[edge_mu_index, None] * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]

pre = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
tgt = ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
onsite_loss = 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))

pre = data[AtomicDataDict.EDGE_FEATURES_KEY][self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
tgt = ref_data[AtomicDataDict.EDGE_FEATURES_KEY][self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
hopping_loss = 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))

if self.overlap:
pre = data[AtomicDataDict.EDGE_OVERLAP_KEY][self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
tgt = ref_data[AtomicDataDict.EDGE_OVERLAP_KEY][self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
overlap_loss = 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))

pre = data[AtomicDataDict.NODE_OVERLAP_KEY][self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
tgt = ref_data[AtomicDataDict.NODE_OVERLAP_KEY][self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
overlap_loss += 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))

ham_loss = (1/3) * (hopping_loss + onsite_loss + (self.coeff_ovp / self.coeff_ham) * overlap_loss)
else:
ham_loss = 0.5 * (onsite_loss + hopping_loss)

eigloss = self.eigloss(data, ref_data)

return self.coeff_ham * ham_loss + eigloss






@Loss.register("hamil_abs")
Expand Down
6 changes: 6 additions & 0 deletions dptb/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,11 @@ def loss_options():
Argument("spin_deg", int, optional=True, default=2, doc="The spin degeneracy of band structure. Default: 2"),
]

eig_ham = [
Argument("coeff_ham", float, optional=True, default=1., doc="The coefficient of the hamiltonian penalty. Default: 1"),
Argument("coeff_ovp", float, optional=True, default=1., doc="The coefficient of the hamiltonian penalty. Default: 1"),
]

skints = [
Argument("skdata", str, optional=False, doc="The path to the skfile or sk database."),
]
Expand All @@ -769,6 +774,7 @@ def loss_options():
Argument("skints", dict, sub_fields=skints),
Argument("hamil_abs", dict, sub_fields=hamil),
Argument("hamil_blas", dict, sub_fields=hamil),
Argument("eig_ham", dict, sub_fields=hamil+eigvals+eig_ham),
], optional=False, doc=doc_method)


Expand Down
4 changes: 2 additions & 2 deletions dptb/utils/config_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def check_config_train(
if train_data_config.get("get_Hamiltonian") and not train_data_config.get("get_eigenvalues"):
assert jdata['train_options']['loss_options']['train'].get("method").startswith("hamil")

if train_data_config.get("get_Hamiltonian") and train_data_config.get("get_eigenvalues"):
raise RuntimeError("The train data set should not have both get_Hamiltonian and get_eigenvalues set to True.")
# if train_data_config.get("get_Hamiltonian") and train_data_config.get("get_eigenvalues"):
# raise RuntimeError("The train data set should not have both get_Hamiltonian and get_eigenvalues set to True.")

#if jdata["data_options"].get("validation"):

Expand Down