-
Notifications
You must be signed in to change notification settings - Fork 2
/
trainer.py
executable file
·874 lines (723 loc) · 37.8 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
import torch
import torch.nn.functional as F
import torchvision
import logging
import math
import os
import re
import shutil
import warnings
from contextlib import contextmanager
from pathlib import Path
from packaging import version
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm
from tqdm._tqdm import trange
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from utils.attack_util import _create_model_training_folder
from transformers.data.data_collator import DataCollator, default_data_collator
from transformers.file_utils import is_torch_tpu_available
from transformers.training_args import TrainingArguments
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
EvalPrediction,
PredictionOutput,
TrainOutput,
is_wandb_available,
set_seed,
)
from torch.nn import CrossEntropyLoss
WEIGHTS_NAME = "pytorch_model.bin"
_use_native_amp = False
_use_apex = False
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
from transformers.file_utils import is_apex_available
if is_apex_available():
from apex import amp
_use_apex = True
else:
_use_native_amp = True
from torch.cuda.amp import autocast
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
try:
from torch.utils.tensorboard import SummaryWriter
_has_tensorboard = True
except ImportError:
try:
from tensorboardX import SummaryWriter
_has_tensorboard = True
except ImportError:
_has_tensorboard = False
def is_tensorboard_available():
return _has_tensorboard
if is_wandb_available():
import wandb
logger = logging.getLogger(__name__)
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
Args:
local_rank (:obj:`int`): The rank of the local process.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
class SequentialDistributedSampler(Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def get_tpu_sampler(dataset: Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
class BertBYOLTrainer:
def __init__(
self,
args:TrainingArguments,
online_network,
target_network,
predictor,
data_collator=None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
tb_writer: Optional["SummaryWriter"] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
**params):
self.args = args
self.online_network = online_network.to(self.args.device)
self.target_network = target_network.to(self.args.device)
self.predictor = predictor.to(self.args.device)
self.data_collator = data_collator if data_collator is not None else default_data_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.optimizer, self.lr_scheduler = optimizers
self.tb_writer = tb_writer
self.m = params['m']
# self.max_epochs = params['max_epochs']
# self.batch_size = params['batch_size']
# self.num_workers = params['num_workers']
# self.checkpoint_interval = params['checkpoint_interval']
if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
if not is_tensorboard_available():
logger.warning(
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
)
if is_wandb_available():
self.setup_wandb()
elif os.environ.get("WANDB_DISABLED") != "true":
logger.info(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
)
set_seed(self.args.seed)
# Create output directory if needed
if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True)
if is_torch_tpu_available():
# Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future.
self.online_network.config.xla_device = True
self.predictor.config.xla_device = True
self.target_network.config.xla_device = True
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
self.data_collator = self.data_collator.collate_batch
warnings.warn(
(
"The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
+ "with a `collate_batch` are deprecated and won't be supported in a future version."
),
FutureWarning,
)
self.global_step = None
self.epoch = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
_create_model_training_folder(self.tb_writer, files_to_same=["./config/config.yaml", "roberta_main.py", 'trainer.py'])
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
else:
return (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
(adapted to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = self._get_train_sampler()
return DataLoader(
self.train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
)
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizer is None:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.online_network.named_parameters() if not any(nd in n for nd in no_decay)]
+ [p for n, p in self.predictor.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.online_network.named_parameters() if any(nd in n for nd in no_decay)]
+ [p for n, p in self.predictor.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
self.optimizer = AdamW(
optimizer_grouped_parameters,
lr=self.args.learning_rate,
betas=(self.args.adam_beta1, self.args.adam_beta2),
eps=self.args.adam_epsilon,
)
if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)
def setup_wandb(self):
"""
Setup the optional Weights & Biases (`wandb`) integration.
One can subclass and override this method to customize the setup if needed. Find more information
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_WATCH:
(Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
or "all" to log gradients and parameters
WANDB_PROJECT:
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
WANDB_DISABLED:
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
"""
if hasattr(self, "_setup_wandb"):
warnings.warn(
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
FutureWarning,
)
return self._setup_wandb()
if self.is_world_process_zero():
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(
[self.online_network,self.predictor,self.target_network], log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
)
def num_examples(self, dataloader: DataLoader) -> int:
"""
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
"""
return len(dataloader.dataset)
@torch.no_grad()
def _update_target_network_parameters(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@staticmethod
def regression_loss(x, y):
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
return 2-2 * (x * y).sum(dim=-1)
def initializes_target_network(self):
# init momentum network as encoder net
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
def train(self, model_path: Optional[str] = None):
train_dataloader = self.get_train_dataloader()
# model_checkpoints_folder = os.path.join(self.args.output_dir, 'checkpoints')
self.initializes_target_network()
if self.args.max_steps > 0:
t_total = self.args.max_steps
num_train_epochs = (
self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
)
else:
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs
self.create_optimizer_and_scheduler(num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if (
model_path is not None
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
):
# Load in optimizer and scheduler states
self.optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
online_network = self.online_network
predictor = self.predictor
target_network = self.target_network
if self.args.fp16 and _use_apex:
if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
online_network,predictor,target_network, self.optimizer = amp.initialize((online_network,predictor,target_network), self.optimizer, opt_level=self.args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
online_network = torch.nn.DataParallel(online_network)
predictor = torch.nn.DataParallel(predictor)
target_network = torch.nn.DataParallel(target_network)
# Distributed training (should be after apex fp16 initialization)
if self.args.local_rank != -1:
online_network = torch.nn.parallel.DistributedDataParallel(
online_network,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=True,
)
predictor = torch.nn.parallel.DistributedDataParallel(
predictor,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=True,
)
target_network = torch.nn.parallel.DistributedDataParallel(
target_network,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=True,
)
if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string())
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
# Train!
if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
else:
total_train_batch_size = (
self.args.train_batch_size
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
logger.info(" Num Epochs = %d", num_train_epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
self.global_step = 0
self.epoch = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if model_path is not None:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split("/")[0])
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = self.global_step % (
len(train_dataloader) // self.args.gradient_accumulation_steps
)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
self.global_step = 0
logger.info(" Starting fine-tuning.")
tr_loss = 0.0
logging_loss = 0.0
self.optimizer.zero_grad()
train_iterator = trange(
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_process_zero()
)
for epoch in train_iterator:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_process_zero())
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, (online_inputs, online_labels, target_inputs) in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
online_inputs = online_inputs.to(self.args.device)
online_labels = online_labels.to(self.args.device)
target_inputs = target_inputs.to(self.args.device)
masked_lm_loss, contrastive_loss, loss = self.update(online_inputs, online_labels, target_inputs)
tr_loss += loss
self.tb_writer.add_scalars('loss',{'masked_lm_loss': masked_lm_loss, 'contrastive_loss': contrastive_loss},
global_step=self.global_step)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
len(epoch_iterator) <= self.args.gradient_accumulation_steps
and (step + 1) == len(epoch_iterator)
):
if self.args.fp16 and _use_native_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(online_network.parameters(), self.args.max_grad_norm)
torch.nn.utils.clip_grad_norm_(predictor.parameters(), self.args.max_grad_norm)
elif self.args.fp16 and _use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(online_network.parameters(), self.args.max_grad_norm)
torch.nn.utils.clip_grad_norm_(predictor.parameters(), self.args.max_grad_norm)
if is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
if self.args.fp16 and _use_native_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
self._update_target_network_parameters() # update the key encoder
self.global_step += 1
self.epoch = epoch + (step + 1) / len(epoch_iterator)
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
self.global_step == 1 and self.args.logging_first_step
):
logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
logging_loss = tr_loss
self.log(logs)
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
self.evaluate()
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(online_network, "module"):
assert online_network.module is self.online_network
else:
assert online_network is self.online_network
# Save model checkpoint
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
# save checkpoints
self.save_model(output_dir)
if self.is_world_process_zero():
self._rotate_checkpoints()
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
elif self.is_world_process_zero():
pass
# torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
# torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
train_iterator.close()
break
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
if self.tb_writer:
self.tb_writer.close()
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(self.global_step, tr_loss / self.global_step)
def update(self, online_inputs, online_labels, target_inputs):
# compute query feature
if self.args.fp16 and _use_native_amp:
with autocast():
prediction_scores_view_1, projection_view_1 = self.online_network(online_inputs)
prediction_scores_view_2, projection_view_2 = self.online_network(target_inputs)
predictions_from_view_1 = self.predictor(projection_view_1)
predictions_from_view_2 = self.predictor(projection_view_2)
# compute key features
with torch.no_grad():
targets_to_view_2 = self.target_network(online_inputs)[1]
targets_to_view_1 = self.target_network(target_inputs)[1]
reg_loss = self.regression_loss(predictions_from_view_1, targets_to_view_1)
reg_loss += self.regression_loss(predictions_from_view_2, targets_to_view_2)
contrastive_loss = reg_loss.mean()
if online_labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores_view_1.view((-1, prediction_scores_view_1.shape[-1])),
online_labels.view(-1))
else:
prediction_scores_view_1, projection_view_1 = self.online_network(online_inputs)
prediction_scores_view_2, projection_view_2 = self.online_network(target_inputs)
predictions_from_view_1 = self.predictor(projection_view_1)
predictions_from_view_2 = self.predictor(projection_view_2)
# compute key features
with torch.no_grad():
targets_to_view_2 = self.target_network(online_inputs)[1]
targets_to_view_1 = self.target_network(target_inputs)[1]
reg_loss = self.regression_loss(predictions_from_view_1, targets_to_view_1)
reg_loss += self.regression_loss(predictions_from_view_2, targets_to_view_2)
contrastive_loss = reg_loss.mean()
if online_labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores_view_1.view((-1, prediction_scores_view_1.shape[-1])),
online_labels.view(-1))
if self.args.n_gpu > 1:
masked_lm_loss = masked_lm_loss.mean()
contrastive_loss = contrastive_loss.mean() # mean() to average on multi-gpu parallel training
loss = masked_lm_loss + contrastive_loss
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
if self.args.fp16 and _use_native_amp:
self.scaler.scale(loss).backward()
elif self.args.fp16 and _use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
return masked_lm_loss.item(),contrastive_loss.item(),loss.item()
def is_local_process_zero(self) -> bool:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
"""
if is_torch_tpu_available():
return xm.is_master_ordinal(local=True)
else:
return self.args.local_rank in [-1, 0]
def is_world_master(self) -> bool:
"""
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
.. warning::
This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
"""
warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
return self.is_world_process_zero()
def is_world_process_zero(self) -> bool:
"""
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
"""
if is_torch_tpu_available():
return xm.is_master_ordinal(local=False)
else:
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
"""
Log :obj:`logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (:obj:`Dict[str, float]`):
The values to log.
iterator (:obj:`tqdm`, `optional`):
A potential tqdm progress bar to write the logs on.
"""
if hasattr(self, "_log"):
warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
FutureWarning,
)
return self._log(logs, iterator=iterator)
if self.epoch is not None:
logs["epoch"] = self.epoch
if self.global_step is None:
# when logging evaluation metrics without training
self.global_step = 0
if self.tb_writer:
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, self.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
'"%s" of type %s for key "%s" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute.",
v,
type(v),
k,
)
self.tb_writer.flush()
if is_wandb_available():
if self.is_world_process_zero():
wandb.log(logs, step=self.global_step)
output = {**logs, **{"step": self.global_step}}
if iterator is not None:
iterator.write(output)
else:
print(output)
def save_model(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
output_model_file = os.path.join(output_dir,WEIGHTS_NAME)
#if not isinstance(self.model, PreTrainedModel):
# raise ValueError("Trainer.model appears to not be a PreTrainedModel")
#self.model.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
torch.save({
'online_network_state_dict': self.online_network.state_dict(),
'target_network_state_dict': self.target_network.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, output_model_file)
logger.info("Model weights saved in {}".format(output_model_file))
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
for path in glob_checkpoints:
if use_mtime:
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
else:
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
if regex_match and regex_match.groups():
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def _rotate_checkpoints(self, use_mtime=False) -> None:
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
if len(checkpoints_sorted) <= self.args.save_total_limit:
return
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint)
class BYOLTrainer:
def __init__(self, online_network, target_network, predictor, optimizer, device, **params):
self.online_network = online_network
self.target_network = target_network
self.optimizer = optimizer
self.device = device
self.predictor = predictor
self.max_epochs = params['max_epochs']
self.writer = SummaryWriter()
self.m = params['m']
self.batch_size = params['batch_size']
self.num_workers = params['num_workers']
self.checkpoint_interval = params['checkpoint_interval']
_create_model_training_folder(self.writer, files_to_same=["./config/config.yaml", "main.py", 'trainer.py'])
@torch.no_grad()
def _update_target_network_parameters(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@staticmethod
def regression_loss(x, y):
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
return -2 * (x * y).sum(dim=-1)
def initializes_target_network(self):
# init momentum network as encoder net
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
def train(self, train_dataset):
train_loader = DataLoader(train_dataset, batch_size=self.batch_size,
num_workers=self.num_workers, drop_last=False, shuffle=True)
niter = 0
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
self.initializes_target_network()
for epoch_counter in range(self.max_epochs):
for batch_view_1, batch_view_2 in train_loader:
batch_view_1 = batch_view_1.to(self.device)
batch_view_2 = batch_view_2.to(self.device)
if niter == 0:
grid = torchvision.utils.make_grid(batch_view_1[:32])
self.writer.add_image('views_1', grid, global_step=niter)
grid = torchvision.utils.make_grid(batch_view_2[:32])
self.writer.add_image('views_2', grid, global_step=niter)
loss = self.update(batch_view_1, batch_view_2)
self.writer.add_scalar('loss', loss, global_step=niter)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self._update_target_network_parameters() # update the key encoder
niter += 1
print("End of epoch {}".format(epoch_counter))
# save checkpoints
self.save_model(os.path.join(model_checkpoints_folder, 'model.pth'))
def update(self, batch_view_1, batch_view_2):
# compute query feature
predictions_from_view_1 = self.predictor(self.online_network(batch_view_1))
predictions_from_view_2 = self.predictor(self.online_network(batch_view_2))
# compute key features
with torch.no_grad():
targets_to_view_2 = self.target_network(batch_view_1)
targets_to_view_1 = self.target_network(batch_view_2)
loss = self.regression_loss(predictions_from_view_1, targets_to_view_1)
loss += self.regression_loss(predictions_from_view_2, targets_to_view_2)
return loss.mean()
def save_model(self, PATH):
torch.save({
'online_network_state_dict': self.online_network.state_dict(),
'target_network_state_dict': self.target_network.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, PATH)