From 40d38e9d0155bfdd9e80188d07206f266f0a080c Mon Sep 17 00:00:00 2001 From: pabloinigoblasco Date: Wed, 11 Sep 2024 22:09:30 +0200 Subject: [PATCH] multiple input images implementation --- config/image_object_detection.yaml | 15 +- .../image_object_detection_node.py | 139 ++++++++++-------- 2 files changed, 90 insertions(+), 64 deletions(-) diff --git a/config/image_object_detection.yaml b/config/image_object_detection.yaml index 7fead2c..81187b4 100644 --- a/config/image_object_detection.yaml +++ b/config/image_object_detection.yaml @@ -6,7 +6,20 @@ image_object_detection_node: model.weights_file: yolov7-tiny.pt model.device: '0' - selected_detections: ['person'] + selected_detections: ['person'] # Classes to detect ['person', 'car'] show_image: False publish_debug_image: True + + # Lists of topics to subscribe + camera_topics: + - '/camera/image_raw' + # - '/camera1/image_raw' + # - '/camera2/image_raw' + # - '/camera3/image_raw' + + # QoS policy for the image subscriber + subscribers.qos_policy: 'best_effort' + + # QoS policy for the image debug publisher + image_debug_publisher.qos_policy: 'best_effort' diff --git a/src/image_object_detection/image_object_detection_node.py b/src/image_object_detection/image_object_detection_node.py index 855ddab..b566efa 100644 --- a/src/image_object_detection/image_object_detection_node.py +++ b/src/image_object_detection/image_object_detection_node.py @@ -19,6 +19,7 @@ import std_srvs.srv from sensor_msgs.msg import CompressedImage, Image +from vision_msgs.msg import Detection2D, ObjectHypothesisWithPose import torch import torch.backends.cudnn as cudnn @@ -27,7 +28,7 @@ from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh, set_logging from utils.plots import plot_one_box from utils.torch_utils import select_device -from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose +from vision_msgs.msg import Detection2DArray, Detection2D from ament_index_python.packages import get_package_share_directory PACKAGE_NAME = "image_object_detection" @@ -35,9 +36,9 @@ class ImageDetectObjectNode(Node): def __init__(self): - super().__init__("image_object_detection_node") + super().__init__("image_object_detection_node") - # parametros + # Model parameters self.declare_parameter("model.image_size", 640) self.model_image_size = ( self.get_parameter("model.image_size").get_parameter_value().integer_value @@ -129,60 +130,59 @@ def __init__(self): self.bridge = cv_bridge.CvBridge() - self.image_sub = self.create_subscription( - msg_type=Image, topic="image", callback=self.image_callback, qos_profile=self.qos + # Get the list of camera topics from the config file + self.declare_parameter("camera_topics", []) + self.camera_topics = ( + self.get_parameter("camera_topics").get_parameter_value().string_array_value ) - - self.image_compressed_sub = self.create_subscription( - msg_type=CompressedImage, - topic="image/compressed", - callback=self.image_compressed_callback, - qos_profile=self.qos, - ) - - self.detection_publisher = self.create_publisher( - msg_type=Detection2DArray, topic="detections", qos_profile=self.qos - ) - - if self.enable_publish_debug_image: - if self.qos_policy == "best_effort": - self.get_logger().info("Using best effort qos policy for debug image publisher") - self.qos = QoSProfile( - reliability=QoSReliabilityPolicy.BEST_EFFORT, - history=QoSHistoryPolicy.KEEP_LAST, - depth=1, - ) - else: - self.get_logger().info("Using reliable qos policy for debug image publisher") - self.qos = QoSProfile( - reliability=QoSReliabilityPolicy.RELIABLE, - history=QoSHistoryPolicy.KEEP_LAST, - depth=1, + self.get_logger().info(f"Subscribed to topics: {self.camera_topics}") + + # Initialize subscribers and publishers for each camera topic + self.subscribers = [] + self.detection_publishers = {} + self.debug_image_publishers = {} + + for topic in self.camera_topics: + # Create a subscriber for each camera topic + self.subscribers.append( + self.create_subscription( + Image, + topic, + callback=self.image_callback_factory(topic), + qos_profile=self.qos, ) + ) - self.debug_image_publisher = self.create_publisher( - msg_type=Image, topic="debug_image", qos_profile=self.qos + # Create a detection publisher for each camera + detection_topic = f"{topic}/detections" + self.detection_publishers[topic] = self.create_publisher( + Detection2DArray, detection_topic, self.qos ) + # Create a debug image publisher for each camera (if enabled) + if self.enable_publish_debug_image: + debug_image_topic = f"{topic}/debug_image" + self.debug_image_publishers[topic] = self.create_publisher( + Image, debug_image_topic, self.qos + ) + self.initialize_model() def initialize_model(self): with torch.no_grad(): - # Initialize set_logging() self.device = select_device(self.device) self.half = self.device.type != "cpu" - # Load model self.model = attempt_load( self.model_weights_file, map_location=self.device - ) # load FP32 model + ) self.stride = int(self.model.stride.max()) self.imgsz = check_img_size(self.model_image_size, s=self.stride) if self.half: - self.model.half() # to FP16 + self.model.half() cudnn.benchmark = True @@ -215,17 +215,13 @@ def accomodate_image_to_model(self, img0): def image_compressed_callback(self, msg): if not self.processing_enabled: return - - try: - self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format) - img = self.accomodate_image_to_model(self.cv_img) - detections_msg, debugimg = self.predict(img, self.cv_img) + self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format) + img = self.accomodate_image_to_model(self.cv_img) - self.detection_publisher.publish(detections_msg) - except CvBridgeError as e: - self.get_logger().error(f"Error converting image: {e}") - return + detections_msg, debugimg = self.predict(img, self.cv_img) + + self.detection_publisher.publish(detections_msg) if debugimg is not None: self.publish_debug_image(debugimg) @@ -234,25 +230,43 @@ def image_compressed_callback(self, msg): cv2.imshow("Compressed Image", debugimg) cv2.waitKey(1) - def image_callback(self, msg): - if not self.processing_enabled: - return + def image_callback_factory(self, topic): + def callback(msg): + try: + cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8") + self.image_queue[topic] = cv_img + except CvBridgeError as e: + self.get_logger().error(f"Error converting image from {topic}: {e}") - self.cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8") - img = self.accomodate_image_to_model(self.cv_img) + return callback + + def image_callback_factory(self, topic): + def callback(msg): + if not self.processing_enabled: + return - detections_msg, debugimg = self.predict(img, self.cv_img) + try: + cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8") + img = self.accomodate_image_to_model(cv_img) - self.detection_publisher.publish(detections_msg) + detections_msg, debugimg = self.predict(img, cv_img) - if debugimg is not None: - self.publish_debug_image(debugimg) + # Publish detections for the current camera + self.detection_publishers[topic].publish(detections_msg) - if self.show_image: - cv2.imshow("Detection", debugimg) - cv2.waitKey(1) + # Publish debug image for the current camera (if enabled) + if self.enable_publish_debug_image and topic in self.debug_image_publishers: + self.publish_debug_image(debugimg, topic) - def publish_debug_image(self, debugimg): + if self.show_image: + cv2.imshow(f"Detection from {topic}", debugimg) + cv2.waitKey(1) + except CvBridgeError as e: + self.get_logger().error(f"Error converting image from {topic}: {e}") + + return callback + + def publish_debug_image(self, debugimg, topic): if self.debug_image_output_format == "mono8": debugimg = cv2.cvtColor(debugimg, cv2.COLOR_RGB2GRAY) elif self.debug_image_output_format == "rgb8": @@ -261,11 +275,12 @@ def publish_debug_image(self, debugimg): debugimg = cv2.cvtColor(debugimg, cv2.COLOR_BGR2RGBA) else: self.get_logger().error( - "Unsupported debug image output format: {}".format(self.debug_image_output_format) + f"Unsupported debug image output format: {self.debug_image_output_format}" ) return - self.debug_image_publisher.publish( + # Publish the debug image for the current camera + self.debug_image_publishers[topic].publish( self.bridge.cv2_to_imgmsg(debugimg, self.debug_image_output_format) ) @@ -294,7 +309,6 @@ def predict(self, model_img, original_image): ).round() for *xyxy, conf, cls in reversed(det): - # clase clases deseadas if self.names[int(cls)] in self.selected_detections: detection2D_msg = Detection2D() xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() @@ -322,7 +336,6 @@ def predict(self, model_img, original_image): return detections_msg, original_image - def main(args=None): print(args) rclpy.init(args=sys.argv)