Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ State Machine Library #56

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ ament_target_dependencies(lie rclcpp geometry_msgs tf2 tf2_ros)
mrover_add_header_only_library(units units)
mrover_add_header_only_library(loop_profiler loop_profiler)
mrover_add_header_only_library(parameter_utils parameter_utils)
mrover_add_header_only_library(state_machine state_machine)

# Simulator

Expand Down
38 changes: 19 additions & 19 deletions lie/lie.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,28 @@ namespace mrover {
[[nodiscard]] static auto toTransformStamped(SE3d const& tf, std::string const& childFrame, std::string const& parentFrame, rclcpp::Time const& time) -> geometry_msgs::msg::TransformStamped;

/**
* \brief Pull the most recent transform or pose between two frames from the TF tree.
* The second and third parameters are named for the transform interpretation.
* Consider them named "a" and "b" respectively:
* For a transform this is a rotation and translation, i.e. aToB.
* For a pose this is a position and orientation, i.e. aInB.
* \param buffer ROS TF Buffer, make sure a listener is attached
* \param fromFrame From (transform) or child (pose) frame
* \param toFrame To (transform) or parent (pose) frame
* \param time Time to query the transform at, default is the latest
* \return The transform or pose represented by an SE3 lie group element
*/
* \brief Pull the most recent transform or pose between two frames from the TF tree.
* The second and third parameters are named for the transform interpretation.
* Consider them named "a" and "b" respectively:
* For a transform this is a rotation and translation, i.e. aToB.
* For a pose this is a position and orientation, i.e. aInB.
* \param buffer ROS TF Buffer, make sure a listener is attached
* \param fromFrame From (transform) or child (pose) frame
* \param toFrame To (transform) or parent (pose) frame
* \param time Time to query the transform at, default is the latest
* \return The transform or pose represented by an SE3 lie group element
*/
[[nodiscard]] static auto fromTfTree(tf2_ros::Buffer const& buffer, std::string const& fromFrame, std::string const& toFrame, rclcpp::Time const& time = rclcpp::Time{}) -> SE3d;

/**
* \brief Push a transform to the TF tree between two frames
* \see fromTfTree for more explanation of the frames
* \param broadcaster ROS TF Broadcaster
* \param fromFrame From (transform) or child (pose) frame
* \param toFrame To (transform) or parent (pose) frame
* \param transform The transform or pose represented by an SE3 lie group element
* \param time
*/
* \brief Push a transform to the TF tree between two frames
* \see fromTfTree for more explanation of the frames
* \param broadcaster ROS TF Broadcaster
* \param fromFrame From (transform) or child (pose) frame
* \param toFrame To (transform) or parent (pose) frame
* \param transform The transform or pose represented by an SE3 lie group element
* \param time
*/
static auto pushToTfTree(tf2_ros::TransformBroadcaster& broadcaster, std::string const& fromFrame, std::string const& toFrame, SE3d const& transform, rclcpp::Time const& time) -> void;
};

Expand Down
22 changes: 12 additions & 10 deletions scripts/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
from dataclasses import dataclass
from typing import Optional, List, Dict
import threading


STRUCTURE_TOPIC = "nav_structure"
STATUS_TOPIC = "nav_state"
import sys


@dataclass
Expand Down Expand Up @@ -85,7 +82,7 @@ def container_structure_callback(self, structure: StateMachineStructure):


class GUI(QWidget): # type: ignore
def __init__(self, state_machine_instance, *args, **kwargs):
def __init__(self, state_machine_instance, structure_topic, state_topic, *args, **kwargs):
super().__init__(*args, **kwargs)
self.label: QLabel = QLabel() # type: ignore
self.timer: QTimer = QTimer() # type: ignore
Expand All @@ -98,11 +95,11 @@ def __init__(self, state_machine_instance, *args, **kwargs):
self.viz = Node("Visualizer")

self.viz.create_subscription(
StateMachineStructure, STRUCTURE_TOPIC, self.state_machine.container_structure_callback, 1
StateMachineStructure, structure_topic, self.state_machine.container_structure_callback, 1
)

self.viz.create_subscription(
StateMachineStateUpdate, STATUS_TOPIC, self.state_machine.container_status_callback, 1
StateMachineStateUpdate, state_topic, self.state_machine.container_status_callback, 1
)

def paintEvent(self, event):
Expand Down Expand Up @@ -131,7 +128,7 @@ def update(self): # type: ignore[override]
self.repaint()


def main():
def main(structure_topic, state_topic):
try:
rclpy.init()

Expand All @@ -147,7 +144,7 @@ def main():
print("Subscriptions Created...")

app = QApplication([]) # type: ignore
g = GUI(state_machine)
g = GUI(state_machine, structure_topic, state_topic)
g.show()
app.exec_()

Expand All @@ -158,4 +155,9 @@ def main():


