Skip to content

Commit

Permalink
ADD: threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
matbun committed Nov 7, 2023
1 parent be865f3 commit 55aac61
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/itwinai/torch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ class MultilabelTorchPredictor(TorchPredictor):
output of the neural network.
"""

threshold: float

def __init__(
self,
model: Union[nn.Module, ModelLoader],
Expand All @@ -200,6 +202,9 @@ def __init__(
)
self.threshold = threshold

def transform_predictions(self, batch: Batch) -> Batch:
return (batch > self.threshold).float()


class RegressionTorchPredictor(TorchPredictor):
"""
Expand Down

0 comments on commit 55aac61

Please sign in to comment.