Skip to content

Commit

Permalink
Support force gaussian prior for CFM model, fix some typing
Browse files Browse the repository at this point in the history
  • Loading branch information
enhuiz committed Dec 3, 2024
1 parent 1eb5fc8 commit 8e97814
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 15 deletions.
5 changes: 4 additions & 1 deletion resemble_enhance/enhancer/enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
y_mel = _maybe(self.to_mel)(y) # (b d t)
y_mel = _maybe(self.normalizer)(y_mel)

lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t)
if self.hp.force_gaussian_prior:
lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=None) # (b d t)
else:
lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t)

if lcfm_decoded is None:
o = None
Expand Down
4 changes: 4 additions & 0 deletions resemble_enhance/enhancer/hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class HParams(HParamsBase):

lcfm_latent_dim: int = 64
lcfm_training_mode: str = "ae"
# This value should be carefully tuned when training. Better estimate it from the latent vectors first
lcfm_z_scale: float = 5

vocoder_extra_dim: int = 32
Expand All @@ -21,3 +22,6 @@ class HParams(HParamsBase):
enhancer_stage1_run_dir: Path | None = None

denoiser_run_dir: Path | None = None

# Enable this increases the training stability (but will also disable the change of eval_tau)
force_gaussian_prior: bool = False
3 changes: 1 addition & 2 deletions resemble_enhance/enhancer/lcfm/cfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@


class VelocityField(Protocol):
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
...
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: ...


class Solver:
Expand Down
6 changes: 3 additions & 3 deletions resemble_enhance/enhancer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def load_G(run_dir: Path, hp: HParams | None = None, training=True):
return engine


def load_D(run_dir: Path, hp: HParams):
def load_D(run_dir: Path, hp: HParams | None):
if hp is None:
hp = HParams.load(run_dir)
assert isinstance(hp, HParams)
Expand All @@ -41,8 +41,8 @@ def load_D(run_dir: Path, hp: HParams):


def save_wav(path: Path, wav: Tensor, rate: int):
wav = wav.detach().cpu().numpy()
soundfile.write(path, wav, samplerate=rate)
wav_numpy = wav.detach().cpu().numpy()
soundfile.write(path, wav_numpy, samplerate=rate)


def main():
Expand Down
14 changes: 5 additions & 9 deletions resemble_enhance/utils/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,19 @@


class EvalFn(Protocol):
def __call__(self, engine: Engine, eval_dir: Path) -> None:
...
def __call__(self, engine: Engine, eval_dir: Path) -> None: ...


class EngineLoader(Protocol):
def __call__(self, run_dir: Path) -> Engine:
...
def __call__(self, run_dir: Path) -> Engine: ...


class GenFeeder(Protocol):
def __call__(self, engine: Engine, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
...
def __call__(self, engine: Engine, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: ...


class DisFeeder(Protocol):
def __call__(self, engine: Engine, batch: dict[str, Tensor] | None, fake: Tensor) -> dict[str, Tensor]:
...
def __call__(self, engine: Engine, batch: dict[str, Tensor] | None, fake: Tensor) -> dict[str, Tensor]: ...


@dataclass
Expand Down Expand Up @@ -239,7 +235,7 @@ def run(self, max_steps: int = -1):
@classmethod
def set_running_loop_(cls, loop):
assert isinstance(loop, cls), f"Expected {cls}, got {type(loop)}"
cls._running_loop: cls = loop
cls._running_loop = loop

@classmethod
def get_running_loop(cls) -> "TrainLoop | None":
Expand Down

0 comments on commit 8e97814

Please sign in to comment.