diff --git a/README.md b/README.md index a8f1ccf..5f7b32e 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ _That's it_! You have ObjectCut running on port 80 routing traffic using _traefi ### Change underlying model -This project was built using [BASNet](https://github.com/NathanUA/BASNet) as the model for inferring the Salient Object Detection. However, in order to test other ones we added the support to select also [U^2-Net](https://github.com/NathanUA/U-2-Net), also implemented by [Xuebin Qin](https://github.com/NathanUA), in the Inference container specifying it as a environment variable called `MODEL`. You can do that setting your model name at [docker-compose.yml](docker-compose.yml): +This project was built using [BASNet](https://github.com/NathanUA/BASNet) as the model for inferring the Salient Object Detection. However, in order to test other ones we added the support to select also the different versions of [U^2-Net](https://github.com/NathanUA/U-2-Net) (`U2NET`, `U2NETP` and `U2NETPORTRAIT`), also implemented by [Xuebin Qin](https://github.com/NathanUA), in the Inference container specifying it as a environment variable called `MODEL`. You can do that setting your model name at [docker-compose.yml](docker-compose.yml): ```yaml inference: @@ -130,7 +130,7 @@ inference: - object_cut restart: always environment: - - MODEL=BASNet # Can also be `U2NET` + - MODEL=BASNet # Can also be `U2NET`, `U2NETP` or `U2NETPORTRAIT` ``` ### Integrations diff --git a/docker-compose.yml b/docker-compose.yml index 6246bd0..7d159bf 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -44,7 +44,7 @@ services: - object_cut restart: always environment: - - MODEL=BASNet # Can also be `U2NET` + - MODEL=U2NETP # Can also be `BASNet`, `U2NET` or `U2NETPORTRAIT` labels: - 'traefik.enable=true' - 'traefik.docker.network=object_cut' diff --git a/inference/Dockerfile b/inference/Dockerfile index 5ccb1c8..b5f32c2 100644 --- a/inference/Dockerfile +++ b/inference/Dockerfile @@ -46,6 +46,7 @@ ADD ./requirements.lock ${HOME}/requirements.lock RUN ${HOME}/gdrive_download.sh 1s52ek_4YTDRt_EOkx1FS53u-vJa0c4nu ${HOME}/data/basnet.pth RUN ${HOME}/gdrive_download.sh 1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ ${HOME}/data/u2net.pth RUN ${HOME}/gdrive_download.sh 1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy ${HOME}/data/u2netp.pth +RUN ${HOME}/gdrive_download.sh 1IG3HdpcRiDoWNookbncQjeaPN28t90yW ${HOME}/data/u2netportrait.pth # Install dependencies RUN python3 -m pip install pip --upgrade diff --git a/inference/src/main.py b/inference/src/main.py index b135447..c99cc08 100644 --- a/inference/src/main.py +++ b/inference/src/main.py @@ -26,6 +26,7 @@ # Load model model_name = os.environ.get('MODEL', Model.BASNet.name) # BASNet as default +log.info('Model name: [{}]'.format(model_name)) assert model_name in Model.list() model_path = os.path.join('data', '{}.pth'.format(model_name.lower())) log.info('Model path: [{}]'.format(model_path)) diff --git a/inference/src/utils/model_enum.py b/inference/src/utils/model_enum.py index 3b9de6b..8f34845 100644 --- a/inference/src/utils/model_enum.py +++ b/inference/src/utils/model_enum.py @@ -1,12 +1,14 @@ from enum import Enum -from src.u2_net.model import U2NET +from src.u2_net.model import U2NET, U2NETP from src.bas_net.model import BASNet class Model(Enum): U2NET = U2NET # U2NET + U2NETP = U2NETP # U2NETP + U2NETPORTRAIT = U2NET # U2NETPORTRAIT BASNet = BASNet # BASNet def __str__(self): @@ -14,4 +16,4 @@ def __str__(self): @staticmethod def list(): - return [m.name for m in Model] + return [m for m in Model.__members__.keys()] diff --git a/multiplexer/test/api/test_remove.py b/multiplexer/test/api/test_remove.py index b5f33ac..7450ba1 100644 --- a/multiplexer/test/api/test_remove.py +++ b/multiplexer/test/api/test_remove.py @@ -8,7 +8,7 @@ class MultiplexerRemoveTest(BaseTestClass): def setUp(self): self.secret_access = env.get_secret_access() - self.img_url = 'https://objectcut.com/docs/images/object-cut.png' + self.img_url = 'https://objectcut.com/assets/img/raven.jpg' self.img_url_wrong = 'https://example.com/not-existing.jpg' self.img_base64_wrong = 'not-a-base64' self.img_path = os.path.join('test', 'data', 'person.jpg')