Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/dynamic parameters #15

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
15 changes: 14 additions & 1 deletion config/image_object_detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,20 @@ image_object_detection_node:
model.weights_file: yolov7-tiny.pt
model.device: '0'

selected_detections: ['person']
selected_detections: ['person'] # Classes to detect ['person', 'car']

show_image: False
publish_debug_image: True

# Lists of topics to subscribe
camera_topics:
- '/camera/image_raw'
# - '/camera1/image_raw'
# - '/camera2/image_raw'
# - '/camera3/image_raw'

# QoS policy for the image subscriber
subscribers.qos_policy: 'best_effort'

# QoS policy for the image debug publisher
image_debug_publisher.qos_policy: 'best_effort'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
("share/" + package_name, ["yolov7-tiny.pt"]),
("share/" + package_name + "/launch", ["launch/image_object_detection_launch.py"]),
("share/" + package_name + "/config", ["config/image_object_detection.yaml"]),
("share/" + package_name + "/templates", ["src/image_object_detection/templates/index.html"]),
],
install_requires=["setuptools"],
zip_safe=True,
Expand Down
264 changes: 185 additions & 79 deletions src/image_object_detection/image_object_detection_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import std_srvs.srv
from sensor_msgs.msg import CompressedImage, Image
from vision_msgs.msg import Detection2D, ObjectHypothesisWithPose

import torch
import torch.backends.cudnn as cudnn
Expand All @@ -27,17 +28,20 @@
from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh, set_logging
from utils.plots import plot_one_box
from utils.torch_utils import select_device
from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose
from vision_msgs.msg import Detection2DArray, Detection2D
from ament_index_python.packages import get_package_share_directory

# Add this after the existing imports
from rcl_interfaces.msg import SetParametersResult

PACKAGE_NAME = "image_object_detection"


class ImageDetectObjectNode(Node):
def __init__(self):
super().__init__("image_object_detection_node")

# parametros
# Model parameters
self.declare_parameter("model.image_size", 640)
self.model_image_size = (
self.get_parameter("model.image_size").get_parameter_value().integer_value
Expand Down Expand Up @@ -97,6 +101,8 @@ def __init__(self):
.string_value
)



