diff --git a/Assets/RocketAI71.onnx b/Assets/RocketAI71.onnx new file mode 100644 index 0000000..e0af490 Binary files /dev/null and b/Assets/RocketAI71.onnx differ diff --git a/Assets/RocketAI71.onnx.meta b/Assets/RocketAI71.onnx.meta new file mode 100644 index 0000000..e639c3b --- /dev/null +++ b/Assets/RocketAI71.onnx.meta @@ -0,0 +1,16 @@ +fileFormatVersion: 2 +guid: 34253930b5dfcd54c871fce67c8ce5b4 +ScriptedImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 2 + userData: + assetBundleName: + assetBundleVariant: + script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3} + optimizeModel: 1 + forceArbitraryBatchSize: 1 + treatErrorsAsWarnings: 0 + importMode: 1 + weightsTypeMode: 0 + activationTypeMode: 0 diff --git a/Assets/RocketAIAgent.cs b/Assets/RocketAIAgent.cs index 09c02da..5165286 100644 --- a/Assets/RocketAIAgent.cs +++ b/Assets/RocketAIAgent.cs @@ -6,6 +6,8 @@ using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using Unity.Mathematics; +using Unity.MLAgents.Policies; +using Random = UnityEngine.Random; public class RocketAIAgent : Agent { @@ -20,6 +22,7 @@ public override void Initialize() base.Initialize(); lastPoints = 0; + GetComponent().TeamId = (int)Random.Range(0.0f, 100.0f); } public override void CollectObservations(VectorSensor sensor) @@ -99,7 +102,7 @@ public override void OnActionReceived(ActionBuffers actionBuffers) if (rocket.points > lastPoints) { // big reward if we hit something - reward = 3.0f * (rocket.points - lastPoints); + reward = 1.0f; // 3.0f * (rocket.points - lastPoints); SetReward(reward); Debug.Log("Reward for points " + rocket.points + " last points " + lastPoints + " reward " + reward); lastPoints = rocket.points; @@ -123,7 +126,7 @@ public override void OnActionReceived(ActionBuffers actionBuffers) */ } - + /* if (rocket.countRotations != lastCountRotations) { // negative reward for excessive rotations @@ -136,7 +139,8 @@ public override void OnActionReceived(ActionBuffers actionBuffers) lastCountRotations = rocket.countRotations; } - + */ + /* if (rb) { if (rb.angularVelocity.magnitude < 0.02f) @@ -170,7 +174,7 @@ public override void OnActionReceived(ActionBuffers actionBuffers) SetReward(reward); } } - + */ if (raySensor) { if (raySensor.RaySensor != null) @@ -190,7 +194,7 @@ public override void OnActionReceived(ActionBuffers actionBuffers) //Debug.Log("Reward for being near something " + raySensor.DetectableTags.Count + " reward " + reward); } */ - + /* if (raySensor.RaySensor.RayPerceptionOutput.RayOutputs[8].HasHit) { if (rocket.fireInput) @@ -245,7 +249,7 @@ public override void OnActionReceived(ActionBuffers actionBuffers) //Debug.Log("reward thrusting away from something " + reward); } } - + */ /* if (raySensor.RaySensor.RayPerceptionOutput.RayOutputs[0].HasHit || raySensor.RaySensor.RayPerceptionOutput.RayOutputs[1].HasHit @@ -283,10 +287,13 @@ public override void OnEpisodeBegin() { base.OnEpisodeBegin(); + // make sure everyone is on a different team, i.e. every man for himself + GetComponent().TeamId = (int)Random.Range(0.0f, 100.0f); + raySensor = transform.gameObject.GetComponentInChildren(); rocket = transform.gameObject.GetComponent(); rb = transform.gameObject.GetComponent(); - + if (rocket) { lastPoints = rocket.points; @@ -300,7 +307,7 @@ public override void OnEpisodeBegin() public void EpisodeEndGood() { Debug.Log("reward survived for lifetime " + 0.2f); - SetReward(0.2f); + //SetReward(0.2f); EndEpisode(); } diff --git a/Assets/RocketSphereAI1.prefab b/Assets/RocketSphereAI1.prefab index 64961f1..f076bfa 100644 Binary files a/Assets/RocketSphereAI1.prefab and b/Assets/RocketSphereAI1.prefab differ diff --git a/Assets/RocketSphereAI2.prefab b/Assets/RocketSphereAI2.prefab index 24f09c4..870f0a4 100644 Binary files a/Assets/RocketSphereAI2.prefab and b/Assets/RocketSphereAI2.prefab differ