diff --git a/datasets/car_dataset.py b/datasets/car_dataset.py index f0e3849..8e1fdf6 100644 --- a/datasets/car_dataset.py +++ b/datasets/car_dataset.py @@ -1,4 +1,5 @@ import os +import torch from PIL import Image from torch.utils.data import Dataset @@ -14,7 +15,8 @@ def __init__(self, data_dir, train=None, transform=None): """ self.data_dir = data_dir self.transform = transform - self.images, self.labels, self.idx_to_class = self._load_dataset() + self.images, self.labels, self.idx_to_class, self.label_to_count = self._load_dataset() + self.class_weights = self._calculate_class_weights() def _load_dataset(self): """ @@ -25,13 +27,14 @@ def _load_dataset(self): labels (list): List of corresponding labels. idx_to_class (dict): Mapping of label index to class name. """ + images = [] labels = [] label_to_idx = {} idx_to_class = {} + label_to_count = {} current_label = 0 - # Walk through the data directory for root, _, filenames in os.walk(self.data_dir): for filename in filenames: if filename.endswith(".jpg") or filename.endswith(".png"): @@ -39,13 +42,15 @@ def _load_dataset(self): if label_name not in label_to_idx: label_to_idx[label_name] = current_label idx_to_class[current_label] = label_name + label_to_count[current_label] = 0 current_label += 1 label = label_to_idx[label_name] + label_to_count[label] += 1 images.append(os.path.join(root, filename)) labels.append(label) - return images, labels, idx_to_class + return images, labels, idx_to_class, label_to_count def __len__(self): """ @@ -57,17 +62,14 @@ def __len__(self): return len(self.images) def __getitem__(self, idx): - """ - Get a sample from the dataset. - - Args: - idx (int): Index of the sample. + if isinstance(idx, tuple): + return [self.get_single_item(i) for i in idx] + else: + return self.get_single_item(idx) - Returns: - tuple: A tuple containing the image and its corresponding label. - """ + def get_single_item(self, idx): image_path = self.images[idx] - image = Image.open(image_path) + image = Image.open(image_path).convert('RGB') label = self.labels[idx] if self.transform: @@ -83,3 +85,12 @@ def get_classes(self): list: List of class names. """ return [self.idx_to_class[idx] for idx in sorted(self.idx_to_class)] + + def _calculate_class_weights(self): + total_count = sum(self.label_to_count.values()) + weights = {label: total_count / count for label, count in self.label_to_count.items()} + return weights + + def get_class_weights(self): + weights = torch.tensor([self.class_weights[label] for label in sorted(self.class_weights)]) + return weights