Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-arena episodes #35

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Assets/Scripts/AAI3EnvironmentManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,6 @@ public void TriggerArenaChangeEvent(int currentArenaIndex, int totalArenas)
OnArenaChanged?.Invoke(currentArenaIndex, totalArenas);
}

public int getMaxArenaID()
{
return _arenasConfigurations.configurations.Count;
}

public bool GetRandomizeArenasStatus()
{
return _arenasConfigurations.randomizeArenas;
Expand Down Expand Up @@ -314,9 +309,14 @@ private Dictionary<string, int> RetrieveEnvironmentParameters()

#region Configuration Management Methods

public bool GetConfiguration(int arenaID, out ArenaConfiguration arenaConfiguration)
public ArenaConfiguration GetConfiguration(int arenaID)
{
return _arenasConfigurations.configurations.TryGetValue(arenaID, out arenaConfiguration);
ArenaConfiguration returnConfiguration;
if (!_arenasConfigurations.configurations.TryGetValue(arenaID, out returnConfiguration))
{
throw new KeyNotFoundException($"Tried to load arena {arenaID} but it did not exist");
}
return returnConfiguration;
}

public void AddConfiguration(int arenaID, ArenaConfiguration arenaConfiguration)
Expand Down
2 changes: 2 additions & 0 deletions Assets/Scripts/ArenasParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public class ArenaConfiguration
public bool toUpdate = false;
public string protoString = "";
public int randomSeed = 0;
public bool mergeNextArena = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having this boolean applied locally (per arena definition in the config file) would be inferior to having it applied globally, like the canChangePerspective boolean. It would make sense to set this boolean once (TRUE = ALL ARENAS ARE MERGED / FALSE = NO MERGEDARENAS) and only once at the top of the config file which would then apply to all arenas, rather than specifiying it for each arena.

However, if the point of this parameter is to have the flexibility of choosing exactly which and what arenas are to be merged, then this makes sense having it set locally.

Can you elaborate on the actual requirements in this regard?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the second of your points re flexibility is what's important for this change.

An example use case for this feature is to have a maze that the agent navigates several times during one episode. For example the first time might be to learn the maze structure and then the second time we've added a block, so that the most efficient route is no longer possible and the agent has to use its learning from the previous run to reroute. For this example if we had 3 mazes we wanted to test on, we would have 6 arenas we want to merge in pairs (0-1, 2-3, 4-5).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, perhaps an elaboration could be made on the actual and most probable use case for this feature, in the docs. Maybe it was just me but i did not understand why and how it would be used :)


public ArenaConfiguration() { }

Expand Down Expand Up @@ -175,6 +176,7 @@ internal ArenaConfiguration(YAMLDefs.Arena yamlArena)
toUpdate = true;
protoString = yamlArena.ToString();
randomSeed = yamlArena.randomSeed;
this.mergeNextArena = yamlArena.mergeNextArena;
alhasacademy96 marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
Expand Down
20 changes: 13 additions & 7 deletions Assets/Scripts/TrainingAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class TrainingAgent : Agent, IPrefab
private float _freezeDelay = 0f;
private bool _isFrozen = false;

private bool _nextUpdateEpisodeEnd = false;
private bool _nextUpdateCompleteArena = false;
alhasacademy96 marked this conversation as resolved.
Show resolved Hide resolved

[Header("Agent Notification")]
public bool showNotification = false;
Expand Down Expand Up @@ -214,20 +214,20 @@ public override void Heuristic(in ActionBuffers actionsOut)

#region Agent Health Methods

public void UpdateHealthNextStep(float updateAmount, bool andEndEpisode = false)
public void UpdateHealthNextStep(float updateAmount, bool andCompleteArena = false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See main comment above for the rest of the changes in this method.

{
/// <summary>
/// ML-Agents doesn't guarantee behaviour if an episode ends outside of OnActionReceived
/// Therefore we queue any health updates to happen on the next action step.
/// </summary>
_nextUpdateHealth += updateAmount;
if (andEndEpisode)
if (andCompleteArena)
{
_nextUpdateEpisodeEnd = true;
_nextUpdateCompleteArena = true;
}
}

public void UpdateHealth(float updateAmount, bool andEndEpisode = false)
public void UpdateHealth(float updateAmount, bool andCompleteArena = false)
{
if (NotificationManager.Instance == null && showNotification == true)
{
Expand Down Expand Up @@ -260,12 +260,19 @@ public void UpdateHealth(float updateAmount, bool andEndEpisode = false)
StartCoroutine(EndEpisodeAfterDelay());
return;
}
if (andEndEpisode || _nextUpdateEpisodeEnd)
if (andCompleteArena || _nextUpdateCompleteArena)
{
_nextUpdateCompleteArena = false;
float cumulativeReward = this.GetCumulativeReward();

if (cumulativeReward >= Arena.CurrentPassMark)
{
// If passed and the next arena is merged load that without ending the episode
if (_arena.mergeNextArena)
{
_arena.LoadNextArena();
return;
}
if (showNotification)
{
NotificationManager.Instance.ShowSuccessNotification();
Expand All @@ -278,7 +285,6 @@ public void UpdateHealth(float updateAmount, bool andEndEpisode = false)
NotificationManager.Instance.ShowFailureNotification();
}
}
_nextUpdateEpisodeEnd = false;
alhasacademy96 marked this conversation as resolved.
Show resolved Hide resolved
StartCoroutine(EndEpisodeAfterDelay());
}
}
Expand Down
91 changes: 76 additions & 15 deletions Assets/Scripts/TrainingArena.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
Expand Down Expand Up @@ -40,8 +41,15 @@ public class TrainingArena : MonoBehaviour
private bool isFirstArenaReset = true;
private List<GameObject> spawnedRewards = new List<GameObject>();
private List<int> playedArenas = new List<int>();
private List<int> _mergedArenas = null;

public bool showNotification { get; set; }
public bool mergeNextArena
{
get {
return _arenaConfiguration.mergeNextArena;
}
}

public ArenaBuilder Builder
{
Expand Down Expand Up @@ -87,6 +95,25 @@ private void InitializeArenaComponents()
Spawner_InteractiveButton.RewardSpawned += OnRewardSpawned;
}

alhasacademy96 marked this conversation as resolved.
Show resolved Hide resolved
/// <summary>
/// Provides a list of the arenas in the current config file that are preceeded by an arena with
/// the mergeNextArena property, so that we can avoid loading them when arenas are randomised.
/// </summary>
private List<int> GetMergedArenas()
{
List<int> mergedArenas = new List<int>();
int totalArenas = _environmentManager.GetTotalArenas();
ArenaConfiguration currentArena = _environmentManager.GetConfiguration(0);
bool currentlyMerged = currentArena.mergeNextArena;
for (int i = 1; i < totalArenas; i++)
{
if (currentlyMerged) { mergedArenas.Add(i); }
currentArena = _environmentManager.GetConfiguration(i);
currentlyMerged = currentArena.mergeNextArena;
}
return mergedArenas;
}

#region Arena Handling Methods

/// <summary>
Expand All @@ -99,20 +126,39 @@ public void ResetArena()

CleanUpSpawnedObjects();

DetermineNextArenaID();
SetNextArenaID();

// Load the new configuration
ArenaConfiguration newConfiguration = _environmentManager.GetConfiguration(arenaID);

ApplyNewArenaConfiguration(newConfiguration);

CleanupRewards();

NotifyArenaChange();
}

if (!TryLoadArenaConfiguration(out ArenaConfiguration newConfiguration))
public void LoadNextArena()
{
// TrainingArena must have reset() called at first to initialise arenaID
if (isFirstArenaReset)
{
Debug.LogError("Failed to load arena configuration");
return;
throw new InvalidOperationException("LoadNextArena called before first reset");
}

Debug.Log($"Loading next arena. Previous: {arenaID}, next: {arenaID + 1}");
CleanUpSpawnedObjects();

arenaID += 1;
// Load the new configuration
// TODO: If mergeNextArena is put in the final arena this will throw. Add some validation to move this failure sooner in execution
ArenaConfiguration newConfiguration = _environmentManager.GetConfiguration(arenaID);

ApplyNewArenaConfiguration(newConfiguration);

CleanupRewards();

NotifyArenaChange();

}

private void CleanUpSpawnedObjects()
Expand All @@ -124,39 +170,55 @@ private void CleanUpSpawnedObjects()
}
}

