Skip to content

Commit

Permalink
Added weighting support to the CarDataset advances #26
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Jun 24, 2024
1 parent 37c14da commit 6885884
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions datasets/car_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import torch
from PIL import Image
from torch.utils.data import Dataset

Expand All @@ -14,7 +15,8 @@ def __init__(self, data_dir, train=None, transform=None):
"""
self.data_dir = data_dir
self.transform = transform
self.images, self.labels, self.idx_to_class = self._load_dataset()
self.images, self.labels, self.idx_to_class, self.label_to_count = self._load_dataset()
self.class_weights = self._calculate_class_weights()

def _load_dataset(self):
"""
Expand All @@ -25,27 +27,30 @@ def _load_dataset(self):
labels (list): List of corresponding labels.
idx_to_class (dict): Mapping of label index to class name.
"""

images = []
labels = []
label_to_idx = {}
idx_to_class = {}
label_to_count = {}
current_label = 0

# Walk through the data directory
for root, _, filenames in os.walk(self.data_dir):
for filename in filenames:
if filename.endswith(".jpg") or filename.endswith(".png"):
label_name = os.path.basename(root)
if label_name not in label_to_idx:
label_to_idx[label_name] = current_label
idx_to_class[current_label] = label_name
label_to_count[current_label] = 0
current_label += 1
label = label_to_idx[label_name]
label_to_count[label] += 1

images.append(os.path.join(root, filename))
labels.append(label)

return images, labels, idx_to_class
return images, labels, idx_to_class, label_to_count

def __len__(self):
"""
Expand All @@ -57,17 +62,14 @@ def __len__(self):
return len(self.images)

def __getitem__(self, idx):
"""
Get a sample from the dataset.
Args:
idx (int): Index of the sample.
if isinstance(idx, tuple):
return [self.get_single_item(i) for i in idx]
else:
return self.get_single_item(idx)

Returns:
tuple: A tuple containing the image and its corresponding label.
"""
def get_single_item(self, idx):
image_path = self.images[idx]
image = Image.open(image_path)
image = Image.open(image_path).convert('RGB')
label = self.labels[idx]

if self.transform:
Expand All @@ -83,3 +85,12 @@ def get_classes(self):
list: List of class names.
"""
return [self.idx_to_class[idx] for idx in sorted(self.idx_to_class)]

def _calculate_class_weights(self):
total_count = sum(self.label_to_count.values())
weights = {label: total_count / count for label, count in self.label_to_count.items()}
return weights

def get_class_weights(self):
weights = torch.tensor([self.class_weights[label] for label in sorted(self.class_weights)])
return weights

0 comments on commit 6885884

Please sign in to comment.