Skip to content

Commit

Permalink
Timestamp and reformat epoch log message
Browse files Browse the repository at this point in the history
I've found it practical to know how long a given epoch has taken to train.
Also, format the losses nicer:
* Print the weighed loss, not the unweighed. This is more intuitive, and has the
  property that the stated loss is indeed the sum of its components.
* Print the loss in scientific notation, since they now can scale differently
* Use uniform distance between the fields when printing
  • Loading branch information
jakobnissen committed Sep 28, 2023
1 parent cab3e0b commit 546cb74
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions vamb/encode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from typing import Optional, IO, Union
from pathlib import Path
import vamb.vambtools as _vambtools
Expand Down Expand Up @@ -342,11 +343,13 @@ def calc_loss(
kld = 0.5 * (mu.pow(2)).sum(dim=1)
sse_weight = self.alpha / self.ntnf
kld_weight = 1 / (self.nlatent * self.beta)
reconstruction_loss = ce * ce_weight + sse * sse_weight
kld_loss = kld * kld_weight
loss = (reconstruction_loss + kld_loss) * weights
weighed_ce = ce * ce_weight
weighed_sse = sse * sse_weight
weighed_kld = kld * kld_weight
reconstruction_loss = weighed_ce + weighed_sse
loss = (reconstruction_loss + weighed_kld) * weights

return loss.mean(), ce.mean(), sse.mean(), kld.mean()
return loss.mean(), weighed_ce.mean(), weighed_sse.mean(), weighed_kld.mean()

def trainepoch(
self,
Expand Down Expand Up @@ -393,7 +396,8 @@ def trainepoch(

if logfile is not None:
print(
"\tEpoch: {}\tLoss: {:.6f}\tCE: {:.7f}\tSSE: {:.6f}\tKLD: {:.4f}\tBatchsize: {}".format(
"\tTime: {}\tEpoch: {:>3} Loss: {:.5e} CE: {:.5e} SSE: {:.5e} KLD: {:.5e} Batchsize: {}".format(
datetime.datetime.now().strftime("%H:%M:%S"),
epoch + 1,
epoch_loss / len(data_loader),
epoch_celoss / len(data_loader),
Expand Down

0 comments on commit 546cb74

Please sign in to comment.