diff --git a/datasets/dataset.py b/datasets/dataset.py index b469466..40e94e2 100644 --- a/datasets/dataset.py +++ b/datasets/dataset.py @@ -1,4 +1,4 @@ -from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FakeData +from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FakeData, ImageFolder from datasets.car_dataset import CarDataset def get_dataset(name, root_dir, train=None, transform=None): @@ -23,6 +23,8 @@ def get_dataset(name, root_dir, train=None, transform=None): return CIFAR100(root=root_dir, train=train, download=True, transform=transform) elif name == 'MNIST': return MNIST(root=root_dir, train=train, download=True, transform=transform) + elif name == 'Imagenette': + return ImageFolder(root=root_dir, transform=transform) elif name == 'FakeData': return FakeData(size=200, image_size=(3, 32, 32), num_classes=10, transform=transform) elif name == 'CarDataset':