Skip to content

Commit

Permalink
Update inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
r-sarma authored Oct 14, 2024
1 parent d64bd14 commit d544901
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/itwinai/torch/inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import os
from typing import Any, Dict, List, Literal, Optional, Union

Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(
self.config = TrainingConfiguration(**config)
else:
self.config = config
self.model = self.model
self.model = self.model.eval()
self.strategy = strategy
self.logger = logger
self.checkpoints_location = checkpoints_location
Expand Down Expand Up @@ -154,7 +155,7 @@ def distribute_model(self) -> None:
)
distribute_kwargs = {}
# Distributed model, optimizer, and scheduler
self.model,_,_ = self.strategy.distributed(
self.model, _, _ = self.strategy.distributed(
self.model, None, None, **distribute_kwargs
)

Expand Down Expand Up @@ -214,9 +215,10 @@ def execute(
self.logger.save_hyperparameters(hparams)

all_predictions = dict()
for samples_ids, samples in inference_dataset:
for ids, (samples_ids, samples) in enumerate(self.inference_dataloader):
with torch.no_grad():
pred = self.model(samples)
pred = self.model(samples.to(self.device))
pred = self.transform_predictions(pred)
for idx, pre in zip(samples_ids, pred):
# For each item in the batch
if pre.numel() == 1:
Expand Down Expand Up @@ -263,6 +265,13 @@ def log(
batch_idx=batch_idx,
**kwargs
)

@abc.abstractmethod
def transform_predictions(self, batch: Batch) -> Batch:
"""
Post-process the predictions of the torch model (e.g., apply
threshold in case of multi-label classifier).
"""


class MulticlassTorchPredictor(TorchPredictor):
Expand Down

0 comments on commit d544901

Please sign in to comment.