Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Di Lu committed Dec 18, 2018
1 parent 72412d3 commit 3f115b4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
7 changes: 3 additions & 4 deletions pypomdp/parsers/env_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,12 @@ def __get_T(self, i):
# %f %f ... %f
# ...
# %f %f ... %f
# TODO: rollback
for j in range(len(self.states)):
for j, sj in enumerate(self.states):
probs = next_line.split()
assert len(probs) == len(self.states)
for k in range(len(probs)):
for k, sk in enumerate(self.states):
prob = float(probs[k])
self.T[(action, j, k)] = prob
self.T[(action, sj, sk)] = prob
next_line = self.contents[i+2+j]
return i+1+len(self.states)
else:
Expand Down
13 changes: 8 additions & 5 deletions pypomdp/solvers/pomcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@ def __init__(self, model):
def add_configs(self, budget=float('inf'), initial_belief=None, simulation_time=0.5,
max_particles=350, reinvigorated_particles_ratio=0.1, utility_fn='ucb1', C=0.5):
# acquaire utility function to choose the most desirable action to try
self.utility_fn = {
'ucb1': UtilityFunction.ucb1(C),
'mab_bv1': UtilityFunction.mab_bv1(min(self.model.costs), C),
'sa_ucb': UtilityFunction.sa_ucb(C)
}[utility_fn]
if utility_fn == 'ucb1':
self.utility_fn = UtilityFunction.ucb1(C)
elif utility_fn == 'sa_ucb':
self.utility_fn = UtilityFunction.sa_ucb(C)
elif utility_fn == 'mab_bv1':
if self.model.costs is None:
raise ValueError('Must specify action costs if utility function is MAB_BV1')
self.utility_fn = UtilityFunction.mab_bv1(min(self.model.costs), C)

# other configs
self.simulation_time = simulation_time
Expand Down
6 changes: 4 additions & 2 deletions pypomdp/util/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ def gen_distribution(n):


def draw_arg(probs):
assert(sum(probs) - 1.0 < 0.00000001)
return np.random.choice(list(range(len(probs))), p=probs)
assert(abs(sum(probs) - 1.0) < 0.00000001)
probs = np.array(probs)
# Do a second normalisation to avoid the problem described here: https://stackoverflow.com/questions/46539431/np-random-choice-probabilities-do-not-sum-to-1
return np.random.choice(list(range(len(probs))), p=probs/probs.sum())


def elem_distribution(arr):
Expand Down

0 comments on commit 3f115b4

Please sign in to comment.