self.declare_parameter("subscribers.qos_policy", "best_effort")
self.subscribers_qos = (
self.get_parameter("subscribers.qos_policy").get_parameter_value().string_value
Expand Down Expand Up @@ -129,60 +135,138 @@ def __init__(self):

self.bridge = cv_bridge.CvBridge()

self.image_sub = self.create_subscription(
msg_type=Image, topic="image", callback=self.image_callback, qos_profile=self.qos
)

self.image_compressed_sub = self.create_subscription(
msg_type=CompressedImage,
topic="image/compressed",
callback=self.image_compressed_callback,
qos_profile=self.qos,
)

self.detection_publisher = self.create_publisher(
msg_type=Detection2DArray, topic="detections", qos_profile=self.qos
)

if self.enable_publish_debug_image:
if self.qos_policy == "best_effort":
self.get_logger().info("Using best effort qos policy for debug image publisher")
self.qos = QoSProfile(
reliability=QoSReliabilityPolicy.BEST_EFFORT,
history=QoSHistoryPolicy.KEEP_LAST,
depth=1,
)
else:
self.get_logger().info("Using reliable qos policy for debug image publisher")
self.qos = QoSProfile(
reliability=QoSReliabilityPolicy.RELIABLE,
history=QoSHistoryPolicy.KEEP_LAST,
depth=1,
# Initialize camera topics parameter
self.declare_parameter("camera_topics", [
"/cameras/frontleft_fisheye_image/image",
"/cameras/frontright_fisheye_image/image",
"/cameras/left_fisheye_image/image",
"/cameras/right_fisheye_image/image"
])
self.camera_topics = self.get_parameter("camera_topics").get_parameter_value().string_array_value
self.get_logger().info(f"Subscribed to topics: {self.camera_topics}")

# Initialize empty containers for subscribers and publishers
self.subscribers = []
self.detection_publishers = {}
self.debug_image_publishers = {}

# Set up camera topics using the extracted method
self.setup_camera_topics()

self.initialize_model()

# Add the parameter callback handler
self.add_on_set_parameters_callback(self.parameters_callback)

def parameters_callback(self, params):
result = SetParametersResult(successful=True)

init_model = False
for param in params:
if param.name == 'camera_topics':
self.camera_topics = param.value
self.get_logger().info(f"Updated camera_topics: {self.camera_topics}")
# Recreate subscribers and publishers for new topics
self.setup_camera_topics()

elif param.name == 'selected_detections':
self.selected_detections = param.value
self.get_logger().info(f"Updated selected_detections: {self.selected_detections}")

elif param.name == 'model.iou_threshold':
self.iou_threshold = param.value
self.get_logger().info(f"Updated iou_threshold: {self.iou_threshold}")

elif param.name == 'model.confidence':
self.confidence = param.value
self.get_logger().info(f"Updated confidence: {self.confidence}")

elif param.name == 'model.weights_file':
self.model_weights_file = param.value
self.get_logger().info(f"Updated weights_file: {self.model_weights_file}")
init_model = True

elif param.name == 'model.publish_debug_image':
self.enable_publish_debug_image = param.value
self.get_logger().info(f"Updated publish_debug_image: {self.enable_publish_debug_image}")
self.setup_camera_topics() # Recreate publishers with new debug setting

elif param.name == 'model.image_size':
self.model_image_size = param.value
self.get_logger().info(f"Updated image_size: {self.model_image_size}")
init_model = True

if init_model:
self.initialize_model()

return result
def setup_camera_topics(self):
# Create sets of existing topics
existing_sub_topics = {sub.topic_name for sub in self.subscribers}
existing_pub_topics = set(self.detection_publishers.keys())
existing_debug_topics = set(self.debug_image_publishers.keys())

# Set of new topics
new_topics = set(self.camera_topics)

# Remove subscribers for topics that no longer exist
for sub in list(self.subscribers):
if sub.topic_name not in new_topics:
sub.destroy()
self.subscribers.remove(sub)

# Remove publishers for topics that no longer exist
for topic in list(self.detection_publishers.keys()):
if topic not in new_topics:
self.detection_publishers[topic].destroy()
self.detection_publishers[topic].destroy()
del self.detection_publishers[topic]

# Remove debug image publishers for topics that no longer exist
for topic in list(self.debug_image_publishers.keys()):
if topic not in new_topics:
self.debug_image_publishers[topic].destroy()
del self.debug_image_publishers[topic]

# Add new topics
for topic in new_topics:
if topic not in existing_sub_topics:
self.subscribers.append(
self.create_subscription(
Image,
topic,
callback=self.image_callback_factory(topic),
qos_profile=self.qos,
)
)

self.debug_image_publisher = self.create_publisher(
msg_type=Image, topic="debug_image", qos_profile=self.qos
)
if topic not in existing_pub_topics:
detection_topic = f"{topic}/detections"
self.detection_publishers[topic] = self.create_publisher(
Detection2DArray, detection_topic, self.qos
)

self.initialize_model()
if self.enable_publish_debug_image and topic not in existing_debug_topics:
debug_image_topic = f"{topic}/debug_image"
self.debug_image_publishers[topic] = self.create_publisher(
Image, debug_image_topic, self.qos
)

def initialize_model(self):
with torch.no_grad():
# Initialize
set_logging()
self.device = select_device(self.device)
self.device = select_device(str(self.device))
self.half = self.device.type != "cpu"

# Load model
self.model = attempt_load(
self.model_weights_file, map_location=self.device
) # load FP32 model
)
self.stride = int(self.model.stride.max())

self.imgsz = check_img_size(self.model_image_size, s=self.stride)

if self.half:
self.model.half() # to FP16
self.model.half()

cudnn.benchmark = True

Expand Down Expand Up @@ -215,17 +299,13 @@ def accomodate_image_to_model(self, img0):
def image_compressed_callback(self, msg):
if not self.processing_enabled:
return

try:
self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format)
img = self.accomodate_image_to_model(self.cv_img)

