This commit is contained in:
2024-06-12 21:03:42 +02:00
parent 4685d9942b
commit aef3b3ab97
1548 changed files with 5615 additions and 72 deletions

View File

@@ -0,0 +1,110 @@
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using System.Collections.Generic;
public class ObserverAI : Agent
{
#region PROPERTIES
public sbyte LastDirectionDecision { get; private set; }
#endregion
#region VARIABLES
private PlayerScript _playerRef;
private Rigidbody2D rb;
private sbyte direction;
SensorsScript sensorsScript;
private bool isReadyToStartEpisode = false;
#endregion
#region EVENTS
public override void OnActionReceived(ActionBuffers actions)
{
LastDirectionDecision = actions.ContinuousActions[0] > 0 ? (sbyte)1 : (sbyte)-1;
}
#endregion
#region METHODS
public sbyte TakeDecision()
{
//if(isReadyToStartEpisode){
RequestDecision();
//}
return LastDirectionDecision == 0 ? (sbyte)-1 : LastDirectionDecision;
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(rb.transform.localPosition.y);
sensor.AddObservation(rb.velocity.y);
sensor.AddObservation(GameHandler.Instance.FrameDistance);
if (sensorsScript != null)
{
List<SensorData> sensorDataList = sensorsScript.GetObserverRaycast();
// Display raycast data
foreach (SensorData data in sensorDataList)
{
sensor.AddObservation(data.hitLayer);
sensor.AddObservation(data.distance);
}
}
else
{
Debug.LogWarning("SensorsScript component is not assigned.");
}
}
public override void Heuristic(in ActionBuffers actionsOut)
{
// Heuristic method
}
#endregion
#region LIFECYCLE
void Awake()
{
rb = GetComponent<Rigidbody2D>();
sensorsScript = GetComponent<SensorsScript>();
isReadyToStartEpisode = true;
Debug.Log("Observer IA enabled !");
}
/*void Update()
{
var verticalPosition = rb.transform.position.y;
var reward = 0.1f / (0.01f + ((verticalPosition + 0.59f) * (verticalPosition + 0.59f)));
AddReward(reward);
}*/
private void OnTriggerEnter2D(Collider2D other)
{
if (other.CompareTag("Laser"))
{
AddReward(-100f);
//EndEpisode();
}
else if (other.CompareTag("GoodCoin"))
{
AddReward(100f);
}
else if (other.CompareTag("BadCoin"))
{
AddReward(50f);
}
else if (other.CompareTag("Missile"))
{
AddReward(-500f);
}
}
#endregion
}

View File

@@ -0,0 +1,60 @@
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using System.Collections.Generic;
using Unity.MLAgents.Policies;
using Unity.VisualScripting;
using Unity.Barracuda;
public class ObserverBehaviour : PlayerBehaviour
{
private ObserverAI _observerAi;
private BehaviorParameters _behaviourParameters;
protected override void Awake()
{
base.Awake();
_observerAi = _playerRef.gameObject.AddComponent<ObserverAI>();
_observerAi.enabled = false;
_behaviourParameters = _playerRef.gameObject.GetComponent<BehaviorParameters>();// ?? _playerRef.gameObject.AddComponent<BehaviorParameters>();
_behaviourParameters.BehaviorName = "Observer";
_behaviourParameters.BrainParameters.VectorObservationSize = 29;
_behaviourParameters.BrainParameters.ActionSpec =
//new ActionSpec(
// numContinuousActions: 1, // Une seule action continue
// discreteBranchSizes: new int[] { 0 } // Une action discrète avec 0 choix possibles
// );
ActionSpec.MakeContinuous(1);
_behaviourParameters.Model = Resources.Load<NNModel>("Observer");
_observerAi.EndEpisode();
_observerAi.enabled = true;
}
protected override sbyte ChooseDirection()
{
var direction = _observerAi.TakeDecision();
return direction;
}
protected override string ChoosePlayerName()
{
return "Bot Observateur";
}
protected override Color ChoosePlayerColor()
{
return new Color(0, 255, 0);
}
protected override bool ChooseIfPlayerIsHuman()
{
return false;
}
}