if __name__ == "__main__":
main()
argc = len(sys.argv)
if argc != 3:
print('Usage ros2 run mrover visualizer.py "[structure topic]" "[state topic]"')
sys.exit(1)

main(sys.argv[1], sys.argv[2])
17 changes: 17 additions & 0 deletions state_machine/state.hpp
jbrhm marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

/**
* \brief Virtual state class to describe how states in the state machine should act
* \see state_machine.hpp for reference on how the states will be used by the state machine
*/
class State {
public:
/**
* \brief Virtual destructor for the state class
*/
virtual ~State() = default;
/**
* \brief The function which will be called every loop in the state machine
*/
virtual auto onLoop() -> State* = 0;
};
162 changes: 162 additions & 0 deletions state_machine/state_machine.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#pragma once
#include "state.hpp"

#include <unordered_map>
#include <vector>
#include <typeinfo>
#include <stdexcept>
#include <format>
#include <algorithm>
#include <tuple>
#include <cxxabi.h>
#include <iostream>

/**
* \brief State Machine class that facilitates transitioning between different states which inherit from the State class
* \see state.hpp for reference on creating states
*/
class StateMachine{
private:
std::string mName;
State* currState;

using TypeHash = std::size_t;

std::unordered_map<TypeHash, std::vector<TypeHash>> mValidTransitions;
std::unordered_map<TypeHash, std::string> decoder;

/**
* \brief Ensures that transitioning from "from" to "to" is a valid transition
* \param from The runtime type of the from state
* \param to The runtime type of the to state
*/
void assertValidTransition(std::type_info const& from, std::type_info const& to) const {
auto it = mValidTransitions.find(from.hash_code());

if(it == mValidTransitions.end()){
throw std::runtime_error(std::format("{} is not in the state machine ", typeid(from).name()));
}

std::vector<TypeHash> const& toTransitions = it->second;
if(std::find(toTransitions.begin(), toTransitions.end(), to.hash_code()) == toTransitions.end()){
throw std::runtime_error(std::format("Invalid State Transition from {} to {}", typeid(currState).name(), typeid(to).name()));
}
}

/**
* \brief Adds the demangled name to the map in the corresponding type hash slot
* \param hash The type hash for the runtime type
* \param name The demangled name of the runtime type
*/
void addNameToDecoder(TypeHash hash, std::string const& name){
constexpr static std::string prefix{"mrover::"};
TypeHash index = name.find(prefix);
std::string _name{name};
_name.replace(index, index + prefix.size(), "");
decoder[hash] = _name;
}
public:
/**
* \brief Constructor for the StateMachine Class
* \param name The name of the state machine useful for visualization
* \param initialState The initial state which the state machine will begin execution in
*/
explicit StateMachine(std::string name, State* initialState) : mName{std::move(name)}, currState{initialState}{};

~StateMachine(){
delete currState;
}

/**
* \brief Makes a state which the state machine will use for execution.
* DO NOT CALL THIS FUNCTION AND NOT PASS THE STATE TO THE STATE MACHINE
* \param args The arguments that will be passed to the constructor of the state
*/
template<typename T, typename ...Args>
static auto make_state(Args... args) -> T*{
static_assert(std::derived_from<T, State>, "State Must Be Derived From The State Class");
return new T(args...);
}

/**
* \brief Returns the name of the state machine
* \return A constant reference to the name of the state machine
*/
auto getName() const -> std::string const& {
return mName;
}

/**
* \brief Returns the demangled name of the state at runtime
* \param state A pointer to a state derived object which will have its runtime type analyzed
* \return A constant reference to the demangled state name at runtime
*/
auto getStateName(State const* state) const -> std::string const&{
return decoder.find(typeid(*state).hash_code())->second;
}

/**
* \brief Returns the demangled name of the current state in the state machine
* \return A constant reference to the current state's demangled state name
*/
auto getCurrentState() const -> std::string const& {
jbrhm marked this conversation as resolved.
Show resolved Hide resolved
return getStateName(currState);
}

/**
* \brief Returns a map of type hashes to a vector of each type hash
* \return A constant reference to the map describing all valid state transitions
*/
auto getTransitionTable() const -> std::unordered_map<TypeHash, std::vector<TypeHash>> const&{
return mValidTransitions;
}

/**
* \brief Takes in a type hash and returns the demangled state name
* \return A constant reference to the demangled state name
*/
auto decodeTypeHash(TypeHash hash) const -> std::string const&{
return decoder.find(hash)->second;
}

/**
* \brief Enables the state transition from the first templated type to the subsequent templated types
*/
template<typename From, typename ...To>
void enableTransitions(){
static_assert(std::derived_from<From, State>, "From State Must Be Derived From The State Class");
static_assert((std::derived_from<To, State> && ...), "All States Must Be Derived From The State Class");
// Add From State To Decoder
int status = 0;
char* demangledName = abi::__cxa_demangle(typeid(From).name(), nullptr, nullptr, &status);

if(status){
throw std::runtime_error("C++ demangle failed!");
}

addNameToDecoder(typeid(From).hash_code(), demangledName);
free(demangledName);

mValidTransitions[typeid(From).hash_code()] = {typeid(To).hash_code()...};

std::vector<std::reference_wrapper<std::type_info const>> types{std::ref(typeid(To))...};
for(auto const& type : types){
demangledName = abi::__cxa_demangle(type.get().name(), nullptr, nullptr, &status);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wtf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its c++filt... but in C++ 💯

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unholy. Just have states carry a static member for the name...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unholy is harsh... 😢 but it is p cool 😎

addNameToDecoder(type.get().hash_code(), demangledName);
jbrhm marked this conversation as resolved.
Show resolved Hide resolved
free(demangledName);
}
}

/**
* \brief Runs the onLoop function for the state and then transitions to the state returned from that function
*/
void update(){
State* newState = currState->onLoop();

assertValidTransition(typeid(*currState), typeid(*newState));
if(newState != currState){
delete currState;
}
currState = newState;
}
};
81 changes: 81 additions & 0 deletions state_machine/state_publisher_server.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#pragma once

