From 9633f547db606a2d61825bbaae53c98ace98590b Mon Sep 17 00:00:00 2001 From: cliebig2019 Date: Thu, 21 Jul 2022 12:24:33 +0200 Subject: [PATCH 1/4] custom neural network from json --- training/scripts/train_agent.py | 2 +- training/tools/argsparser.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/training/scripts/train_agent.py b/training/scripts/train_agent.py index 447787e2..c03cc9f1 100755 --- a/training/scripts/train_agent.py +++ b/training/scripts/train_agent.py @@ -133,7 +133,7 @@ def main(): verbose=1, ) elif args.agent is not None: - agent: Union[Type[BaseAgent], Type[ActorCriticPolicy]] = AgentFactory.instantiate(args.agent) + agent: Union[Type[BaseAgent], Type[ActorCriticPolicy]] = AgentFactory.instantiate(args.agent, path = args.path) if isinstance(agent, BaseAgent): model = PPO( agent.type.value, diff --git a/training/tools/argsparser.py b/training/tools/argsparser.py index 46645e7e..cb9110a9 100644 --- a/training/tools/argsparser.py +++ b/training/tools/argsparser.py @@ -23,6 +23,7 @@ def training_args(parser): import rosnav.model.custom_policy import rosnav.model.custom_sb3_policy from rosnav.model.agent_factory import AgentFactory + import rosnav.model.custom_policy_from_json group.add_argument( "--agent", @@ -60,6 +61,11 @@ def training_args(parser): parser.add_argument( "--tb", action="store_true", help="enables tensorboard logging" ) + parser.add_argument( + "--path", + type=str, + help="path to the json file containing" "the neural network", + ) def run_agent_args(parser): From fc4a1af04706c2a73002c27ce3a694fcd6fefe94 Mon Sep 17 00:00:00 2001 From: cliebig2019 <82951475+cliebig2019@users.noreply.github.com> Date: Thu, 21 Jul 2022 13:01:01 +0200 Subject: [PATCH 2/4] train_agent passing path to agentFactory --- training/scripts/train_agent.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 training/scripts/train_agent.py diff --git a/training/scripts/train_agent.py b/training/scripts/train_agent.py old mode 100755 new mode 100644 From 83424a590e3ffaf4eec5085667428d980f43f603 Mon Sep 17 00:00:00 2001 From: cliebig2019 <82951475+cliebig2019@users.noreply.github.com> Date: Tue, 2 Aug 2022 14:44:38 +0200 Subject: [PATCH 3/4] rosparam to network file; new arg for custom network --- arena_bringup/launch/start_training.launch | 2 ++ training/scripts/train_agent.py | 2 +- training/tools/argsparser.py | 5 ----- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/arena_bringup/launch/start_training.launch b/arena_bringup/launch/start_training.launch index 6804fc9f..84ae6e3f 100644 --- a/arena_bringup/launch/start_training.launch +++ b/arena_bringup/launch/start_training.launch @@ -24,6 +24,8 @@ + + diff --git a/training/scripts/train_agent.py b/training/scripts/train_agent.py index c03cc9f1..447787e2 100644 --- a/training/scripts/train_agent.py +++ b/training/scripts/train_agent.py @@ -133,7 +133,7 @@ def main(): verbose=1, ) elif args.agent is not None: - agent: Union[Type[BaseAgent], Type[ActorCriticPolicy]] = AgentFactory.instantiate(args.agent, path = args.path) + agent: Union[Type[BaseAgent], Type[ActorCriticPolicy]] = AgentFactory.instantiate(args.agent) if isinstance(agent, BaseAgent): model = PPO( agent.type.value, diff --git a/training/tools/argsparser.py b/training/tools/argsparser.py index cb9110a9..127dc53f 100644 --- a/training/tools/argsparser.py +++ b/training/tools/argsparser.py @@ -61,11 +61,6 @@ def training_args(parser): parser.add_argument( "--tb", action="store_true", help="enables tensorboard logging" ) - parser.add_argument( - "--path", - type=str, - help="path to the json file containing" "the neural network", - ) def run_agent_args(parser): From 018dbb5680bb5b22990a18375d685e66f556a722 Mon Sep 17 00:00:00 2001 From: cliebig2019 <82951475+cliebig2019@users.noreply.github.com> Date: Thu, 25 Aug 2022 17:00:19 +0200 Subject: [PATCH 4/4] changed rosparam custom_network_path --- arena_bringup/launch/start_training.launch | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arena_bringup/launch/start_training.launch b/arena_bringup/launch/start_training.launch index 84ae6e3f..8a69b9e5 100644 --- a/arena_bringup/launch/start_training.launch +++ b/arena_bringup/launch/start_training.launch @@ -24,7 +24,8 @@ - + +