Skip to content

Commit

Permalink
fixed issue with local_rank on colab
Browse files Browse the repository at this point in the history
  • Loading branch information
Damien Sileo committed Jun 30, 2023
1 parent a6c7489 commit aa7edca
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,12 +452,8 @@ def get_single_train_dataloader(self, task_name, train_dataset):
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = (
RandomSampler(train_dataset)
if self.args.local_rank == -1
else DistributedSampler(train_dataset)
)

train_sampler = (RandomSampler(train_dataset) if torch.cuda.device_count()<2 or self.args.local_rank == -1 else DistributedSampler(train_dataset))

data_loader = DataLoaderWithTaskname(
task_name=task_name,
data_loader=DataLoader(
Expand Down

0 comments on commit aa7edca

Please sign in to comment.