From 8e978149bfe8abab3eb77d965d579a111afdb0ff Mon Sep 17 00:00:00 2001 From: enhuiz Date: Tue, 3 Dec 2024 10:28:47 +0800 Subject: [PATCH] Support force gaussian prior for CFM model, fix some typing --- resemble_enhance/enhancer/enhancer.py | 5 ++++- resemble_enhance/enhancer/hparams.py | 4 ++++ resemble_enhance/enhancer/lcfm/cfm.py | 3 +-- resemble_enhance/enhancer/train.py | 6 +++--- resemble_enhance/utils/train_loop.py | 14 +++++--------- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/resemble_enhance/enhancer/enhancer.py b/resemble_enhance/enhancer/enhancer.py index cd5e1ab..7fef9eb 100644 --- a/resemble_enhance/enhancer/enhancer.py +++ b/resemble_enhance/enhancer/enhancer.py @@ -185,7 +185,10 @@ def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None): y_mel = _maybe(self.to_mel)(y) # (b d t) y_mel = _maybe(self.normalizer)(y_mel) - lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t) + if self.hp.force_gaussian_prior: + lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=None) # (b d t) + else: + lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t) if lcfm_decoded is None: o = None diff --git a/resemble_enhance/enhancer/hparams.py b/resemble_enhance/enhancer/hparams.py index ca89bea..87d0588 100644 --- a/resemble_enhance/enhancer/hparams.py +++ b/resemble_enhance/enhancer/hparams.py @@ -13,6 +13,7 @@ class HParams(HParamsBase): lcfm_latent_dim: int = 64 lcfm_training_mode: str = "ae" + # This value should be carefully tuned when training. Better estimate it from the latent vectors first lcfm_z_scale: float = 5 vocoder_extra_dim: int = 32 @@ -21,3 +22,6 @@ class HParams(HParamsBase): enhancer_stage1_run_dir: Path | None = None denoiser_run_dir: Path | None = None + + # Enable this increases the training stability (but will also disable the change of eval_tau) + force_gaussian_prior: bool = False diff --git a/resemble_enhance/enhancer/lcfm/cfm.py b/resemble_enhance/enhancer/lcfm/cfm.py index a512526..4abe8d1 100644 --- a/resemble_enhance/enhancer/lcfm/cfm.py +++ b/resemble_enhance/enhancer/lcfm/cfm.py @@ -17,8 +17,7 @@ class VelocityField(Protocol): - def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: - ... + def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: ... class Solver: diff --git a/resemble_enhance/enhancer/train.py b/resemble_enhance/enhancer/train.py index 7b99d81..e9d2a6f 100644 --- a/resemble_enhance/enhancer/train.py +++ b/resemble_enhance/enhancer/train.py @@ -30,7 +30,7 @@ def load_G(run_dir: Path, hp: HParams | None = None, training=True): return engine -def load_D(run_dir: Path, hp: HParams): +def load_D(run_dir: Path, hp: HParams | None): if hp is None: hp = HParams.load(run_dir) assert isinstance(hp, HParams) @@ -41,8 +41,8 @@ def load_D(run_dir: Path, hp: HParams): def save_wav(path: Path, wav: Tensor, rate: int): - wav = wav.detach().cpu().numpy() - soundfile.write(path, wav, samplerate=rate) + wav_numpy = wav.detach().cpu().numpy() + soundfile.write(path, wav_numpy, samplerate=rate) def main(): diff --git a/resemble_enhance/utils/train_loop.py b/resemble_enhance/utils/train_loop.py index dcfbf2e..174a993 100644 --- a/resemble_enhance/utils/train_loop.py +++ b/resemble_enhance/utils/train_loop.py @@ -18,23 +18,19 @@ class EvalFn(Protocol): - def __call__(self, engine: Engine, eval_dir: Path) -> None: - ... + def __call__(self, engine: Engine, eval_dir: Path) -> None: ... class EngineLoader(Protocol): - def __call__(self, run_dir: Path) -> Engine: - ... + def __call__(self, run_dir: Path) -> Engine: ... class GenFeeder(Protocol): - def __call__(self, engine: Engine, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: - ... + def __call__(self, engine: Engine, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: ... class DisFeeder(Protocol): - def __call__(self, engine: Engine, batch: dict[str, Tensor] | None, fake: Tensor) -> dict[str, Tensor]: - ... + def __call__(self, engine: Engine, batch: dict[str, Tensor] | None, fake: Tensor) -> dict[str, Tensor]: ... @dataclass @@ -239,7 +235,7 @@ def run(self, max_steps: int = -1): @classmethod def set_running_loop_(cls, loop): assert isinstance(loop, cls), f"Expected {cls}, got {type(loop)}" - cls._running_loop: cls = loop + cls._running_loop = loop @classmethod def get_running_loop(cls) -> "TrainLoop | None":