From 727e31ff73ebc63c95626e12e454429921eb1e31 Mon Sep 17 00:00:00 2001 From: Pablo Olivares Date: Mon, 22 Apr 2024 12:17:08 +0200 Subject: [PATCH] Whole families supported advances #21 --- models/efficientnet.py | 10 ++++++++++ models/resnet.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/models/efficientnet.py b/models/efficientnet.py index cea817c..fcc097c 100644 --- a/models/efficientnet.py +++ b/models/efficientnet.py @@ -24,6 +24,16 @@ def get_efficientnet(model_name, num_classes, pretrained=True): model = models.efficientnet_b1(weights=weights) elif model_name == "efficientnet_b2": model = models.efficientnet_b2(weights=weights) + elif model_name == "efficientnet_b3": + model = models.efficientnet_b3(weights=weights) + elif model_name == "efficientnet_b4": + model = models.efficientnet_b4(weights=weights) + elif model_name == "efficientnet_b5": + model = models.efficientnet_b5(weights=weights) + elif model_name == "efficientnet_b6": + model = models.efficientnet_b6(weights=weights) + elif model_name == "efficientnet_b7": + model = models.efficientnet_b7(weights=weights) else: raise ValueError("Unsupported EfficientNet version") diff --git a/models/resnet.py b/models/resnet.py index 063fb35..10bdd13 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -24,6 +24,10 @@ def get_resnet(model_name, num_classes, pretrained=True): model = models.resnet34(weights=weights) elif model_name == "resnet50": model = models.resnet50(weights=weights) + elif model_name == "resnet101": + model = models.resnet101(weights=weights) + elif model_name == "resnet152": + model = models.resnet152(weights=weights) else: raise ValueError("Unsupported ResNet version")