Skip to content

Commit

Permalink
train GA: tournament with qualification and finale
Browse files Browse the repository at this point in the history
  • Loading branch information
stepanmracek committed Jun 14, 2024
1 parent 26e2a60 commit b4a03d8
Showing 1 changed file with 116 additions and 28 deletions.
144 changes: 116 additions & 28 deletions train_ga.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def evaluate_model(params: EvaluateParams):
global world
random.seed(params.seed)
fitness = 0.0
diamonds = 0
for run in range(params.runs):
world.reset()
car_keys = init_keys()
Expand All @@ -193,6 +194,8 @@ def evaluate_model(params: EvaluateParams):
world.blue_car, *step_outcome.blue_car, blue_prev_diamond
)
fitness += f1 + f2
diamonds += 1 if step_outcome.red_car[1].collected_diamond else 0
diamonds += 1 if step_outcome.blue_car[1].collected_diamond else 0

model_input = np.array(
[
Expand All @@ -216,7 +219,7 @@ def evaluate_model(params: EvaluateParams):
"r": model_output[1][3] > 0,
},
}
return params.order, fitness
return params.order, fitness, diamonds


@dataclass(slots=True)
Expand Down Expand Up @@ -334,11 +337,11 @@ def train():
desc=f"Evaluating generation {generation_index}",
)
)
results.sort(key=lambda index_fitness_pair: index_fitness_pair[0])
results.sort(key=lambda index_fitness_diamonds_tuple: index_fitness_diamonds_tuple[0])

# sort population by fitness (best individuals first)
sorted_population = sorted(
((model, fitness) for (model, (_, fitness)) in zip(population, results)),
((model, fitness) for (model, (_, fitness, _)) in zip(population, results)),
reverse=True,
key=lambda model_fitness_pair: model_fitness_pair[1],
)
Expand Down Expand Up @@ -503,39 +506,124 @@ def competition_pairs(competitors: list[T]):

def tournament():
arg_parser = ArgumentParser(prog="train_ga.py tournament")
arg_parser.add_argument("--model-dir", required=True)
arg_parser.add_argument("--processes", type=int)
arg_parser.add_argument("--models-dir", required=True)
arg_parser.add_argument("--level", default="park", choices=["park", "nyan"])
arg_parser.add_argument("--timelimit", default=60, type=int)
arg_parser.add_argument("--scorelimit", default=10, type=int)
arg_parser.add_argument("--qualification-runs", default=10, type=int)
arg_parser.add_argument("--qualification-seed", default=int(random.random() * 1e12), type=int)
arg_parser.add_argument("--tournament-max-models", default=50, type=int)
arg_parser.add_argument("--finale-runs", default=50, type=int)
args = arg_parser.parse_args(sys.argv[2:])

models_paths = sorted(glob.glob(args.model_dir + os.path.sep + "*.np"))
models_paths = sorted(glob.glob(args.models_dir + os.path.sep + "*.np"))
models = [(path, NumpyModel.load(path)) for path in tqdm(models_paths, desc="Loading models")]
pairs = list(competition_pairs(models))
params = [
CompetitionParams(
order=i,
seed=i,
timelimit=args.timelimit,
scorelimit=args.scorelimit,
red_model=red[1],
blue_model=blue[1],
)
for i, (red, blue) in enumerate(pairs)
]

with Pool(initializer=process_init, initargs=(args.level,)) as pool:
results = [r for r in tqdm(pool.imap(competition, params), total=len(params))]

rankings = Counter()
for (red, blue), result in zip(pairs, results):
if result == CompetitionResult.RED:
rankings[red[0]] += 1
elif result == CompetitionResult.BLUE:
rankings[blue[0]] += 1
with Pool(processes=args.processes, initializer=process_init, initargs=(args.level,)) as pool:
# If number of loaded models is greater than tournament_max_models do qualification
if len(models) > args.tournament_max_models:
# Run each model alon qualification_runs times
qualification_params = [
EvaluateParams(
order=i, seed=args.qualification_seed, model=model, runs=args.qualification_runs
)
for i, (_, model) in enumerate(models)
]
qualification_results = list(
tqdm(
pool.imap_unordered(evaluate_model, qualification_params),
total=len(models),
desc="Qualification",
)
)
# sort by index
qualification_results.sort(key=lambda i: i[0])

# merge qualification results and models
qualification_results = list(
(model, path, diamonds)
for ((model, path), diamonds) in zip(
models, (diamonds for (_, _, diamonds) in qualification_results)
)
)

for model, wins in rankings.most_common():
print(model, wins)
# sort qualification results by collected diamonds and select best tournament_max_models
qualification_results.sort(key=lambda r: r[2], reverse=True)
qualification_results = qualification_results[: args.tournament_max_models]

print("Qualification results")
for r in qualification_results:
print(f" {r[0]}: {r[2]} diamonds")

models = [r[:2] for r in qualification_results]

# Tournament: run one game with every model against each other
pairs = list(competition_pairs(models))
tournament_params = [
CompetitionParams(
order=i,
seed=i,
timelimit=args.timelimit,
scorelimit=args.scorelimit,
red_model=red[1],
blue_model=blue[1],
)
for i, (red, blue) in enumerate(pairs)
]
tournament_results = [
r
for r in tqdm(
pool.imap(competition, tournament_params),
total=len(tournament_params),
desc="Tournament",
)
]

# Count number of wins for each model
rankings = Counter()
for (red, blue), result in zip(pairs, tournament_results):
if result == CompetitionResult.RED:
rankings[red[0]] += 1
elif result == CompetitionResult.BLUE:
rankings[blue[0]] += 1

print("Tournament results:")
rankings_sorted = rankings.most_common()
for model, wins in rankings_sorted:
print(f" {model}: {wins} wins")

# Select best two models to finale
red = next((path, model) for path, model in models if path == rankings_sorted[0][0])
blue = next((path, model) for path, model in models if path == rankings_sorted[1][0])
finale_params = [
CompetitionParams(
order=i,
seed=i,
timelimit=args.timelimit,
scorelimit=args.scorelimit,
red_model=red[1],
blue_model=blue[1],
)
for i in range(args.finale_runs)
]
finale_results = [
r
for r in tqdm(
pool.imap(competition, finale_params), total=len(finale_params), desc="Finale"
)
]
print("Finale results:")
print(f" Red: {red[0]}")
print(f" Blue: {blue[0]}")
red_wins = 0
blue_wins = 0
for result in finale_results:
if result == CompetitionResult.RED:
red_wins += 1
elif result == CompetitionResult.BLUE:
blue_wins += 1
print(f" {red_wins}:{blue_wins}")


if __name__ == "__main__":
Expand Down

0 comments on commit b4a03d8

Please sign in to comment.