detections_msg, debugimg = self.predict(img, self.cv_img)
self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format)
img = self.accomodate_image_to_model(self.cv_img)

self.detection_publisher.publish(detections_msg)
except CvBridgeError as e:
self.get_logger().error(f"Error converting image: {e}")
return
detections_msg, debugimg = self.predict(img, self.cv_img)

self.detection_publisher.publish(detections_msg)

if debugimg is not None:
self.publish_debug_image(debugimg)
Expand All @@ -234,25 +314,43 @@ def image_compressed_callback(self, msg):
cv2.imshow("Compressed Image", debugimg)
cv2.waitKey(1)

def image_callback(self, msg):
if not self.processing_enabled:
return
def image_callback_factory(self, topic):
def callback(msg):
try:
cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8")
self.image_queue[topic] = cv_img
except CvBridgeError as e:
self.get_logger().error(f"Error converting image from {topic}: {e}")

self.cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8")
img = self.accomodate_image_to_model(self.cv_img)
return callback

def image_callback_factory(self, topic):
def callback(msg):
if not self.processing_enabled:
return

detections_msg, debugimg = self.predict(img, self.cv_img)
try:
cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8")
img = self.accomodate_image_to_model(cv_img)

self.detection_publisher.publish(detections_msg)
detections_msg, debugimg = self.predict(img, cv_img)

if debugimg is not None:
self.publish_debug_image(debugimg)
# Publish detections for the current camera
self.detection_publishers[topic].publish(detections_msg)

if self.show_image:
cv2.imshow("Detection", debugimg)
cv2.waitKey(1)
# Publish debug image for the current camera (if enabled)
if self.enable_publish_debug_image and topic in self.debug_image_publishers:
self.publish_debug_image(debugimg, topic)

if self.show_image:
cv2.imshow(f"Detection from {topic}", debugimg)
cv2.waitKey(1)
except CvBridgeError as e:
self.get_logger().error(f"Error converting image from {topic}: {e}")

return callback

def publish_debug_image(self, debugimg):
def publish_debug_image(self, debugimg, topic):
if self.debug_image_output_format == "mono8":
debugimg = cv2.cvtColor(debugimg, cv2.COLOR_RGB2GRAY)
elif self.debug_image_output_format == "rgb8":
Expand All @@ -261,11 +359,12 @@ def publish_debug_image(self, debugimg):
debugimg = cv2.cvtColor(debugimg, cv2.COLOR_BGR2RGBA)
else:
self.get_logger().error(
"Unsupported debug image output format: {}".format(self.debug_image_output_format)
f"Unsupported debug image output format: {self.debug_image_output_format}"
)
return

self.debug_image_publisher.publish(
# Publish the debug image for the current camera
self.debug_image_publishers[topic].publish(
self.bridge.cv2_to_imgmsg(debugimg, self.debug_image_output_format)
)

Expand Down Expand Up @@ -294,7 +393,6 @@ def predict(self, model_img, original_image):
).round()

for *xyxy, conf, cls in reversed(det):
# clase clases deseadas
if self.names[int(cls)] in self.selected_detections:
detection2D_msg = Detection2D()
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
Expand Down Expand Up @@ -322,17 +420,25 @@ def predict(self, model_img, original_image):

return detections_msg, original_image


def main(args=None):
print(args)
rclpy.init(args=sys.argv)

minimal_publisher = ImageDetectObjectNode()
rclpy.spin(minimal_publisher)

minimal_publisher.destroy_node()
rclpy.shutdown()


if __name__ == "__main__":
main(sys.argv)
rclpy.init(args=args)

detection_node = ImageDetectObjectNode()
from image_object_detection.web_interface_node import WebInterfaceNode
web_interface = WebInterfaceNode(detection_node)

# Use MultiThreadedExecutor to handle both nodes
executor = rclpy.executors.MultiThreadedExecutor()
executor.add_node(detection_node)
executor.add_node(web_interface)

try:
executor.spin()
finally:
executor.shutdown()
detection_node.destroy_node()
web_interface.destroy_node()
rclpy.shutdown()

if __name__ == '__main__':
main()
Loading
Loading