Skip to content

Commit

Permalink
[Feature]: add swanlab logger (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyi-Lin authored May 2, 2024
1 parent cdce915 commit f534a7f
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ datasets/
experiments_t2m/
experiments_t2m_test/
experiments_control/
experiments_control_test/
experiments_control_test/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Experimental results demonstrate the remarkable generation and controlling capab

## 📢 News

- **[2024/05/02]** We support the [SwanLab](https://github.com/SwanHubX/SwanLab) logger, please refer to this [PR](https://github.com/Dai-Wenxun/MotionLCM/pull/7) for details.
- **[2024/05/01]** Upload paper and release code.

## 👨‍🏫 Quick Start
Expand Down
2 changes: 2 additions & 0 deletions mld/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def parse_args() -> DictConfig:
parser.add_argument('--example', type=str, required=False, help="input text and lengths with txt format")
parser.add_argument('--no-plot', action="store_true", required=False, help="whether plot the skeleton-based motion")
parser.add_argument('--replication', type=int, default=1, help="the number of replication of sampling")
parser.add_argument('--vis', type=str, default="tb", choices=['tb', 'swanlab'], help="the visualization method, tensorboard or swanlab")
args = parser.parse_args()

cfg = OmegaConf.load(args.cfg)
Expand All @@ -44,4 +45,5 @@ def parse_args() -> DictConfig:
cfg.example = args.example
cfg.no_plot = args.no_plot
cfg.replication = args.replication
cfg.vis = args.vis
return cfg
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ torch==1.13.1
gdown
omegaconf
rich
swanlab==0.3.1
torchmetrics==1.3.2
scipy==1.11.2
matplotlib==3.3.4
Expand All @@ -14,4 +15,4 @@ h5py==3.11.0
smplx==0.1.28
chumpy==0.70
numpy==1.23.1
natsort==8.4.0
natsort==8.4.0
21 changes: 17 additions & 4 deletions train_motion_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from omegaconf import OmegaConf

import torch
import swanlab
import diffusers
import transformers
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -30,8 +31,14 @@ def main():
output_dir = osp.join(cfg.FOLDER, name_time_str)
os.makedirs(output_dir, exist_ok=False)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=False)

writer = SummaryWriter(output_dir)

if cfg.vis == "tb":
writer = SummaryWriter(output_dir)
elif cfg.vis == "swanlab":
run = swanlab.init(project="MotionLCM", experiment_name=os.path.normpath(output_dir).replace(os.path.sep, "-"),
suffix=None, config=cfg, logdir=output_dir)
else:
raise ValueError(f"Invalid vis method: {cfg.vis}")

stream_handler = logging.StreamHandler(sys.stdout)
file_handler = logging.FileHandler(osp.join(output_dir, 'output.log'))
Expand Down Expand Up @@ -133,7 +140,10 @@ def validation():
min_val_tj = metrics['Metrics/traj_fail_50cm']
print_table(f'Metrics@Step-{global_step}', metrics)
for k, v in metrics.items():
writer.add_scalar(k, v, global_step=global_step)
if cfg.vis == "tb":
writer.add_scalar(k, v, global_step=global_step)
elif cfg.vis == "swanlab":
run.log({k: v}, step=global_step)

model.controlnet.train()
model.traj_encoder.train()
Expand Down Expand Up @@ -189,7 +199,10 @@ def validation():
"diff_loss": diff_loss.detach().item(), 'cond_loss': cond_loss.detach().item(), 'rot_loss': rot_loss.detach().item()}
progress_bar.set_postfix(**logs)
for k, v in logs.items():
writer.add_scalar(k, v, global_step=global_step)
if cfg.vis == "tb":
writer.add_scalar(k, v, global_step=global_step)
elif cfg.vis == "swanlab":
run.log({k: v}, step=global_step)

if global_step >= cfg.TRAIN.max_train_steps:
break
Expand Down
23 changes: 18 additions & 5 deletions train_motionlcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from omegaconf import OmegaConf

import torch
import swanlab
import diffusers
import transformers
import torch.nn.functional as F
Expand Down Expand Up @@ -127,7 +128,13 @@ def main():
os.makedirs(output_dir, exist_ok=False)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=False)

writer = SummaryWriter(output_dir)
if cfg.vis == "tb":
writer = SummaryWriter(output_dir)
elif cfg.vis == "swanlab":
run = swanlab.init(project="MotionLCM", experiment_name=os.path.normpath(output_dir).replace(os.path.sep, "-"),
suffix=None, config=cfg, logdir=output_dir)
else:
raise ValueError(f"Invalid vis method: {cfg.vis}")

stream_handler = logging.StreamHandler(sys.stdout)
file_handler = logging.FileHandler(osp.join(output_dir, 'output.log'))
Expand Down Expand Up @@ -245,7 +252,10 @@ def validation():
min_val_fid = metrics['Metrics/FID']
print_table(f'Metrics@Step-{global_step}', metrics)
for k, v in metrics.items():
writer.add_scalar(k, v, global_step=global_step)
if cfg.vis == "tb":
writer.add_scalar(k, v, global_step=global_step)
elif cfg.vis == "swanlab":
run.log({k: v}, step=global_step)
base_model.train()
return max_val_rp1, min_val_fid

Expand Down Expand Up @@ -411,9 +421,12 @@ def validation():

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
writer.add_scalar('loss', logs['loss'], global_step=global_step)
writer.add_scalar('lr', logs['lr'], global_step=global_step)

if cfg.vis == "tb":
writer.add_scalar('loss', logs['loss'], global_step=global_step)
writer.add_scalar('lr', logs['lr'], global_step=global_step)
elif cfg.vis == "swanlab":
run.log({'loss': logs['loss'], 'lr': logs['lr']}, step=global_step)

if global_step >= cfg.TRAIN.max_train_steps:
break

Expand Down

0 comments on commit f534a7f

Please sign in to comment.