Projet-IA-Madelaine/Scripts/PlayerBehaviourScripts/Observer/ObserverAi.cs
2024-06-12 21:03:42 +02:00

111 lines
2.6 KiB
C#

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using System.Collections.Generic;
public class ObserverAI : Agent
{
#region PROPERTIES
public sbyte LastDirectionDecision { get; private set; }
#endregion
#region VARIABLES
private PlayerScript _playerRef;
private Rigidbody2D rb;
private sbyte direction;
SensorsScript sensorsScript;
private bool isReadyToStartEpisode = false;
#endregion
#region EVENTS
public override void OnActionReceived(ActionBuffers actions)
{
LastDirectionDecision = actions.ContinuousActions[0] > 0 ? (sbyte)1 : (sbyte)-1;
}
#endregion
#region METHODS
public sbyte TakeDecision()
{
//if(isReadyToStartEpisode){
RequestDecision();
//}
return LastDirectionDecision == 0 ? (sbyte)-1 : LastDirectionDecision;
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(rb.transform.localPosition.y);
sensor.AddObservation(rb.velocity.y);
sensor.AddObservation(GameHandler.Instance.FrameDistance);
if (sensorsScript != null)
{
List<SensorData> sensorDataList = sensorsScript.GetObserverRaycast();
// Display raycast data
foreach (SensorData data in sensorDataList)
{
sensor.AddObservation(data.hitLayer);
sensor.AddObservation(data.distance);
}
}
else
{
Debug.LogWarning("SensorsScript component is not assigned.");
}
}
public override void Heuristic(in ActionBuffers actionsOut)
{
// Heuristic method
}
#endregion
#region LIFECYCLE
void Awake()
{
rb = GetComponent<Rigidbody2D>();
sensorsScript = GetComponent<SensorsScript>();
isReadyToStartEpisode = true;
Debug.Log("Observer IA enabled !");
}
/*void Update()
{
var verticalPosition = rb.transform.position.y;
var reward = 0.1f / (0.01f + ((verticalPosition + 0.59f) * (verticalPosition + 0.59f)));
AddReward(reward);
}*/
private void OnTriggerEnter2D(Collider2D other)
{
if (other.CompareTag("Laser"))
{
AddReward(-100f);
//EndEpisode();
}
else if (other.CompareTag("GoodCoin"))
{
AddReward(100f);
}
else if (other.CompareTag("BadCoin"))
{
AddReward(50f);
}
else if (other.CompareTag("Missile"))
{
AddReward(-500f);
}
}
#endregion
}