diff --git a/vamb/encode.py b/vamb/encode.py index 174ed5a0..a303984a 100644 --- a/vamb/encode.py +++ b/vamb/encode.py @@ -1,3 +1,4 @@ +import datetime from typing import Optional, IO, Union from pathlib import Path import vamb.vambtools as _vambtools @@ -342,11 +343,13 @@ def calc_loss( kld = 0.5 * (mu.pow(2)).sum(dim=1) sse_weight = self.alpha / self.ntnf kld_weight = 1 / (self.nlatent * self.beta) - reconstruction_loss = ce * ce_weight + sse * sse_weight - kld_loss = kld * kld_weight - loss = (reconstruction_loss + kld_loss) * weights + weighed_ce = ce * ce_weight + weighed_sse = sse * sse_weight + weighed_kld = kld * kld_weight + reconstruction_loss = weighed_ce + weighed_sse + loss = (reconstruction_loss + weighed_kld) * weights - return loss.mean(), ce.mean(), sse.mean(), kld.mean() + return loss.mean(), weighed_ce.mean(), weighed_sse.mean(), weighed_kld.mean() def trainepoch( self, @@ -393,7 +396,8 @@ def trainepoch( if logfile is not None: print( - "\tEpoch: {}\tLoss: {:.6f}\tCE: {:.7f}\tSSE: {:.6f}\tKLD: {:.4f}\tBatchsize: {}".format( + "\tTime: {}\tEpoch: {:>3} Loss: {:.5e} CE: {:.5e} SSE: {:.5e} KLD: {:.5e} Batchsize: {}".format( + datetime.datetime.now().strftime("%H:%M:%S"), epoch + 1, epoch_loss / len(data_loader), epoch_celoss / len(data_loader),