-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_value_iteration.py
142 lines (109 loc) · 4.31 KB
/
test_value_iteration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from mdp.algorithms.value_iteration import ValueIteration
from mdp.environment.env import Environment
from file_manager import FileManager
import pygame
# Change file name to custom_grid to use your own grid
from env_config import grid, actions, rewards, gw, gh
pygame.init()
# Initialize Constants
GAMMA = 0.99
C = 0.1
MAX_REWARD = 1.0
EPSILON = C * MAX_REWARD
PATH = 'analysis/'
CONVERT_POLICY = {(1,0): '↓', (-1, 0): '↑' , (0, 1): '→', (0, -1): '←'}
DISPLAY_GRID = True
UTILITY_FONT_SIZE = 15
UTILITY_OFFSET = (4, 14)
POLICY_FONT_SIZE = 30
POLICY_OFFSET = (17, 5)
ratio = 1
# Initialize the MDP
mdp = Environment(grid, actions, rewards, gw, gh)
# Initialize the algorithm
value_iteration = ValueIteration(GAMMA)
# Solve the MDP
results = value_iteration.solve(mdp, EPSILON)
# Retrieve the results
num_iterations = results['iterations']
values = results['utilities']
policy = results['policy']
# Print the results to the console
print(f'Number of iterations: {num_iterations}\n')
print('\n(Column, Row)')
for i in range(values.shape[0]):
for j in range(values.shape[1]):
print(f"{j, i}: {values[i][j]}")
# Save data for analysis
file_mgr = FileManager(PATH)
file_mgr.write('value_iteration.csv', value_iteration.get_data())
# Display utility and policy plot
if DISPLAY_GRID:
GREEN = (100, 200, 100)
RED = (200, 100, 100)
WHITE = (200, 200, 200)
GREY = (50, 50, 50)
directions = [[CONVERT_POLICY[cell] for cell in row] for row in policy]
utilities = [["{:.3f}".format(cell) for cell in row] for row in values]
colors = []
for row in grid:
color = []
for cell in row:
if cell == 'W':
color.append(GREY)
elif cell == 'G':
color.append(GREEN)
elif cell == 'R':
color.append(RED)
else:
color.append(WHITE)
colors.append(color)
block_size = 50
width = 300
height = 300
screen_dimensions = (width, height)
screen_color = (0, 0, 0)
policy_font = pygame.font.Font("assets/seguisym.ttf", int(POLICY_FONT_SIZE*ratio))
utility_font = pygame.font.Font("assets/seguisym.ttf", int(UTILITY_FONT_SIZE*ratio))
screen = pygame.display.set_mode(screen_dimensions)
pygame.display.set_caption('Policy Iteration')
# Display Policy
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
rect = pygame.Rect(0, 0, width, height)
pygame.draw.rect(screen, screen_color, rect)
for row in range(len(grid)):
for col in range(len(grid)):
rect = pygame.Rect(col * block_size, row * block_size, block_size, block_size)
pygame.draw.rect(screen, colors[row][col], rect)
pygame.draw.rect(screen, (0, 0, 0), rect, 1)
if grid[row][col] == 'W':
continue
message = policy_font.render(directions[row][col], True, (0, 0, 0))
screen.blit(message, (col * block_size + POLICY_OFFSET[0] * ratio, row * block_size + POLICY_OFFSET[1]*ratio))
pygame.display.update()
#pygame.image.save(screen, "images/complex_maze/vi_policy.png")
screen = pygame.display.set_mode(screen_dimensions)
pygame.display.set_caption('Policy Iteration')
# Display Utilities
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
rect = pygame.Rect(0, 0, width, height)
pygame.draw.rect(screen, screen_color, rect)
for row in range(len(grid)):
for col in range(len(grid)):
rect = pygame.Rect(col * block_size, row * block_size, block_size, block_size)
pygame.draw.rect(screen, colors[row][col], rect)
pygame.draw.rect(screen, (0, 0, 0), rect, 1)
if grid[row][col] == 'W':
continue
message = utility_font.render(utilities[row][col], True, (0, 0, 0))
screen.blit(message, (col * block_size + UTILITY_OFFSET[0]*ratio, row * block_size + UTILITY_OFFSET[1]*ratio))
pygame.display.update()
#pygame.image.save(screen, "images/complex_maze/vi_values.png")