Skip to content

Commit

Permalink
Whole families supported advances #21
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 22, 2024
1 parent f461afc commit 727e31f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
10 changes: 10 additions & 0 deletions models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ def get_efficientnet(model_name, num_classes, pretrained=True):
model = models.efficientnet_b1(weights=weights)
elif model_name == "efficientnet_b2":
model = models.efficientnet_b2(weights=weights)
elif model_name == "efficientnet_b3":
model = models.efficientnet_b3(weights=weights)
elif model_name == "efficientnet_b4":
model = models.efficientnet_b4(weights=weights)
elif model_name == "efficientnet_b5":
model = models.efficientnet_b5(weights=weights)
elif model_name == "efficientnet_b6":
model = models.efficientnet_b6(weights=weights)
elif model_name == "efficientnet_b7":
model = models.efficientnet_b7(weights=weights)
else:
raise ValueError("Unsupported EfficientNet version")

Expand Down
4 changes: 4 additions & 0 deletions models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def get_resnet(model_name, num_classes, pretrained=True):
model = models.resnet34(weights=weights)
elif model_name == "resnet50":
model = models.resnet50(weights=weights)
elif model_name == "resnet101":
model = models.resnet101(weights=weights)
elif model_name == "resnet152":
model = models.resnet152(weights=weights)
else:
raise ValueError("Unsupported ResNet version")

Expand Down

0 comments on commit 727e31f

Please sign in to comment.