Skip to content

Commit

Permalink
try training will bare minimum rewards
Browse files Browse the repository at this point in the history
  • Loading branch information
plaidpants committed Feb 20, 2022
1 parent 615324c commit 1940e1b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 8 deletions.
Binary file added Assets/RocketAI71.onnx
Binary file not shown.
16 changes: 16 additions & 0 deletions Assets/RocketAI71.onnx.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 15 additions & 8 deletions Assets/RocketAIAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.Mathematics;
using Unity.MLAgents.Policies;
using Random = UnityEngine.Random;

public class RocketAIAgent : Agent
{
Expand All @@ -20,6 +22,7 @@ public override void Initialize()
base.Initialize();

lastPoints = 0;
GetComponent<BehaviorParameters>().TeamId = (int)Random.Range(0.0f, 100.0f);
}

public override void CollectObservations(VectorSensor sensor)
Expand Down Expand Up @@ -99,7 +102,7 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
if (rocket.points > lastPoints)
{
// big reward if we hit something
reward = 3.0f * (rocket.points - lastPoints);
reward = 1.0f; // 3.0f * (rocket.points - lastPoints);
SetReward(reward);
Debug.Log("Reward for points " + rocket.points + " last points " + lastPoints + " reward " + reward);
lastPoints = rocket.points;
Expand All @@ -123,7 +126,7 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
*/
}


/*
if (rocket.countRotations != lastCountRotations)
{
// negative reward for excessive rotations
Expand All @@ -136,7 +139,8 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
lastCountRotations = rocket.countRotations;
}

*/
/*
if (rb)
{
if (rb.angularVelocity.magnitude < 0.02f)
Expand Down Expand Up @@ -170,7 +174,7 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
SetReward(reward);
}
}

*/
if (raySensor)
{
if (raySensor.RaySensor != null)
Expand All @@ -190,7 +194,7 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
//Debug.Log("Reward for being near something " + raySensor.DetectableTags.Count + " reward " + reward);
}
*/

/*
if (raySensor.RaySensor.RayPerceptionOutput.RayOutputs[8].HasHit)
{
if (rocket.fireInput)
Expand Down Expand Up @@ -245,7 +249,7 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
//Debug.Log("reward thrusting away from something " + reward);
}
}

*/
/*
if (raySensor.RaySensor.RayPerceptionOutput.RayOutputs[0].HasHit
|| raySensor.RaySensor.RayPerceptionOutput.RayOutputs[1].HasHit
Expand Down Expand Up @@ -283,10 +287,13 @@ public override void OnEpisodeBegin()
{
base.OnEpisodeBegin();

// make sure everyone is on a different team, i.e. every man for himself
GetComponent<BehaviorParameters>().TeamId = (int)Random.Range(0.0f, 100.0f);

raySensor = transform.gameObject.GetComponentInChildren<RayPerceptionSensorComponent3D>();
rocket = transform.gameObject.GetComponent<RocketSphereAI>();
rb = transform.gameObject.GetComponent<Rigidbody>();

if (rocket)
{
lastPoints = rocket.points;
Expand All @@ -300,7 +307,7 @@ public override void OnEpisodeBegin()
public void EpisodeEndGood()
{
Debug.Log("reward survived for lifetime " + 0.2f);
SetReward(0.2f);
//SetReward(0.2f);
EndEpisode();
}

Expand Down
Binary file modified Assets/RocketSphereAI1.prefab
Binary file not shown.
Binary file modified Assets/RocketSphereAI2.prefab
Binary file not shown.

0 comments on commit 1940e1b

Please sign in to comment.