-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagentRunner.py
62 lines (53 loc) · 2.58 KB
/
agentRunner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
from tensorflow_agents.mountain_car_model_tester import MountainCarModelTester
from tensorflow_agents.mountain_car_mo_dqn import MultiObjectiveMountainCarDQN
from tensorflow_agents.mountain_car_mo_ddqn import MultiObjectiveMountainCarDDQN
from tensorflow_agents.mountain_car_mo_pddqn import MultiObjectiveMountainCarPDDQN
from tensorflow_agents.mountain_car_mo_wpddqn import MultiObjectiveWMountainCar
from tensorflow_agents.mountain_car_graphical_ddqn import MountainCarGraphicalDDQN
from tensorflow_agents.mountain_car_open_ai import OpenAIMountainCar
from tensorflow_agents.deep_sea_baseline_dqn import DeepSeaTreasureBaselineDQN
from tensorflow_agents.deep_sea_baseline_ddqn import DeepSeaTreasureBaselineDDQN
from tensorflow_agents.deep_sea_graphical_pddqn import DeepSeaTreasureGraphicalPDDQN
from tensorflow_agents.deep_sea_graphical_ddqn import DeepSeaTreasureGraphicalDDQN
from tensorflow_agents.deep_sea_graphical_dqn import DeepSeaTreasureGraphicalDQN
from tensorflow_agents.deep_sea_mo_wdqn import DeepSeaWAgent
from tensorflow_agents.deep_sea_graphical_wpddqn import MultiObjectiveDeepSeaW
from tensorflow_agents.mario_baseline import MarioBaseline
parser = argparse.ArgumentParser(description='Run agentArg model for game')
parser.add_argument("-a", "--agentArg", required=True)
args = parser.parse_args()
agentArg = args.agentArg
if agentArg == 'mountain_car_mo_dqn':
agent = MultiObjectiveMountainCarDQN(1001)
elif agentArg == 'mountain_car_mo_ddqn':
agent = MultiObjectiveMountainCarDDQN(1001)
elif agentArg == 'mountain_car_mo_pddqn':
agent = MultiObjectiveMountainCarPDDQN(1001)
elif agentArg == 'mountain_car_mo_wpddqn':
agent = MultiObjectiveWMountainCar(5000)
elif agentArg == 'mountain_car_graphical_ddqn':
agent = MountainCarGraphicalDDQN(5000)
elif agentArg == 'mountain_car_open_ai':
agent = OpenAIMountainCar(2000)
elif agentArg == 'deep_sea_baseline_ddqn':
agent = DeepSeaTreasureBaselineDDQN(350)
elif agentArg == 'deep_sea_graphical_pddqn':
agent = DeepSeaTreasureGraphicalPDDQN(301)
elif agentArg == 'deep_sea_baseline_dqn':
agent = DeepSeaTreasureBaselineDQN(300)
elif agentArg == 'deep_sea_mo_wdqn':
agent = DeepSeaWAgent(2000)
elif agentArg == 'deep_sea_graphical_ddqn':
agent = DeepSeaTreasureGraphicalDDQN(1501)
elif agentArg == 'deep_sea_graphical_dqn':
agent = DeepSeaTreasureGraphicalDQN(2001)
elif agentArg == 'deep_sea_graphical_wpddqn':
agent = MultiObjectiveDeepSeaW(301)
elif agentArg == 'mario_baseline':
agent = MarioBaseline(2000)
agent.train()
'''
agentArg = MountainCarModelTester("./mountain_car_wnet_54540.chkpt")
agentArg.test()
'''