implementation of drecon in unity 2022 lts forked from: https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

198 lines
5.7 KiB

10 months ago
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using System.Linq;
using static BodyHelper002;
public class MarathonManAgent : Agent, IOnTerrainCollision
{
BodyManager002 _bodyManager;
bool _isDone;
bool _hasLazyInitialized;
public static BodyConfig BodyConfig = new BodyConfig
{
GetBodyPartGroup = (name) =>
{
name = name.ToLower();
if (name.Contains("mixamorig"))
return BodyPartGroup.None;
if (name.Contains("butt"))
return BodyPartGroup.Hips;
if (name.Contains("torso"))
return BodyPartGroup.Torso;
if (name.Contains("head"))
return BodyPartGroup.Head;
if (name.Contains("waist"))
return BodyPartGroup.Spine;
if (name.Contains("thigh"))
return BodyPartGroup.LegUpper;
if (name.Contains("shin"))
return BodyPartGroup.LegLower;
if (name.Contains("right_right_foot") || name.Contains("left_left_foot"))
return BodyPartGroup.Foot;
if (name.Contains("upper_arm"))
return BodyPartGroup.ArmUpper;
if (name.Contains("larm"))
return BodyPartGroup.ArmLower;
if (name.Contains("hand"))
return BodyPartGroup.Hand;
return BodyPartGroup.None;
},
GetMuscleGroup = (name) =>
{
name = name.ToLower();
if (name.Contains("mixamorig"))
return MuscleGroup.None;
if (name.Contains("butt"))
return MuscleGroup.Hips;
if (name.Contains("lower_waist")
|| name.Contains("abdomen_y"))
return MuscleGroup.Spine;
if (name.Contains("thigh")
|| name.Contains("hip"))
return MuscleGroup.LegUpper;
if (name.Contains("shin"))
return MuscleGroup.LegLower;
if (name.Contains("right_right_foot")
|| name.Contains("left_left_foot")
|| name.Contains("ankle_x"))
return MuscleGroup.Foot;
if (name.Contains("upper_arm"))
return MuscleGroup.ArmUpper;
if (name.Contains("larm"))
return MuscleGroup.ArmLower;
if (name.Contains("hand"))
return MuscleGroup.Hand;
return MuscleGroup.None;
},
GetRootBodyPart = () => BodyPartGroup.Hips,
GetRootMuscle = () => MuscleGroup.Hips
};
override public void CollectObservations(VectorSensor sensor)
{
if (!_hasLazyInitialized)
{
OnEpisodeBegin();
}
Vector3 normalizedVelocity = _bodyManager.GetNormalizedVelocity();
var pelvis = _bodyManager.GetFirstBodyPart(BodyPartGroup.Hips);
var shoulders = _bodyManager.GetFirstBodyPart(BodyPartGroup.Torso);
sensor.AddObservation(normalizedVelocity);
sensor.AddObservation(pelvis.Rigidbody.transform.forward); // gyroscope
sensor.AddObservation(pelvis.Rigidbody.transform.up);
sensor.AddObservation(shoulders.Rigidbody.transform.forward); // gyroscope
sensor.AddObservation(shoulders.Rigidbody.transform.up);
sensor.AddObservation(_bodyManager.GetSensorIsInTouch());
foreach (var bodyPart in _bodyManager.BodyParts)
{
bodyPart.UpdateObservations();
sensor.AddObservation(bodyPart.ObsLocalPosition);
sensor.AddObservation(bodyPart.ObsRotation);
sensor.AddObservation(bodyPart.ObsRotationVelocity);
sensor.AddObservation(bodyPart.ObsVelocity);
}
sensor.AddObservation(_bodyManager.GetSensorObservations());
// _bodyManager.OnCollectObservationsHandleDebug(GetInfo());
}
public override void OnActionReceived(ActionBuffers actions)
{
float[] vectorAction = actions.ContinuousActions.Select(x=>x).ToArray();
if (!_hasLazyInitialized)
{
return;
}
_isDone = false;
// apply actions to body
_bodyManager.OnAgentAction(vectorAction);
// manage reward
float velocity = Mathf.Clamp(_bodyManager.GetNormalizedVelocity().x, 0f, 1f);
var actionDifference = _bodyManager.GetActionDifference();
var actionsAbsolute = vectorAction.Select(x=>Mathf.Abs(x)).ToList();
var actionsAtLimit = actionsAbsolute.Select(x=> x>=1f ? 1f : 0f).ToList();
float actionaAtLimitCount = actionsAtLimit.Sum();
float notAtLimitBonus = 1f - (actionaAtLimitCount / (float) actionsAbsolute.Count);
float reducedPowerBonus = 1f - actionsAbsolute.Average();
// velocity *= 0.85f;
// reducedPowerBonus *=0f;
// notAtLimitBonus *=.1f;
// actionDifference *=.05f;
// var reward = velocity
// + notAtLimitBonus
// + reducedPowerBonus
// + actionDifference;
var pelvis = _bodyManager.GetFirstBodyPart(BodyPartGroup.Hips);
if (pelvis.Transform.position.y<0){
EndEpisode();
}
var reward = velocity;
AddReward(reward);
_bodyManager.SetDebugFrameReward(reward);
}
public override void OnEpisodeBegin()
{
if (!_hasLazyInitialized)
{
_bodyManager = GetComponent<BodyManager002>();
_bodyManager.BodyConfig = MarathonManAgent.BodyConfig;
_bodyManager.OnInitializeAgent();
_hasLazyInitialized = true;
}
_isDone = true;
_bodyManager.OnAgentReset();
}
public virtual void OnTerrainCollision(GameObject other, GameObject terrain)
{
// if (string.Compare(terrain.name, "Terrain", true) != 0)
if (terrain.GetComponent<Terrain>() == null)
return;
// if (!_styleAnimator.AnimationStepsReady)
// return;
// HACK - for when agent has not been initialized
if (_bodyManager == null)
return;
var bodyPart = _bodyManager.BodyParts.FirstOrDefault(x=>x.Transform.gameObject == other);
if (bodyPart == null)
return;
switch (bodyPart.Group)
{
case BodyHelper002.BodyPartGroup.None:
case BodyHelper002.BodyPartGroup.Foot:
// case BodyHelper002.BodyPartGroup.LegUpper:
case BodyHelper002.BodyPartGroup.LegLower:
case BodyHelper002.BodyPartGroup.Hand:
// case BodyHelper002.BodyPartGroup.ArmLower:
// case BodyHelper002.BodyPartGroup.ArmUpper:
break;
default:
// AddReward(-100f);
if (!_isDone){
EndEpisode();
}
break;
}
}
}