// STL
#include <cstddef>
#include <memory>

// Ros Client Library
#include <ament_index_cpp/get_package_prefix.hpp>
#include <rclcpp/node.hpp>
#include <rclcpp/rclcpp.hpp>
#include <tf2_ros/transform_broadcaster.h>
#include <tf2_ros/transform_listener.h>

// MRover
#include "mrover/msg/detail/state_machine_state_update__struct.hpp"
#include "mrover/msg/detail/state_machine_structure__struct.hpp"
#include "state_machine.hpp"
#include <mrover/msg/state_machine_structure.hpp>
#include <mrover/msg/state_machine_state_update.hpp>

namespace mrover{
class StatePublisher{
private:
StateMachine const& mStateMachine;

rclcpp::Publisher<mrover::msg::StateMachineStructure>::SharedPtr mStructurePub;
rclcpp::Publisher<mrover::msg::StateMachineStateUpdate>::SharedPtr mStatePub;

rclcpp::TimerBase::SharedPtr mStructureTimer;
rclcpp::TimerBase::SharedPtr mStateTimer;
/**
* \brief Publishes the structure to be used by visualizer.py
* \see visualizer.py to see how these topic will be used
*/
void publishStructure(){
auto structureMsg = mrover::msg::StateMachineStructure();
structureMsg.machine_name = mStateMachine.getName();
auto transitionTable = mStateMachine.getTransitionTable();

for(auto const&[from, tos] : transitionTable){
auto transition = mrover::msg::StateMachineTransition();
transition.origin = mStateMachine.decodeTypeHash(from);
for(auto& hash : tos){
transition.destinations.push_back(mStateMachine.decodeTypeHash(hash));
}
structureMsg.transitions.push_back(std::move(transition));
}

mStructurePub->publish(structureMsg);
}

/**
* \brief Publishes the current state of the state machine
* \see visualizer.py to see how these topic will be used
*/
void publishState(){
auto stateMachineUpdate = mrover::msg::StateMachineStateUpdate();
stateMachineUpdate.state_machine_name = mStateMachine.getName();
stateMachineUpdate.state = mStateMachine.getCurrentState();
mStatePub->publish(stateMachineUpdate);
}

public:
/**
* \brief Creates a State Publisher to facilitate the communications between visualizer.py and the state machine
* \param node The node which owns the state publisher
* \param stateMachine The state machine which the publisher will describe
* \param structureTopicName The topic which will publish the state machine's structure
* \param structureTopicHz The rate at which the structure topic will publish
* \param stateTopicName The topic which will publish the state machine's state
* \param stateTopicHz The rate at which the state topic will publish
*/
StatePublisher(rclcpp::Node* node, StateMachine const& stateMachine, std::string const& structureTopicName, double structureTopicHz, std::string const& stateTopicName, double stateTopicHz) : mStateMachine{stateMachine} {
mStructurePub = node->create_publisher<mrover::msg::StateMachineStructure>(structureTopicName, 1);
mStatePub = node->create_publisher<mrover::msg::StateMachineStateUpdate>(stateTopicName, 1);

mStructureTimer = node->create_wall_timer(std::chrono::milliseconds(static_cast<std::size_t>(1 / structureTopicHz)), [&](){publishStructure();});
jbrhm marked this conversation as resolved.
Show resolved Hide resolved
mStateTimer = node->create_wall_timer(std::chrono::milliseconds(static_cast<std::size_t>(1 / stateTopicHz)), [&](){publishState();});
jbrhm marked this conversation as resolved.
Show resolved Hide resolved
}
};
}