diff --git a/tests/test_dp_cgans.py b/tests/test_dp_cgans.py index 4afeedd..fe3c189 100644 --- a/tests/test_dp_cgans.py +++ b/tests/test_dp_cgans.py @@ -13,7 +13,7 @@ def test_dp_cgans(): tabular_data=pd.read_csv("resources/example_tabular_data_UCIAdult.csv") model = DP_CGAN( - epochs=1, # number of training epochs + epochs=10, # number of training epochs batch_size=1000, # the size of each batch log_frequency=True, verbose=False, @@ -23,6 +23,7 @@ def test_dp_cgans(): discriminator_lr=2e-4, discriminator_steps=1, private=False, + wandb=False ) model.fit(tabular_data)