Skip to content

Commit

Permalink
fixup! Add pytorch losses and metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
timokau committed Apr 8, 2021
1 parent 7cf35f2 commit b43b25d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
3 changes: 1 addition & 2 deletions csrank/discrete_choice_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class CategoricalHingeLossMax:
arXiv:1901.10860.
"""

# Should be: First true, then predicted (as in Cross-Entropy Loss)
# But for some reason skorch calls it with a swapped argument order.
# The argument order is chosen to be compatible with skorch.
def __call__(self, scores, true_choice):
"""Compute the loss of a scoring in the context of a choice.
Expand Down
3 changes: 1 addition & 2 deletions csrank/rank_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ class HingedRankLoss:
The total loss, summed over all instances.
"""

# Should be: First true, then predicted (as in Cross-Entropy Loss)
# But for some reason skorch calls it with a swapped argument order.
# The argument order is chosen to be compatible with skorch.
def __call__(self, comparison_scores, true_rankings):
# 2d matrix which is 1 if the row-element *should* be ranked higher than the column element
mask = true_rankings[:, :, None] > true_rankings[:, None]
Expand Down

0 comments on commit b43b25d

Please sign in to comment.