From 728575b4c1b7ae5e64b2d1f07fcbcd35480bcaaa Mon Sep 17 00:00:00 2001 From: TannyLe <130630658+tannyle289@users.noreply.github.com> Date: Tue, 27 Aug 2024 16:32:26 +0200 Subject: [PATCH 1/7] update dockerfile --- Dockerfile | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 6a15d20..1bda777 100644 --- a/Dockerfile +++ b/Dockerfile @@ -43,8 +43,19 @@ COPY . . ENV MEDIA_SAVEPATH "/ml/data/input/input_video.mp4" # Model parameters -ENV MODEL_NAME: "helmet_dectector_1k_16b_150e.pt" +ENV MODEL_NAME: "yolov8n.pt.pt" ENV MODEL_NAME_2: "helmet_dectector_1k_16b_150e.pt" +ENV MODEL_ALLOWED_CLASSES="0" +ENV MODEL_2_ALLOWED_CLASSES="0" + +# Dataset parameters +ENV DATASET_FORMAT="base" +ENV DATASET_VERSION="1" +ENV DATASET_UPLOAD="True" + +# Forwarding +ENV FORWARDING_MEDIA="True" +ENV REMOVE_AFTER_PROCESSED="True" # Queue parameters ENV QUEUE_NAME "" @@ -59,12 +70,20 @@ ENV STORAGE_URI "" ENV STORAGE_ACCESS_KEY "" ENV STORAGE_SECRET_KEY "" +#Integration parameters +ENV INTEGRATION_NAME="" + # Roboflow parameters -ENV RBF_UPLOAD: "" ENV RBF_API_KEY: "" ENV RBF_WORKSPACE: "" ENV RBF_PROJECT: "" +#S3 parameters +ENV S3_ENDPOINT="" +ENV S3_ACCESS_KEY="" +ENV S3_SECRET_KEY="" +ENV S3_BUCKET="" + # Feature parameters ENV PLOT "False" From 53176bbb7903a642df52379b86e13a2580d0550b Mon Sep 17 00:00:00 2001 From: TannyLe <130630658+tannyle289@users.noreply.github.com> Date: Tue, 27 Aug 2024 16:49:01 +0200 Subject: [PATCH 2/7] update k8s-deployment.yaml --- k8s-deployment.yaml | 48 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/k8s-deployment.yaml b/k8s-deployment.yaml index e1d21f5..2b6c3b9 100644 --- a/k8s-deployment.yaml +++ b/k8s-deployment.yaml @@ -23,8 +23,21 @@ spec: env: - name: MODEL_NAME value: "yolov8n.pt" - - name: CONDITION - value: "1 persons detected" + - name: MODEL_NAME_2 + value: "helmet_dectector_1k_16b_150e.pt" + - name: MODEL_ALLOWED_CLASSES + value: "0" + - name: MODEL_2_ALLOWED_CLASSES + value: "0,1,2" + + - name: DATASET_FORMAT + value: "base" + - name: DATASET_VERSION + value: "1" + - name: DATASET_UPLOAD + value: "True" + - name: MODEL_NAME + value: "yolov8n.pt" - name: QUEUE_NAME value: "data-harvesting" # This is the topic of kafka we will read messages from. @@ -39,9 +52,28 @@ spec: - name: STORAGE_URI value: "http://vault-lb.kerberos-vault/api" - name: STORAGE_ACCESS_KEY - value: "52gyELgxutOXUWhF" + value: "YOUR_KEY" - name: STORAGE_SECRET_KEY - value: "k8DrcB@hQ5XfxDENzDKcnkxBHx" + value: "YOUR_SECRET_KEY" + + - name: INTEGRATION_NAME + value: "s3" + + - name: RBF_API_KEY + value: "YOUR KEY" + - name: RBF_WORKSPACE + value: "YOUR_WS" + - name: RBF_PROJECT + value: "YOUR_PROJ" + + - name: S3_ENDPOINT + value: "YOUR_ENDPOINT" + - name: S3_ACCESS_KEY + value: "YOUR_KEY" + - name: S3_SECRET_KEY + value: "YOUR_SECRET_KEY" + - name: S3_BUCKET + value: "YOUR_BUCKET" - name: LOGGING value: "True" @@ -60,11 +92,17 @@ spec: - name: CLASSIFICATION_FPS value: "3" - name: CLASSIFICATION_THRESHOLD - value: "0.3" + value: "0.25" - name: MAX_NUMBER_OF_PREDICTIONS value: "100" + - name: FRAMES_SKIP_AFTER_DETECT + value: "50" - name: ALLOWED_CLASSIFICATIONS value: "0, 1, 2, 3, 5, 7, 14, 15, 16, 24, 26, 28" + - name: MIN_DETECTIONS + value: "1" + - name: IOU + value: "0.85" - name: REMOVE_AFTER_PROCESSED value: "True" \ No newline at end of file From 219864fc559aa4934633e2388e0ad339355a39b8 Mon Sep 17 00:00:00 2001 From: TannyLe <130630658+tannyle289@users.noreply.github.com> Date: Tue, 27 Aug 2024 16:50:01 +0200 Subject: [PATCH 3/7] update k8s-deployment.yaml p2 --- k8s-deployment.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/k8s-deployment.yaml b/k8s-deployment.yaml index 2b6c3b9..c33ec43 100644 --- a/k8s-deployment.yaml +++ b/k8s-deployment.yaml @@ -36,8 +36,6 @@ spec: value: "1" - name: DATASET_UPLOAD value: "True" - - name: MODEL_NAME - value: "yolov8n.pt" - name: QUEUE_NAME value: "data-harvesting" # This is the topic of kafka we will read messages from. From 64898c4ca8f5de4e54272be33089a4a9fc706fa4 Mon Sep 17 00:00:00 2001 From: TannyLe <130630658+tannyle289@users.noreply.github.com> Date: Tue, 27 Aug 2024 16:50:59 +0200 Subject: [PATCH 4/7] revert key --- k8s-deployment.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/k8s-deployment.yaml b/k8s-deployment.yaml index c33ec43..8a3dbb8 100644 --- a/k8s-deployment.yaml +++ b/k8s-deployment.yaml @@ -50,9 +50,9 @@ spec: - name: STORAGE_URI value: "http://vault-lb.kerberos-vault/api" - name: STORAGE_ACCESS_KEY - value: "YOUR_KEY" + value: "52gyELgxutOXUWhF" - name: STORAGE_SECRET_KEY - value: "YOUR_SECRET_KEY" + value: "k8DrcB@hQ5XfxDENzDKcnkxBHx" - name: INTEGRATION_NAME value: "s3" From 329ffe4d30baa4f80971702debfe6b77c6168e79 Mon Sep 17 00:00:00 2001 From: TannyLe <130630658+tannyle289@users.noreply.github.com> Date: Tue, 27 Aug 2024 16:54:44 +0200 Subject: [PATCH 5/7] align code --- single-shot.py | 79 +++++++++----------------------------------------- 1 file changed, 14 insertions(+), 65 deletions(-) diff --git a/single-shot.py b/single-shot.py index 7d74563..62b7625 100644 --- a/single-shot.py +++ b/single-shot.py @@ -1,7 +1,7 @@ # This script is used to look for objects under a specific condition (at least 5 persons etc) # The script reads a video from a message queue, classifies the objects in the video, and does a condition check. # If condition is met, the video is being forwarded to a remote vault. - +from integrations.integration_factory import IntegrationFactory from projects.project_factory import ProjectFactory from services.harvest_service import HarvestService from integrations.roboflow_integration import RoboflowIntegration @@ -25,68 +25,27 @@ def init(): project = ProjectFactory().init('helmet') - # Perform object classification on the media + # Mapping classes of 2 models + mapping = project.class_mapping(model1, model2) + integration = IntegrationFactory().init() # Open video-capture/recording using the video-path. Throw FileNotFoundError if cap is unable to open. cap = harvest_service.open_video() time_verbose = TimeVerbose() - # Initialize the classification process. - # 2 lists are initialized: - # Classification objects - # Additional list for easy access to the ids. - - # frame_number -> The current frame number. Depending on the frame_skip_factor this can make jumps. - # predicted_frames -> The number of frames, that were used for the prediction. This goes up by one each prediction iteration. - # frame_skip_factor is the factor by which the input video frames are skipped. - frame_number, predicted_frames = 0, 0 - frame_skip_factor = int(cap.get(cv2.CAP_PROP_FPS) / var.CLASSIFICATION_FPS) - - # Loop over the video frames, and perform object classification. - # The classification process is done until the counter reaches the MAX_NUMBER_OF_PREDICTIONS or the last frame is reached. - MAX_FRAME_NUMBER = cap.get(cv2.CAP_PROP_FRAME_COUNT) if var.LOGGING: - print(f'5) Classifying frames') + print(f'5. Classifying frames') if var.TIME_VERBOSE: time_verbose.add_preprocessing_time() + save_dir = harvest_service.process( + cap, + model1, + model2, + project.condition_func, + mapping) - skip_frames_counter = 0 - - result_dir_path, image_dir_path, label_dir_path, yaml_path = project.create_result_save_dir() - mapping = project.class_mapping(model1, model2) - - while (predicted_frames < var.MAX_NUMBER_OF_PREDICTIONS) and (frame_number < MAX_FRAME_NUMBER): - success, frame, skip_frames_counter = harvest_service.get_video_frame(cap, skip_frames_counter) - - if success and frame is None: - continue - if not success: - break - - # Process frame - skip_frames_counter = harvest_service.process_frame( - frame_skip_factor, - skip_frames_counter, - model1, - model2, - project.condition_func, - mapping, - result_dir_path, - image_dir_path, - label_dir_path, - frame, - None) - - # Upload to roboflow - project.upload_dataset(result_dir_path, yaml_path, model2) - - # Upload to roboflow after processing frames if any - if os.path.exists(result_dir_path) and var.RBF_UPLOAD: - rb = RoboflowIntegration() - if rb: - rb.upload_dataset(result_dir_path) - else: - print('Nothing to upload!!') + if var.DATASET_UPLOAD: + integration.upload_dataset(save_dir) if var.TIME_VERBOSE: time_verbose.add_preprocessing_time() @@ -94,8 +53,7 @@ def init(): # Depending on the TIME_VERBOSE parameter, the time it took to classify the objects is printed. if var.TIME_VERBOSE: time_verbose.show_result() - # If the videowriter was active, the videowriter is released. - # Close the video-capture and destroy all windows. + if var.LOGGING: print('8) Releasing video writer and closing video capture') print("\n\n") @@ -106,14 +64,5 @@ def init(): # cv2.destroyAllWindows() -def create_yaml(file_path, label_names): - with open(file_path, 'w') as my_file: - content = 'names:\n' - for name in label_names: - content += f'- {name}\n' # class mapping for helmet detection project - content += f'nc: {len(label_names)}' - my_file.write(content) - - # Run the init function. init() From 0f19cf1e718d5df6dbb45d847fbf84045a49cb7d Mon Sep 17 00:00:00 2001 From: TannyLe <130630658+tannyle289@users.noreply.github.com> Date: Tue, 27 Aug 2024 16:55:43 +0200 Subject: [PATCH 6/7] align code p2 --- single-shot.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/single-shot.py b/single-shot.py index 62b7625..c4f04ba 100644 --- a/single-shot.py +++ b/single-shot.py @@ -4,15 +4,8 @@ from integrations.integration_factory import IntegrationFactory from projects.project_factory import ProjectFactory from services.harvest_service import HarvestService -from integrations.roboflow_integration import RoboflowIntegration -# Local imports -from utils.VariableClass import VariableClass - -# External imports -import os - -import cv2 +from utils.VariableClass import VariableClass from utils.time_verbose_object import TimeVerbose # Initialize the VariableClass object, which contains all the necessary environment variables. From 64f923a1a0690229741a1b1c0d3c15d4e9f85f56 Mon Sep 17 00:00:00 2001 From: TannyLe <130630658+tannyle289@users.noreply.github.com> Date: Wed, 28 Aug 2024 12:26:03 +0200 Subject: [PATCH 7/7] Add docstring comments and move models to dedicated folder --- exports/base_export.py | 18 +++++++- exports/export_factory.py | 1 + exports/ibase_export.py | 16 +++++++ exports/iyolov8_export.py | 22 +++++++++- exports/yolov8_export.py | 15 ++++++- integrations/integration_factory.py | 3 ++ integrations/iroboflow_integration.py | 13 ++++++ integrations/is3_integration.py | 30 +++++++++++++ integrations/roboflow_integration.py | 33 ++++++++++++--- integrations/s3_integration.py | 40 ++++++++++++------ .../helmet_dectector_1k_16b_150e.pt | Bin yolov8n.pt => models/yolov8n.pt | Bin projects/base_project.py | 16 ------- services/harvest_service.py | 13 +++++- 14 files changed, 176 insertions(+), 44 deletions(-) rename helmet_dectector_1k_16b_150e.pt => models/helmet_dectector_1k_16b_150e.pt (100%) rename yolov8n.pt => models/yolov8n.pt (100%) diff --git a/exports/base_export.py b/exports/base_export.py index 6957583..2fc02c8 100644 --- a/exports/base_export.py +++ b/exports/base_export.py @@ -10,7 +10,15 @@ class BaseExport(IBaseExport): + """ + Base Export class that implements functions for + initializing and saving frame under specific format. + """ + def __init__(self, proj_dir_name): + """ + Constructor. + """ self._var = VariableClass() _cur_dir = pdirname(pabspath(__file__)) self.proj_dir = pjoin(_cur_dir, f'../data/{proj_dir_name}') @@ -19,10 +27,10 @@ def __init__(self, proj_dir_name): def initialize_save_dir(self): """ - See ibase_project.py + See ibase_export.py Returns: - None + success True or False """ self.result_dir_path = pjoin(self.proj_dir, f'{self._var.DATASET_FORMAT}-v{self._var.DATASET_VERSION}') os.makedirs(self.result_dir_path, exist_ok=True) @@ -35,6 +43,12 @@ def initialize_save_dir(self): return False def save_frame(self, frame, predicted_frames, cv2, labels_and_boxes): + """ + See ibase_export.py + + Returns: + Predicted frame counter. + """ print(f'5.1. Condition met, processing valid frame: {predicted_frames}') # Save original frame unix_time = int(time.time()) diff --git a/exports/export_factory.py b/exports/export_factory.py index 6d2ebf9..ab0eb10 100644 --- a/exports/export_factory.py +++ b/exports/export_factory.py @@ -7,6 +7,7 @@ class ExportFactory: """ Export Factory initializes specific export types. """ + def __init__(self): self._var = VariableClass() self.save_format = self._var.DATASET_FORMAT diff --git a/exports/ibase_export.py b/exports/ibase_export.py index e208810..55893b7 100644 --- a/exports/ibase_export.py +++ b/exports/ibase_export.py @@ -2,11 +2,27 @@ class IBaseExport(ABC): + """ + Interface for Base Export. + """ @abstractmethod def initialize_save_dir(self): + """ + Initializes save directory for Base export format + """ pass @abstractmethod def save_frame(self, frame, predicted_frames, cv2, labels_and_boxes): + """ + Saves a single frames as well as it predicted annotation. + It should save 2 separate files under the same name, 1 .png for the raw frame and 1 .txt for the annotations. + + Args: + frame: The current frame to be saved. + predicted_frames: Frames with predictions that might need to be saved alongside the original. + cv2: The OpenCV module used for image processing, passed in to avoid tight coupling. + labels_and_boxes: A list containing labels and their corresponding bounding boxes for the frame. + """ pass diff --git a/exports/iyolov8_export.py b/exports/iyolov8_export.py index 977dca8..dae787f 100644 --- a/exports/iyolov8_export.py +++ b/exports/iyolov8_export.py @@ -2,15 +2,35 @@ class IYolov8Export(ABC): - + """ + Interface for Yolov8 Export. + """ @abstractmethod def initialize_save_dir(self): + """ + Initializes save directory for Yolov8 export format + """ pass @abstractmethod def save_frame(self, frame, predicted_frames, cv2, labels_and_boxes): + """ + Saves a single frames as well as it predicted annotation. + It should save 2 separate files under the same name, + - 1 .png for the raw frame and is saved in images subdirectory. + - 1 .txt for the annotations and is saved in labels subdirectory. + + Args: + frame: The current frame to be saved. + predicted_frames: Frames with predictions that might need to be saved alongside the original. + cv2: The OpenCV module used for image processing, passed in to avoid tight coupling. + labels_and_boxes: A list containing labels and their corresponding bounding boxes for the frame. + """ pass @abstractmethod def create_yaml(self, model2): + """ + Create .yaml file to map annotation labels with their corresponding names. + """ pass diff --git a/exports/yolov8_export.py b/exports/yolov8_export.py index 01dcc0d..bfb8719 100644 --- a/exports/yolov8_export.py +++ b/exports/yolov8_export.py @@ -10,6 +10,11 @@ class Yolov8Export(IYolov8Export): + """ + Yolov8 Export class that implements functions for + initializing, saving frame and creating yaml file under specific format. + """ + def __init__(self, proj_dir_name): """ Constructor. @@ -25,10 +30,10 @@ def __init__(self, proj_dir_name): def initialize_save_dir(self): """ - See ibase_project.py + See iyolov8_export.py Returns: - None + Success true or false. """ self.result_dir_path = pjoin(self.proj_dir, f'{self._var.DATASET_FORMAT}-v{self._var.DATASET_VERSION}') os.makedirs(self.result_dir_path, exist_ok=True) @@ -51,6 +56,12 @@ def initialize_save_dir(self): return False def save_frame(self, frame, predicted_frames, cv2, labels_and_boxes): + """ + See iyolov8_export.py + + Returns: + Predicted frame counter. + """ print(f'5.1. Condition met, processing valid frame: {predicted_frames}') # Save original frame unix_time = int(time.time()) diff --git a/integrations/integration_factory.py b/integrations/integration_factory.py index 4e3124b..b66caab 100644 --- a/integrations/integration_factory.py +++ b/integrations/integration_factory.py @@ -4,6 +4,9 @@ class IntegrationFactory: + """ + Integration Factory initializes specific integration types. + """ def __init__(self): self._var = VariableClass() self.name = self._var.INTEGRATION_NAME diff --git a/integrations/iroboflow_integration.py b/integrations/iroboflow_integration.py index d58b2c7..076374b 100644 --- a/integrations/iroboflow_integration.py +++ b/integrations/iroboflow_integration.py @@ -2,11 +2,24 @@ class IRoboflowIntegration(ABC): + """ + Interface for Roboflow Integration class. + """ @abstractmethod def upload_dataset(self, src_project_path): + """ + Upload dataset to Roboflow platform. + + Args: + src_project_path: Project save path + """ pass @abstractmethod def __connect__(self): + """ + Connect to Roboflow agent. + You need to provide Roboflow parameters in .env file. + """ pass diff --git a/integrations/is3_integration.py b/integrations/is3_integration.py index 2858a87..2b5f4b3 100644 --- a/integrations/is3_integration.py +++ b/integrations/is3_integration.py @@ -2,19 +2,49 @@ class IS3Integration(ABC): + """ + Interface for S3 Integration class. + """ @abstractmethod def upload_file(self, source_path, output_path): + """ + Upload a single file to S3 compatible platform. + + Args: + source_path: File save path + output_path: Desired path we want to save in S3 + """ pass @abstractmethod def upload_dataset(self, src_project_path): + """ + Upload dataset to S3 compatible platform. + + Args: + src_project_path: Projecet save path + """ pass @abstractmethod def __connect__(self): + """ + Connect to S3 compatible agent. + You need to provide S3 parameters in .env file. + """ pass @abstractmethod def __check_bucket_exists__(self, bucket_name): + """ + Check if input bucket exists after connecting to S3 compatible agent. + You need to provide S3 parameters in .env file. + + Args: + bucket_name: Bucket name. + + Returns: + True or False + """ pass diff --git a/integrations/roboflow_integration.py b/integrations/roboflow_integration.py index f04c764..c9cdad9 100644 --- a/integrations/roboflow_integration.py +++ b/integrations/roboflow_integration.py @@ -7,11 +7,27 @@ class RoboflowIntegration: + """ + Roboflow Integration class that implements functions for connecting, uploading dataset + to Roboflow platform. + """ + def __init__(self): + """ + Constructor. + """ self._var = VariableClass() self.agent, self.ws, self.project = self.__connect__() def __connect__(self): + """ + See iroboflow_integration.py + + Returns: + agent: Connected agent. + workspace: Selected workspace in that agent. + project: Selected project in that workspace. + """ try: # Attempt to initialize Roboflow with the API key agent = roboflow.Roboflow(api_key=self._var.ROBOFLOW_API_KEY) @@ -29,15 +45,18 @@ def __connect__(self): raise ConnectionRefusedError(f'Error during Roboflow login: {e}') def upload_dataset(self, src_project_path): + """ + See iroboflow_integration.py + """ # Upload data set to an existing project self.ws.upload_dataset( - src_project_path, - pbasename(self.project.id), - num_workers=10, - project_license="MIT", - project_type="object-detection", - batch_name=None, - num_retries=0 + src_project_path, + pbasename(self.project.id), + num_workers=10, + project_license="MIT", + project_type="object-detection", + batch_name=None, + num_retries=0 ) print('Uploaded') diff --git a/integrations/s3_integration.py b/integrations/s3_integration.py index d341ecf..eee0953 100644 --- a/integrations/s3_integration.py +++ b/integrations/s3_integration.py @@ -5,18 +5,33 @@ class S3Integration: + """ + S3 Integration class that implements functions for connecting, uploading single file and dataset + to S3 compatible platform. + """ + def __init__(self): + """ + Constructor. + """ self._var = VariableClass() self.session, self.agent = self.__connect__() self.bucket = self._var.S3_BUCKET self.__check_bucket_exists__(self.bucket) def __connect__(self): + """ + See is3_integration.py + + Returns: + session: Connected session. + agent: Connected agent. + """ session = boto3.session.Session() - # Connect to Wasabi S3 + # Connect to S3 Compatible agent = session.client( self._var.INTEGRATION_NAME, - endpoint_url=self._var.S3_ENDPOINT, # Wasabi endpoint URL + endpoint_url=self._var.S3_ENDPOINT, aws_access_key_id=self._var.S3_ACCESS_KEY, aws_secret_access_key=self._var.S3_SECRET_KEY, ) @@ -25,25 +40,19 @@ def __connect__(self): return session, agent def upload_file(self, source_path, output_path): + """ + See is3_integration.py + """ try: self.agent.upload_file(source_path, self.bucket, output_path) print(f"Successfully uploaded '{source_path}' to 's3://{self.bucket}/{output_path}'") except Exception as e: print(f"Failed to upload '{source_path}' to 's3://{self.bucket}/{output_path}': {e}") - # def upload_dataset(self, src_project_path): - # # Iterate over all the files in the folder - # for root, dirs, files in os.walk(src_project_path): - # for filename in files: - # # Construct the full file path - # source_path = os.path.join(root, filename) - # - # output_path = f'{self._var.DATASET_FORMAT}-v{self._var.DATASET_VERSION}/{filename}' - # # Upload the file - # self.upload_file(source_path, output_path) - # print(f'Uploaded: {source_path} to s3://{self.bucket}/{output_path}') - def upload_dataset(self, src_project_path): + """ + See is3_integration.py + """ # Iterate over all the files in the folder, including sub folders for root, dirs, files in os.walk(src_project_path): for filename in files: @@ -62,6 +71,9 @@ def upload_dataset(self, src_project_path): print(f'Uploaded: {source_path} to s3://{self.bucket}/{output_path}') def __check_bucket_exists__(self, bucket_name): + """ + See is3_integration.py + """ try: self.agent.head_bucket(Bucket=bucket_name) print(f"Bucket '{bucket_name}' found.") diff --git a/helmet_dectector_1k_16b_150e.pt b/models/helmet_dectector_1k_16b_150e.pt similarity index 100% rename from helmet_dectector_1k_16b_150e.pt rename to models/helmet_dectector_1k_16b_150e.pt diff --git a/yolov8n.pt b/models/yolov8n.pt similarity index 100% rename from yolov8n.pt rename to models/yolov8n.pt diff --git a/projects/base_project.py b/projects/base_project.py index dbe113b..04d8c10 100644 --- a/projects/base_project.py +++ b/projects/base_project.py @@ -42,19 +42,3 @@ def create_proj_save_dir(self, dir_name): self.proj_dir = pjoin(_cur_dir, f'../data/{dir_name}') self.proj_dir = pabspath(self.proj_dir) # normalise the link print(f'1. Created/Found project folder under {self.proj_dir} path') - - # def create_result_save_dir(self): - # """ - # See ibase_project.py - # - # Returns: - # None - # """ - # if self._var.DATASET_FORMAT == 'yolov8': - # result_dir_path = pjoin(self.proj_dir, f'{datetime.now().strftime("%d-%m-%Y_%H-%M-%S")}') - # image_dir_path = pjoin(result_dir_path, 'images') - # label_dir_path = pjoin(result_dir_path, 'labels') - # yaml_path = pjoin(result_dir_path, 'data.yaml') - # return result_dir_path, image_dir_path, label_dir_path, yaml_path - # else: - # raise TypeError('Unsupported dataset format!') diff --git a/services/harvest_service.py b/services/harvest_service.py index 001c6d2..1c8d4e1 100644 --- a/services/harvest_service.py +++ b/services/harvest_service.py @@ -2,6 +2,11 @@ from uugai_python_kerberos_vault.KerberosVault import KerberosVault from exports.export_factory import ExportFactory +from os.path import ( + join as pjoin, + dirname as pdirname, + abspath as pabspath, +) from services.iharvest_service import IHarvestService from utils.VariableClass import VariableClass import time @@ -115,8 +120,12 @@ def connect_models(self): _cur_dir = os.getcwd() # initialise the yolo model, additionally use the device parameter to specify the device to run the model on. device = 'cuda' if torch.cuda.is_available() else 'cpu' - model = YOLO(self._var.MODEL_NAME).to(device) - model2 = YOLO(self._var.MODEL_NAME_2).to(device) + _cur_dir = pdirname(pabspath(__file__)) + model_dir = pjoin(_cur_dir, f'../models') + model_dir = pabspath(model_dir) # normalise the link + + model = YOLO(pjoin(model_dir, self._var.MODEL_NAME)).to(device) + model2 = YOLO(pjoin(model_dir, self._var.MODEL_NAME_2)).to(device) if model and model2: print(f'2. Using device: {device}') print(f'3. Using models: {self._var.MODEL_NAME} and {self._var.MODEL_NAME_2}')