private void DetermineNextArenaID()
private void SetNextArenaID()
{
int totalArenas = _environmentManager.getMaxArenaID();
int totalArenas = _environmentManager.GetTotalArenas();
bool randomizeArenas = _environmentManager.GetRandomizeArenasStatus();

if (isFirstArenaReset)
{
isFirstArenaReset = false;
arenaID = randomizeArenas ? Random.Range(0, totalArenas) : 0;
arenaID = randomizeArenas ? ChooseRandomArenaID(totalArenas) : 0;
}
else
{
arenaID = randomizeArenas ? ChooseRandomArenaID(totalArenas) : (arenaID + 1) % totalArenas;
if (randomizeArenas)
{
arenaID = ChooseRandomArenaID(totalArenas);
}
else
{
// If the next arena is merged, sequentially search for the next unmerged one
ArenaConfiguration preceedingArena = _arenaConfiguration;
arenaID = (arenaID + 1) % totalArenas;
while (preceedingArena.mergeNextArena)
{
preceedingArena = _environmentManager.GetConfiguration(arenaID);
arenaID = (arenaID + 1) % totalArenas;
}
}
}
}

private int ChooseRandomArenaID(int totalArenas)
{
// Populate the list of merged arenas if needed
if (_mergedArenas == null){ _mergedArenas = GetMergedArenas(); }

playedArenas.Add(arenaID);
if (playedArenas.Count >= totalArenas)
{
playedArenas = new List<int> { arenaID };
}

var availableArenas = Enumerable.Range(0, totalArenas).Except(playedArenas).ToList();
var availableArenas = Enumerable.Range(0, totalArenas).Except(playedArenas).Except(_mergedArenas).ToList();
alhasacademy96 marked this conversation as resolved.
Show resolved Hide resolved
return availableArenas[Random.Range(0, availableArenas.Count)];
}

private bool TryLoadArenaConfiguration(out ArenaConfiguration newConfiguration)
{
return _environmentManager.GetConfiguration(arenaID, out newConfiguration);
}

/* Note: to update the active arena to a new ID the following must be called in sequence
GetConfiguration, ApplyNewArenaConfiguration, CleanupRewards, NotifyArenaChange
*/
private void ApplyNewArenaConfiguration(ArenaConfiguration newConfiguration)
{
_arenaConfiguration = newConfiguration;
Expand All @@ -175,7 +237,6 @@ private void ApplyNewArenaConfiguration(ArenaConfiguration newConfiguration)
{
Random.InitState(_arenaConfiguration.randomSeed);
}

Debug.Log($"TimeLimit set to: {_arenaConfiguration.TimeLimit}");
}

Expand Down
1 change: 1 addition & 0 deletions Assets/Scripts/YAMLclasses.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public void SetCurrentPassMark()

public List<int> blackouts { get; set; } = new List<int>();
public int randomSeed { get; set; } = 0;
public bool mergeNextArena { get; set; } = false;
alhasacademy96 marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
Expand Down