implementation of drecon in unity 2022 lts forked from: https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

184 lines
5.8 KiB

10 months ago
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using System.Linq;
using static BodyHelper002;
using ManyWorlds;
public class TerrainMarathonManAgent : Agent, IOnTerrainCollision
{
BodyManager002 _bodyManager;
TerrainGenerator _terrainGenerator;
SpawnableEnv _spawnableEnv;
int _stepCountAtLastMeter;
public int lastXPosInMeters;
public int maxXPosInMeters;
float _pain;
List<float> distances;
float fraction;
bool _hasLazyInitialized;
override public void CollectObservations(VectorSensor sensor)
{
if (!_hasLazyInitialized)
{
OnEpisodeBegin();
}
Vector3 normalizedVelocity = _bodyManager.GetNormalizedVelocity();
var pelvis = _bodyManager.GetFirstBodyPart(BodyPartGroup.Hips);
var shoulders = _bodyManager.GetFirstBodyPart(BodyPartGroup.Torso);
sensor.AddObservation(normalizedVelocity);
sensor.AddObservation(pelvis.Rigidbody.transform.forward); // gyroscope
sensor.AddObservation(pelvis.Rigidbody.transform.up);
sensor.AddObservation(shoulders.Rigidbody.transform.forward); // gyroscope
sensor.AddObservation(shoulders.Rigidbody.transform.up);
sensor.AddObservation(_bodyManager.GetSensorIsInTouch());
foreach (var bodyPart in _bodyManager.BodyParts)
{
bodyPart.UpdateObservations();
sensor.AddObservation(bodyPart.ObsLocalPosition);
sensor.AddObservation(bodyPart.ObsRotation);
sensor.AddObservation(bodyPart.ObsRotationVelocity);
sensor.AddObservation(bodyPart.ObsVelocity);
}
sensor.AddObservation(_bodyManager.GetSensorObservations());
(distances, fraction) =
_terrainGenerator.GetDistances2d(
pelvis.Rigidbody.transform.position, _bodyManager.ShowMonitor);
sensor.AddObservation(distances);
sensor.AddObservation(fraction);
// _bodyManager.OnCollectObservationsHandleDebug(GetInfo());
}
public override void OnActionReceived(ActionBuffers actions)
{
float[] vectorAction = actions.ContinuousActions.Select(x=>x).ToArray();
if (!_hasLazyInitialized)
{
return;
}
// apply actions to body
_bodyManager.OnAgentAction(vectorAction);
// manage reward
float velocity = Mathf.Clamp(_bodyManager.GetNormalizedVelocity().x, 0f, 1f);
var actionDifference = _bodyManager.GetActionDifference();
var actionsAbsolute = vectorAction.Select(x=>Mathf.Abs(x)).ToList();
var actionsAtLimit = actionsAbsolute.Select(x=> x>=1f ? 1f : 0f).ToList();
float actionaAtLimitCount = actionsAtLimit.Sum();
float notAtLimitBonus = 1f - (actionaAtLimitCount / (float) actionsAbsolute.Count);
float reducedPowerBonus = 1f - actionsAbsolute.Average();
// velocity *= 0.85f;
// reducedPowerBonus *=0f;
// notAtLimitBonus *=.1f;
// actionDifference *=.05f;
// var reward = velocity
// + notAtLimitBonus
// + reducedPowerBonus
// + actionDifference;
var reward = velocity;
AddReward(reward);
_bodyManager.SetDebugFrameReward(reward);
var pelvis = _bodyManager.GetFirstBodyPart(BodyPartGroup.Hips);
float xpos =
_bodyManager.GetBodyParts(BodyPartGroup.Foot)
.Average(x=>x.Transform.position.x);
int newXPosInMeters = (int) xpos;
if (newXPosInMeters > lastXPosInMeters) {
lastXPosInMeters = newXPosInMeters;
_stepCountAtLastMeter = this.StepCount;
}
if (newXPosInMeters > maxXPosInMeters)
maxXPosInMeters = newXPosInMeters;
var terminate = false;
// bool isInBounds = _spawnableEnv.IsPointWithinBoundsInWorldSpace(pelvis.Transform.position);
// if (!isInBounds)
// if (pelvis.Rigidbody.transform.position.y < 0f)
if (_terrainGenerator.IsPointOffEdge(pelvis.Transform.position)){
terminate = true;
AddReward(-1f);
}
if (this.StepCount-_stepCountAtLastMeter >= (200*5))
terminate = true;
else if (xpos < 4f && _pain > 1f)
terminate = true;
else if (xpos < 2f && _pain > 0f)
terminate = true;
else if (_pain > 2f)
terminate = true;
if (terminate){
EndEpisode();
}
_pain = 0f;
}
public override void OnEpisodeBegin()
{
if (!_hasLazyInitialized)
{
_bodyManager = GetComponent<BodyManager002>();
_bodyManager.BodyConfig = MarathonManAgent.BodyConfig;
_bodyManager.OnInitializeAgent();
_hasLazyInitialized = true;
}
if (_bodyManager == null)
_bodyManager = GetComponent<BodyManager002>();
_bodyManager.OnAgentReset();
if (_terrainGenerator == null)
_terrainGenerator = GetComponent<TerrainGenerator>();
if (_spawnableEnv == null)
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
_terrainGenerator.Reset();
lastXPosInMeters = (int)
_bodyManager.GetBodyParts(BodyPartGroup.Foot)
.Average(x=>x.Transform.position.x);
_pain = 0f;
}
public virtual void OnTerrainCollision(GameObject other, GameObject terrain)
{
// if (string.Compare(terrain.name, "Terrain", true) != 0)
if (terrain.GetComponent<Terrain>() == null)
return;
// if (!_styleAnimator.AnimationStepsReady)
// return;
// HACK - for when agent has not been initialized
if (_bodyManager == null)
return;
var bodyPart = _bodyManager.BodyParts.FirstOrDefault(x=>x.Transform.gameObject == other);
if (bodyPart == null)
return;
switch (bodyPart.Group)
{
case BodyHelper002.BodyPartGroup.None:
case BodyHelper002.BodyPartGroup.Foot:
case BodyHelper002.BodyPartGroup.LegLower:
break;
case BodyHelper002.BodyPartGroup.LegUpper:
case BodyHelper002.BodyPartGroup.Hand:
case BodyHelper002.BodyPartGroup.ArmLower:
case BodyHelper002.BodyPartGroup.ArmUpper:
_pain += .1f;
break;
default:
// AddReward(-100f);
_pain += 5f;
break;
}
}
}