using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; using System.Collections.Generic; public class ObserverAI : Agent { #region PROPERTIES public sbyte LastDirectionDecision { get; private set; } #endregion #region VARIABLES private PlayerScript _playerRef; private Rigidbody2D rb; private sbyte direction; SensorsScript sensorsScript; private bool isReadyToStartEpisode = false; #endregion #region EVENTS public override void OnActionReceived(ActionBuffers actions) { LastDirectionDecision = actions.ContinuousActions[0] > 0 ? (sbyte)1 : (sbyte)-1; } #endregion #region METHODS public sbyte TakeDecision() { //if(isReadyToStartEpisode){ RequestDecision(); //} return LastDirectionDecision == 0 ? (sbyte)-1 : LastDirectionDecision; } public override void CollectObservations(VectorSensor sensor) { sensor.AddObservation(rb.transform.localPosition.y); sensor.AddObservation(rb.velocity.y); sensor.AddObservation(GameHandler.Instance.FrameDistance); if (sensorsScript != null) { List sensorDataList = sensorsScript.GetObserverRaycast(); // Display raycast data foreach (SensorData data in sensorDataList) { sensor.AddObservation(data.hitLayer); sensor.AddObservation(data.distance); } } else { Debug.LogWarning("SensorsScript component is not assigned."); } } public override void Heuristic(in ActionBuffers actionsOut) { // Heuristic method } #endregion #region LIFECYCLE void Awake() { rb = GetComponent(); sensorsScript = GetComponent(); isReadyToStartEpisode = true; Debug.Log("Observer IA enabled !"); } /*void Update() { var verticalPosition = rb.transform.position.y; var reward = 0.1f / (0.01f + ((verticalPosition + 0.59f) * (verticalPosition + 0.59f))); AddReward(reward); }*/ private void OnTriggerEnter2D(Collider2D other) { if (other.CompareTag("Laser")) { AddReward(-100f); //EndEpisode(); } else if (other.CompareTag("GoodCoin")) { AddReward(100f); } else if (other.CompareTag("BadCoin")) { AddReward(50f); } else if (other.CompareTag("Missile")) { AddReward(-500f); } } #endregion }