-
Notifications
You must be signed in to change notification settings - Fork 68
/
data.py
26 lines (19 loc) · 1.22 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from os.path import exists, join, basename
from os import makedirs, remove
from six.moves import urllib
import tarfile
from torchvision.transforms import Compose, ToTensor
from dataset import DatasetFromFolderTest, DatasetFromFolder
def transform():
return Compose([
ToTensor(),
])
def get_training_set(data_dir, nFrames, upscale_factor, data_augmentation, file_list, other_dataset, patch_size, future_frame):
print("Training samples chosen:", file_list)
return DatasetFromFolder(data_dir,nFrames, upscale_factor, data_augmentation, file_list, other_dataset, patch_size,future_frame,
transform=transform())
def get_eval_set(data_dir, nFrames, upscale_factor, data_augmentation, file_list, other_dataset, patch_size, future_frame):
return DatasetFromFolder(data_dir,nFrames, upscale_factor, data_augmentation, file_list, other_dataset, patch_size,future_frame,
transform=transform())
def get_test_set(data_dir, nFrames, upscale_factor, file_list, other_dataset, future_frame, upscale_only):
return DatasetFromFolderTest(data_dir, nFrames, upscale_factor, file_list, other_dataset, future_frame, transform=transform(), upscale_only=upscale_only)