Skip to content

Commit

Permalink
fixup! Prepare for pytorch tests
Browse files Browse the repository at this point in the history
  • Loading branch information
timokau committed Apr 8, 2021
1 parent 070a56e commit 2d5d425
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
5 changes: 5 additions & 0 deletions csrank/tests/test_discrete_choice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from pymc3.variational.callbacks import CheckParametersConvergence
import pytest
import torch
from torch import optim

from csrank.constants import FATE_DC
Expand Down Expand Up @@ -80,6 +81,10 @@ def trivial_discrete_choice_problem():
@pytest.mark.parametrize("name", list(discrete_choice_functions.keys()))
def test_discrete_choice_function_fixed(trivial_discrete_choice_problem, name):
np.random.seed(123)
# There are some caveats with pytorch reproducibility. See the comment on
# the corresponding line of `test_choice_functions.py` for details.
torch.manual_seed(123)
torch.use_deterministic_algorithms(True)
x, y = trivial_discrete_choice_problem
choice_function = discrete_choice_functions[name][0]
params, accuracies = (
Expand Down
5 changes: 5 additions & 0 deletions csrank/tests/test_ranking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
import torch
from torch import optim

from csrank.constants import ERR
Expand Down Expand Up @@ -52,6 +53,10 @@ def trivial_ranking_problem():
@pytest.mark.parametrize("ranker_name", list(object_rankers.keys()))
def test_object_ranker_fixed(trivial_ranking_problem, ranker_name):
np.random.seed(123)
# There are some caveats with pytorch reproducibility. See the comment on
# the corresponding line of `test_choice_functions.py` for details.
torch.manual_seed(123)
torch.use_deterministic_algorithms(True)
x, y = trivial_ranking_problem
ranker, params, (loss, acc) = object_rankers[ranker_name]
ranker = ranker(**params)
Expand Down

0 comments on commit 2d5d425

Please sign in to comment.