diff --git a/datasets/car_dataset.py b/datasets/car_dataset.py index 60692c0..f0e3849 100644 --- a/datasets/car_dataset.py +++ b/datasets/car_dataset.py @@ -34,7 +34,7 @@ def _load_dataset(self): # Walk through the data directory for root, _, filenames in os.walk(self.data_dir): for filename in filenames: - if filename.endswith(".jpg"): + if filename.endswith(".jpg") or filename.endswith(".png"): label_name = os.path.basename(root) if label_name not in label_to_idx: label_to_idx[label_name] = current_label