Skip to content

Commit

Permalink
Merge pull request #18 from uug-ai/fix_bug
Browse files Browse the repository at this point in the history
Fix occasional crash bug due to different video resolutions
  • Loading branch information
cedricve authored Sep 11, 2024
2 parents 43103c3 + 4627371 commit e35b2a8
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 50 deletions.
32 changes: 32 additions & 0 deletions projects/base_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
)
from projects.ibase_project import IBaseProject
from utils.VariableClass import VariableClass
from ultralytics import YOLO

import yaml
import os
import torch


class BaseProject(IBaseProject):
Expand All @@ -24,6 +28,7 @@ def __init__(self):
self.proj_dir = None
self.mapping = None
self.device = None
self.models = []

def condition_func(self, total_results):
"""
Expand Down Expand Up @@ -59,6 +64,9 @@ def connect_models(self):
raise NotImplemented('Should override this!!!')

def __read_config__(self, path):
"""
See ibase_project.py
"""
with open(path, 'r') as file:
config = yaml.safe_load(file)

Expand All @@ -72,3 +80,27 @@ def __read_config__(self, path):

raise TypeError('Error while reading configuration file, '
'make sure models and allowed_classes have the same size')

def __connect_models__(self):
"""
See ibase_project.py
"""
_cur_dir = os.getcwd()
# initialise the yolo model, additionally use the device parameter to specify the device to run the model on.
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
_cur_dir = pdirname(pabspath(__file__))
model_dir = pjoin(_cur_dir, f'../models')
model_dir = pabspath(model_dir) # normalise the link

models = []
for model_name in self._config.get('models'):
model = YOLO(pjoin(model_dir, model_name)).to(self.device)
models.append(model)

return models

def reset_models(self):
"""
See ibase_project.py
"""
self.models = self.__connect_models__()
27 changes: 3 additions & 24 deletions projects/helmet/helmet_project.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
from os.path import (
join as pjoin,
dirname as pdirname,
abspath as pabspath
)

from ultralytics import YOLO

from projects.base_project import BaseProject
from projects.helmet.ihelmet_project import IHelmetProject

import os
import torch

config_path = './projects/helmet/helmet_config.yaml'


Expand Down Expand Up @@ -117,24 +106,14 @@ def connect_models(self):
Initializes the YOLO models and connects them to the appropriate device (CPU or GPU).
Returns:
tuple: A tuple containing two YOLO models.
models: A tuple containing two YOLO models.
models_allowed_classes: List of corresponding allowed classes for each model.
Raises:
ModuleNotFoundError: If the models cannot be loaded.
"""

_cur_dir = os.getcwd()
# initialise the yolo model, additionally use the device parameter to specify the device to run the model on.
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
_cur_dir = pdirname(pabspath(__file__))
model_dir = pjoin(_cur_dir, f'../../models')
model_dir = pabspath(model_dir) # normalise the link

models = []
for model_name in self._config.get('models'):
model = YOLO(pjoin(model_dir, model_name)).to(self.device)
models.append(model)

models = self.__connect_models__()
models_allowed_classes = self._config.get('allowed_classes')

if not models:
Expand Down
34 changes: 34 additions & 0 deletions projects/ibase_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,37 @@ def create_proj_save_dir(self):
Create project save directory after initializing the project.
"""
pass

@abstractmethod
def __read_config__(self, path):
"""
Read project's configuration file.
Returns:
tuple: Configuration file in dictionary format.
Raises:
TypeError: If the models cannot be loaded.
"""
pass

@abstractmethod
def __connect_models__(self):
"""
Initializes the YOLO models and connects them to the appropriate device (CPU or GPU).
Returns:
tuple: A tuple containing two YOLO models.
Raises:
ModuleNotFoundError: If the models cannot be loaded.
"""
pass

@abstractmethod
def reset_models(self):
"""
Reset model after processing video to avoid memory allocation error when the upcoming video comes in with
different resolution.
"""
pass
30 changes: 4 additions & 26 deletions projects/person/person_project.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
from os.path import (
join as pjoin,
dirname as pdirname,
abspath as pabspath
)

from ultralytics import YOLO

from projects.base_project import BaseProject
from projects.person.iperson_project import IPersonProject

import os
import torch

config_path = './projects/person/person_config.yaml'


Expand Down Expand Up @@ -111,30 +100,19 @@ def connect_models(self):
Initializes the YOLO models and connects them to the appropriate device (CPU or GPU).
Returns:
tuple: A tuple containing two YOLO models.
models: A tuple containing two YOLO models.
models_allowed_classes: List of corresponding allowed classes for each model.
Raises:
ModuleNotFoundError: If the models cannot be loaded.
"""

_cur_dir = os.getcwd()
# initialise the yolo model, additionally use the device parameter to specify the device to run the model on.
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
_cur_dir = pdirname(pabspath(__file__))
model_dir = pjoin(_cur_dir, f'../../models')
model_dir = pabspath(model_dir) # normalise the link

models = []
for model_name in self._config.get('models'):
model = YOLO(pjoin(model_dir, model_name)).to(self.device)
models.append(model)

models = self.__connect_models__()
models_allowed_classes = self._config.get('allowed_classes')

if not models:
raise ModuleNotFoundError('Model not found!')

print(f'1. Using device: {self.device}')
print(
f"2. Using {len(models)} models: {[model_name for model_name in self._config.get('models')]}")
print(f"2. Using {len(models)} models: {[model_name for model_name in self._config.get('models')]}")
return models, models_allowed_classes
1 change: 1 addition & 0 deletions services/harvest_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def evaluate(self, video):
frame,
skip_frames_counter)
# Free all resources
self.project.reset_models()
cv2.destroyAllWindows()

return self.export.result_dir_path
Expand Down

0 comments on commit e35b2a8

Please sign in to comment.