-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataLoader.py
30 lines (29 loc) · 1.15 KB
/
dataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torchvision import transforms
def get_dataloaders(input_size, batch_size, shuffle = True):
'''
Create the dataloaders for train, validation and test set. Randomly rotate images for data augumentation
Normalization based on std and mean.
'''
data_transforms = {
'train': transforms.Compose([
transforms.RandomRotation(25),
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.RandomRotation(25),
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.RandomRotation(25),
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}