111 lines
2.6 KiB
C#
111 lines
2.6 KiB
C#
|
using UnityEngine;
|
||
|
using Unity.MLAgents;
|
||
|
using Unity.MLAgents.Actuators;
|
||
|
using Unity.MLAgents.Sensors;
|
||
|
using System.Collections.Generic;
|
||
|
|
||
|
public class ObserverAI : Agent
|
||
|
{
|
||
|
#region PROPERTIES
|
||
|
|
||
|
public sbyte LastDirectionDecision { get; private set; }
|
||
|
|
||
|
#endregion
|
||
|
#region VARIABLES
|
||
|
|
||
|
private PlayerScript _playerRef;
|
||
|
private Rigidbody2D rb;
|
||
|
private sbyte direction;
|
||
|
SensorsScript sensorsScript;
|
||
|
private bool isReadyToStartEpisode = false;
|
||
|
|
||
|
#endregion
|
||
|
#region EVENTS
|
||
|
|
||
|
public override void OnActionReceived(ActionBuffers actions)
|
||
|
{
|
||
|
LastDirectionDecision = actions.ContinuousActions[0] > 0 ? (sbyte)1 : (sbyte)-1;
|
||
|
}
|
||
|
|
||
|
#endregion
|
||
|
#region METHODS
|
||
|
|
||
|
public sbyte TakeDecision()
|
||
|
{
|
||
|
//if(isReadyToStartEpisode){
|
||
|
RequestDecision();
|
||
|
//}
|
||
|
return LastDirectionDecision == 0 ? (sbyte)-1 : LastDirectionDecision;
|
||
|
}
|
||
|
|
||
|
|
||
|
public override void CollectObservations(VectorSensor sensor)
|
||
|
{
|
||
|
sensor.AddObservation(rb.transform.localPosition.y);
|
||
|
sensor.AddObservation(rb.velocity.y);
|
||
|
sensor.AddObservation(GameHandler.Instance.FrameDistance);
|
||
|
|
||
|
if (sensorsScript != null)
|
||
|
{
|
||
|
List<SensorData> sensorDataList = sensorsScript.GetObserverRaycast();
|
||
|
|
||
|
// Display raycast data
|
||
|
foreach (SensorData data in sensorDataList)
|
||
|
{
|
||
|
sensor.AddObservation(data.hitLayer);
|
||
|
sensor.AddObservation(data.distance);
|
||
|
}
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
Debug.LogWarning("SensorsScript component is not assigned.");
|
||
|
}
|
||
|
}
|
||
|
|
||
|
public override void Heuristic(in ActionBuffers actionsOut)
|
||
|
{
|
||
|
// Heuristic method
|
||
|
}
|
||
|
|
||
|
#endregion
|
||
|
#region LIFECYCLE
|
||
|
|
||
|
void Awake()
|
||
|
{
|
||
|
rb = GetComponent<Rigidbody2D>();
|
||
|
sensorsScript = GetComponent<SensorsScript>();
|
||
|
isReadyToStartEpisode = true;
|
||
|
Debug.Log("Observer IA enabled !");
|
||
|
|
||
|
}
|
||
|
/*void Update()
|
||
|
{
|
||
|
var verticalPosition = rb.transform.position.y;
|
||
|
var reward = 0.1f / (0.01f + ((verticalPosition + 0.59f) * (verticalPosition + 0.59f)));
|
||
|
AddReward(reward);
|
||
|
}*/
|
||
|
|
||
|
private void OnTriggerEnter2D(Collider2D other)
|
||
|
{
|
||
|
if (other.CompareTag("Laser"))
|
||
|
{
|
||
|
AddReward(-100f);
|
||
|
//EndEpisode();
|
||
|
}
|
||
|
else if (other.CompareTag("GoodCoin"))
|
||
|
{
|
||
|
AddReward(100f);
|
||
|
}
|
||
|
else if (other.CompareTag("BadCoin"))
|
||
|
{
|
||
|
AddReward(50f);
|
||
|
}
|
||
|
else if (other.CompareTag("Missile"))
|
||
|
{
|
||
|
AddReward(-500f);
|
||
|
}
|
||
|
}
|
||
|
#endregion
|
||
|
|
||
|
}
|