diff --git a/CMakeLists.txt b/CMakeLists.txt index 5ecc2c6a..599f7654 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,6 +212,7 @@ ament_target_dependencies(lie rclcpp geometry_msgs tf2 tf2_ros) mrover_add_header_only_library(units units) mrover_add_header_only_library(loop_profiler loop_profiler) mrover_add_header_only_library(parameter_utils parameter_utils) +mrover_add_header_only_library(state_machine state_machine) # Simulator diff --git a/lie/lie.hpp b/lie/lie.hpp index d1013a10..d6511978 100644 --- a/lie/lie.hpp +++ b/lie/lie.hpp @@ -43,28 +43,28 @@ namespace mrover { [[nodiscard]] static auto toTransformStamped(SE3d const& tf, std::string const& childFrame, std::string const& parentFrame, rclcpp::Time const& time) -> geometry_msgs::msg::TransformStamped; /** - * \brief Pull the most recent transform or pose between two frames from the TF tree. - * The second and third parameters are named for the transform interpretation. - * Consider them named "a" and "b" respectively: - * For a transform this is a rotation and translation, i.e. aToB. - * For a pose this is a position and orientation, i.e. aInB. - * \param buffer ROS TF Buffer, make sure a listener is attached - * \param fromFrame From (transform) or child (pose) frame - * \param toFrame To (transform) or parent (pose) frame - * \param time Time to query the transform at, default is the latest - * \return The transform or pose represented by an SE3 lie group element - */ + * \brief Pull the most recent transform or pose between two frames from the TF tree. + * The second and third parameters are named for the transform interpretation. + * Consider them named "a" and "b" respectively: + * For a transform this is a rotation and translation, i.e. aToB. + * For a pose this is a position and orientation, i.e. aInB. + * \param buffer ROS TF Buffer, make sure a listener is attached + * \param fromFrame From (transform) or child (pose) frame + * \param toFrame To (transform) or parent (pose) frame + * \param time Time to query the transform at, default is the latest + * \return The transform or pose represented by an SE3 lie group element + */ [[nodiscard]] static auto fromTfTree(tf2_ros::Buffer const& buffer, std::string const& fromFrame, std::string const& toFrame, rclcpp::Time const& time = rclcpp::Time{}) -> SE3d; /** - * \brief Push a transform to the TF tree between two frames - * \see fromTfTree for more explanation of the frames - * \param broadcaster ROS TF Broadcaster - * \param fromFrame From (transform) or child (pose) frame - * \param toFrame To (transform) or parent (pose) frame - * \param transform The transform or pose represented by an SE3 lie group element - * \param time - */ + * \brief Push a transform to the TF tree between two frames + * \see fromTfTree for more explanation of the frames + * \param broadcaster ROS TF Broadcaster + * \param fromFrame From (transform) or child (pose) frame + * \param toFrame To (transform) or parent (pose) frame + * \param transform The transform or pose represented by an SE3 lie group element + * \param time + */ static auto pushToTfTree(tf2_ros::TransformBroadcaster& broadcaster, std::string const& fromFrame, std::string const& toFrame, SE3d const& transform, rclcpp::Time const& time) -> void; }; diff --git a/scripts/visualizer.py b/scripts/visualizer.py index 336f9e67..4ad471dd 100755 --- a/scripts/visualizer.py +++ b/scripts/visualizer.py @@ -20,10 +20,7 @@ from dataclasses import dataclass from typing import Optional, List, Dict import threading - - -STRUCTURE_TOPIC = "nav_structure" -STATUS_TOPIC = "nav_state" +import sys @dataclass @@ -85,7 +82,7 @@ def container_structure_callback(self, structure: StateMachineStructure): class GUI(QWidget): # type: ignore - def __init__(self, state_machine_instance, *args, **kwargs): + def __init__(self, state_machine_instance, structure_topic, state_topic, *args, **kwargs): super().__init__(*args, **kwargs) self.label: QLabel = QLabel() # type: ignore self.timer: QTimer = QTimer() # type: ignore @@ -98,11 +95,11 @@ def __init__(self, state_machine_instance, *args, **kwargs): self.viz = Node("Visualizer") self.viz.create_subscription( - StateMachineStructure, STRUCTURE_TOPIC, self.state_machine.container_structure_callback, 1 + StateMachineStructure, structure_topic, self.state_machine.container_structure_callback, 1 ) self.viz.create_subscription( - StateMachineStateUpdate, STATUS_TOPIC, self.state_machine.container_status_callback, 1 + StateMachineStateUpdate, state_topic, self.state_machine.container_status_callback, 1 ) def paintEvent(self, event): @@ -131,7 +128,7 @@ def update(self): # type: ignore[override] self.repaint() -def main(): +def main(structure_topic, state_topic): try: rclpy.init() @@ -147,7 +144,7 @@ def main(): print("Subscriptions Created...") app = QApplication([]) # type: ignore - g = GUI(state_machine) + g = GUI(state_machine, structure_topic, state_topic) g.show() app.exec_() @@ -158,4 +155,9 @@ def main(): if __name__ == "__main__": - main() + argc = len(sys.argv) + if argc != 3: + print('Usage ros2 run mrover visualizer.py "[structure topic]" "[state topic]"') + sys.exit(1) + + main(sys.argv[1], sys.argv[2]) diff --git a/state_machine/state.hpp b/state_machine/state.hpp new file mode 100644 index 00000000..30dd328b --- /dev/null +++ b/state_machine/state.hpp @@ -0,0 +1,17 @@ +#pragma once + +/** + * \brief Virtual state class to describe how states in the state machine should act + * \see state_machine.hpp for reference on how the states will be used by the state machine + */ +class State { +public: + /** + * \brief Virtual destructor for the state class + */ + virtual ~State() = default; + /** + * \brief The function which will be called every loop in the state machine + */ + virtual auto onLoop() -> State* = 0; +}; diff --git a/state_machine/state_machine.hpp b/state_machine/state_machine.hpp new file mode 100644 index 00000000..df364847 --- /dev/null +++ b/state_machine/state_machine.hpp @@ -0,0 +1,165 @@ +#pragma once +#include "state.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * \brief State Machine class that facilitates transitioning between different states which inherit from the State class + * \see state.hpp for reference on creating states + */ +class StateMachine{ +private: + std::string mName; + State* currState; + + using TypeHash = std::size_t; + + std::unordered_map> mValidTransitions; + std::unordered_map decoder; + + /** + * \brief Ensures that transitioning from "from" to "to" is a valid transition + * \param from The runtime type of the from state + * \param to The runtime type of the to state + */ + void assertValidTransition(std::type_info const& from, std::type_info const& to) const { + auto it = mValidTransitions.find(from.hash_code()); + + if(it == mValidTransitions.end()){ + throw std::runtime_error(std::format("{} is not in the state machine ", typeid(from).name())); + } + + std::vector const& toTransitions = it->second; + if(std::find(toTransitions.begin(), toTransitions.end(), to.hash_code()) == toTransitions.end()){ + throw std::runtime_error(std::format("Invalid State Transition from {} to {}", typeid(currState).name(), typeid(to).name())); + } + } + + /** + * \brief Adds the demangled name to the map in the corresponding type hash slot + * \param hash The type hash for the runtime type + * \param name The demangled name of the runtime type + */ + void addNameToDecoder(TypeHash hash, std::string const& name){ + constexpr static std::string prefix{"mrover::"}; + TypeHash index = name.find(prefix); + std::string _name{name}; + _name.replace(index, index + prefix.size(), ""); + decoder[hash] = _name; + } +public: + /** + * \brief Constructor for the StateMachine Class + * \param name The name of the state machine useful for visualization + * \param initialState The initial state which the state machine will begin execution in + */ + explicit StateMachine(std::string name, State* initialState) : mName{std::move(name)}, currState{initialState}{}; + + ~StateMachine(){ + delete currState; + } + + /** + * \brief Makes a state which the state machine will use for execution. + * DO NOT CALL THIS FUNCTION AND NOT PASS THE STATE TO THE STATE MACHINE + * \param args The arguments that will be passed to the constructor of the state + */ + template + static auto make_state(Args... args) -> T*{ + static_assert(std::derived_from, "State Must Be Derived From The State Class"); + return new T(args...); + } + + /** + * \brief Returns the name of the state machine + * \return A constant reference to the name of the state machine + */ + auto getName() const -> std::string const& { + return mName; + } + + /** + * \brief Returns the demangled name of the state at runtime + * \param state A pointer to a state derived object which will have its runtime type analyzed + * \return A constant reference to the demangled state name at runtime + */ + auto getStateName(State const* state) const -> std::string const&{ + return decoder.find(typeid(*state).hash_code())->second; + } + + /** + * \brief Returns the demangled name of the current state in the state machine + * \return A constant reference to the current state's demangled state name + */ + auto getCurrentStateName() const -> std::string const& { + return getStateName(currState); + } + + /** + * \brief Returns a map of type hashes to a vector of each type hash + * \return A constant reference to the map describing all valid state transitions + */ + auto getTransitionTable() const -> std::unordered_map> const&{ + return mValidTransitions; + } + + /** + * \brief Takes in a type hash and returns the demangled state name + * \return A constant reference to the demangled state name + */ + auto decodeTypeHash(TypeHash hash) const -> std::string const&{ + return decoder.find(hash)->second; + } + + /** + * \brief Enables the state transition from the first templated type to the subsequent templated types + */ + template + void enableTransitions(){ + static_assert(std::derived_from, "From State Must Be Derived From The State Class"); + static_assert((std::derived_from && ...), "All States Must Be Derived From The State Class"); + // Add From State To Decoder + int status = 0; + char* demangledName = abi::__cxa_demangle(typeid(From).name(), nullptr, nullptr, &status); + + if(status){ + throw std::runtime_error("C++ demangle failed!"); + } + + addNameToDecoder(typeid(From).hash_code(), demangledName); + free(demangledName); + + mValidTransitions[typeid(From).hash_code()] = {typeid(To).hash_code()...}; + + std::vector> types{std::ref(typeid(To))...}; + for(auto const& type : types){ + demangledName = abi::__cxa_demangle(type.get().name(), nullptr, nullptr, &status); + if(status){ + throw std::runtime_error("C++ demangle failed!"); + } + addNameToDecoder(type.get().hash_code(), demangledName); + free(demangledName); + } + } + + /** + * \brief Runs the onLoop function for the state and then transitions to the state returned from that function + */ + void update(){ + State* newState = currState->onLoop(); + + assertValidTransition(typeid(*currState), typeid(*newState)); + if(newState != currState){ + delete currState; + } + currState = newState; + } +}; diff --git a/state_machine/state_publisher_server.hpp b/state_machine/state_publisher_server.hpp new file mode 100644 index 00000000..1d8c21b6 --- /dev/null +++ b/state_machine/state_publisher_server.hpp @@ -0,0 +1,81 @@ +#pragma once + +// STL +#include +#include + +// Ros Client Library +#include +#include +#include +#include +#include + +// MRover +#include "mrover/msg/detail/state_machine_state_update__struct.hpp" +#include "mrover/msg/detail/state_machine_structure__struct.hpp" +#include "state_machine.hpp" +#include +#include + +namespace mrover{ + class StatePublisher{ + private: + StateMachine const& mStateMachine; + + rclcpp::Publisher::SharedPtr mStructurePub; + rclcpp::Publisher::SharedPtr mStatePub; + + rclcpp::TimerBase::SharedPtr mStructureTimer; + rclcpp::TimerBase::SharedPtr mStateTimer; + /** + * \brief Publishes the structure to be used by visualizer.py + * \see visualizer.py to see how these topic will be used + */ + void publishStructure(){ + auto structureMsg = mrover::msg::StateMachineStructure(); + structureMsg.machine_name = mStateMachine.getName(); + auto transitionTable = mStateMachine.getTransitionTable(); + + for(auto const&[from, tos] : transitionTable){ + auto transition = mrover::msg::StateMachineTransition(); + transition.origin = mStateMachine.decodeTypeHash(from); + for(auto& hash : tos){ + transition.destinations.push_back(mStateMachine.decodeTypeHash(hash)); + } + structureMsg.transitions.push_back(std::move(transition)); + } + + mStructurePub->publish(structureMsg); + } + + /** + * \brief Publishes the current state of the state machine + * \see visualizer.py to see how these topic will be used + */ + void publishState(){ + auto stateMachineUpdate = mrover::msg::StateMachineStateUpdate(); + stateMachineUpdate.state_machine_name = mStateMachine.getName(); + stateMachineUpdate.state = mStateMachine.getCurrentStateName(); + mStatePub->publish(stateMachineUpdate); + } + + public: + /** + * \brief Creates a State Publisher to facilitate the communications between visualizer.py and the state machine + * \param node The node which owns the state publisher + * \param stateMachine The state machine which the publisher will describe + * \param structureTopicName The topic which will publish the state machine's structure + * \param structureTopicHz The rate at which the structure topic will publish + * \param stateTopicName The topic which will publish the state machine's state + * \param stateTopicHz The rate at which the state topic will publish + */ + StatePublisher(rclcpp::Node* node, StateMachine const& stateMachine, std::string const& structureTopicName, double structureTopicHz, std::string const& stateTopicName, double stateTopicHz) : mStateMachine{stateMachine} { + mStructurePub = node->create_publisher(structureTopicName, 1); + mStatePub = node->create_publisher(stateTopicName, 1); + + mStructureTimer = node->create_wall_timer(std::chrono::milliseconds(static_cast(1000 / structureTopicHz)), [&](){publishStructure();}); + mStateTimer = node->create_wall_timer(std::chrono::milliseconds(static_cast(1000 / stateTopicHz)), [&](){publishState();}); + } + }; +}