-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtemplate.py
56 lines (33 loc) · 1.49 KB
/
template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from dataloaders import init_dataloaders
def main(args):
#--------------------------- BOILERPLATE ------------------------#
args.device = 'cpu'
#--------------------------- DATASETS ---------------------------#
meta_train_dataloader, meta_val_dataloader, cl_dataloader = init_dataloaders(args)
#---------------------- PRETRAINING TIME ------------------------#
if args.pretrain_model is None:
print('Pretraining time')
if args.num_epochs==0:
pass
else:
for epoch in range(args.num_epochs):
for batch in meta_train_dataloader:
'''
batch = {'train', 'test'}
batch['train'][0] = batch-size x num_shots*num_ways x input_dim
batch['train'][1] = batch-size x num_shots*num_ways x output_dim
batch['test'][0] = batch-size x num_shots-test*num_ways x input_dim
batch['test'][1] = batch-size x num_shots-test*num_ways x output_dim
'''
pass
for batch in meta_val_dataloader:
pass
#-------------------------- CL TIME -----------------------------#
print('Continual learning time')
for run in range(args.n_runs):
for i, batch in enumerate(cl_dataloader):
data, labels, task_switch, mode = batch
if __name__ == "__main__":
from args import parse_args
args = parse_args()
main(args)