diff --git a/models/efficientnet.py b/models/efficientnet.py index cea817c..fcc097c 100644 --- a/models/efficientnet.py +++ b/models/efficientnet.py @@ -24,6 +24,16 @@ def get_efficientnet(model_name, num_classes, pretrained=True): model = models.efficientnet_b1(weights=weights) elif model_name == "efficientnet_b2": model = models.efficientnet_b2(weights=weights) + elif model_name == "efficientnet_b3": + model = models.efficientnet_b3(weights=weights) + elif model_name == "efficientnet_b4": + model = models.efficientnet_b4(weights=weights) + elif model_name == "efficientnet_b5": + model = models.efficientnet_b5(weights=weights) + elif model_name == "efficientnet_b6": + model = models.efficientnet_b6(weights=weights) + elif model_name == "efficientnet_b7": + model = models.efficientnet_b7(weights=weights) else: raise ValueError("Unsupported EfficientNet version") diff --git a/models/resnet.py b/models/resnet.py index 063fb35..10bdd13 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -24,6 +24,10 @@ def get_resnet(model_name, num_classes, pretrained=True): model = models.resnet34(weights=weights) elif model_name == "resnet50": model = models.resnet50(weights=weights) + elif model_name == "resnet101": + model = models.resnet101(weights=weights) + elif model_name == "resnet152": + model = models.resnet152(weights=weights) else: raise ValueError("Unsupported ResNet version")