diff --git a/factories/loss_factory.py b/factories/loss_factory.py index 06fbb3d..efcf988 100644 --- a/factories/loss_factory.py +++ b/factories/loss_factory.py @@ -4,5 +4,5 @@ class LossFactory(Factory): def __init__(self): super().__init__() - self.register("CrossEntropyLoss", lambda: CrossEntropyLoss()) - self.register("MSELoss", lambda: MSELoss()) + self.register("CrossEntropyLoss", lambda **kwargs: CrossEntropyLoss(**kwargs)) + self.register("MSELoss", lambda **kwargs: MSELoss(**kwargs)) \ No newline at end of file