implementation of drecon in unity 2022 lts forked from: https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

303 lines
11 KiB

10 months ago
// Implmentation of an Agent. Agent reads observations relevant to the reinforcement
// learning task at hand, acts based on the observations, and receives a reward
// based on its performance.
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using System.Linq;
using ManyWorlds;
public class StyleTransfer002Agent : Agent, IOnSensorCollision, IOnTerrainCollision {
public float FrameReward;
public float AverageReward;
public List<float> Rewards;
public List<float> SensorIsInTouch;
StyleTransfer002Master _master;
StyleTransfer002Animator _localStyleAnimator;
StyleTransfer002Animator _styleAnimator;
DecisionRequester _decisionRequester;
List<GameObject> _sensors;
public bool ShowMonitor = false;
static int _startCount;
static ScoreHistogramData _scoreHistogramData;
int _totalAnimFrames;
bool _ignorScoreForThisFrame;
bool _isDone;
bool _hasLazyInitialized;
// Use this for initialization
void Start () {
_master = GetComponent<StyleTransfer002Master>();
_decisionRequester = GetComponent<DecisionRequester>();
var spawnableEnv = GetComponentInParent<SpawnableEnv>();
_localStyleAnimator = spawnableEnv.gameObject.GetComponentInChildren<StyleTransfer002Animator>();
_styleAnimator = _localStyleAnimator.GetFirstOfThisAnim();
_startCount++;
}
// Update is called once per frame
void Update () {
}
// Collect observations that are used by the Neural Network for training and inference.
override public void CollectObservations(VectorSensor sensor)
{
if (!_hasLazyInitialized)
{
OnEpisodeBegin();
}
sensor.AddObservation(_master.ObsPhase);
foreach (var bodyPart in _master.BodyParts)
{
sensor.AddObservation(bodyPart.ObsLocalPosition);
sensor.AddObservation(bodyPart.ObsRotation);
sensor.AddObservation(bodyPart.ObsRotationVelocity);
sensor.AddObservation(bodyPart.ObsVelocity);
}
foreach (var muscle in _master.Muscles)
{
if (muscle.ConfigurableJoint.angularXMotion != ConfigurableJointMotion.Locked)
sensor.AddObservation(muscle.TargetNormalizedRotationX);
if (muscle.ConfigurableJoint.angularYMotion != ConfigurableJointMotion.Locked)
sensor.AddObservation(muscle.TargetNormalizedRotationY);
if (muscle.ConfigurableJoint.angularZMotion != ConfigurableJointMotion.Locked)
sensor.AddObservation(muscle.TargetNormalizedRotationZ);
}
sensor.AddObservation(_master.ObsCenterOfMass);
sensor.AddObservation(_master.ObsVelocity);
sensor.AddObservation(_master.ObsAngularMoment);
sensor.AddObservation(SensorIsInTouch);
}
// A method that applies the vectorAction to the muscles, and calculates the rewards.
public override void OnActionReceived(ActionBuffers actions)
{
float[] vectorAction = actions.ContinuousActions.Select(x=>x).ToArray();
if (!_hasLazyInitialized)
{
return;
}
_isDone = false;
if (_styleAnimator == _localStyleAnimator)
_styleAnimator.OnAgentAction();
_master.OnAgentAction();
int i = 0;
foreach (var muscle in _master.Muscles)
{
if (muscle.ConfigurableJoint.angularXMotion != ConfigurableJointMotion.Locked)
muscle.TargetNormalizedRotationX = vectorAction[i++];
if (muscle.ConfigurableJoint.angularYMotion != ConfigurableJointMotion.Locked)
muscle.TargetNormalizedRotationY = vectorAction[i++];
if (muscle.ConfigurableJoint.angularZMotion != ConfigurableJointMotion.Locked)
muscle.TargetNormalizedRotationZ = vectorAction[i++];
}
// the scaler factors are picked empirically by calculating the MaxRotationDistance, MaxVelocityDistance achieved for an untrained agent.
var rotationDistance = _master.RotationDistance / 16f ;
var centerOfMassvelocityDistance = _master.CenterOfMassVelocityDistance / 6f ;
var endEffectorDistance = _master.EndEffectorDistance / 1f ;
var endEffectorVelocityDistance = _master.EndEffectorVelocityDistance / 170f;
var jointAngularVelocityDistance = _master.JointAngularVelocityDistance / 7000f;
var jointAngularVelocityDistanceWorld = _master.JointAngularVelocityDistanceWorld / 7000f;
var centerOfMassDistance = _master.CenterOfMassDistance / 0.3f;
var angularMomentDistance = _master.AngularMomentDistance / 150.0f;
var sensorDistance = _master.SensorDistance / 1f;
var rotationReward = 0.35f * Mathf.Exp(-rotationDistance);
var centerOfMassVelocityReward = 0.1f * Mathf.Exp(-centerOfMassvelocityDistance);
var endEffectorReward = 0.15f * Mathf.Exp(-endEffectorDistance);
var endEffectorVelocityReward = 0.1f * Mathf.Exp(-endEffectorVelocityDistance);
var jointAngularVelocityReward = 0.1f * Mathf.Exp(-jointAngularVelocityDistance);
var jointAngularVelocityRewardWorld = 0.0f * Mathf.Exp(-jointAngularVelocityDistanceWorld);
var centerMassReward = 0.05f * Mathf.Exp(-centerOfMassDistance);
var angularMomentReward = 0.15f * Mathf.Exp(-angularMomentDistance);
var sensorReward = 0.0f * Mathf.Exp(-sensorDistance);
var jointsNotAtLimitReward = 0.0f * Mathf.Exp(-JointsAtLimit());
//Debug.Log("---------------");
//Debug.Log("rotation reward: " + rotationReward);
//Debug.Log("endEffectorReward: " + endEffectorReward);
//Debug.Log("endEffectorVelocityReward: " + endEffectorVelocityReward);
//Debug.Log("jointAngularVelocityReward: " + jointAngularVelocityReward);
//Debug.Log("jointAngularVelocityRewardWorld: " + jointAngularVelocityRewardWorld);
//Debug.Log("centerMassReward: " + centerMassReward);
//Debug.Log("centerMassVelocityReward: " + centerOfMassVelocityReward);
//Debug.Log("angularMomentReward: " + angularMomentReward);
//Debug.Log("sensorReward: " + sensorReward);
//Debug.Log("joints not at limit rewards:" + jointsNotAtLimitReward);
float reward = rotationReward +
centerOfMassVelocityReward +
endEffectorReward +
endEffectorVelocityReward +
jointAngularVelocityReward +
jointAngularVelocityRewardWorld +
centerMassReward +
angularMomentReward +
sensorReward +
jointsNotAtLimitReward;
if (!_master.IgnorRewardUntilObservation)
AddReward(reward);
if (reward < 0.5)
EndEpisode();
if (!_isDone){
if (_master.IsDone()){
EndEpisode();
if (_master.StartAnimationIndex > 0)
_master.StartAnimationIndex--;
}
}
FrameReward = reward;
var stepCount = StepCount > 0 ? StepCount : 1;
AverageReward = GetCumulativeReward() / (float) stepCount;
}
// A helper function that calculates a fraction of joints at their limit positions
float JointsAtLimit(string[] ignorJoints = null)
{
int atLimitCount = 0;
int totalJoints = 0;
foreach (var muscle in _master.Muscles)
{
if(muscle.Parent == null)
continue;
var name = muscle.Name;
if (ignorJoints != null && ignorJoints.Contains(name))
continue;
if (Mathf.Abs(muscle.TargetNormalizedRotationX) >= 1f)
atLimitCount++;
if (Mathf.Abs(muscle.TargetNormalizedRotationY) >= 1f)
atLimitCount++;
if (Mathf.Abs(muscle.TargetNormalizedRotationZ) >= 1f)
atLimitCount++;
totalJoints++;
}
float fractionOfJointsAtLimit = (float)atLimitCount / (float)totalJoints;
return fractionOfJointsAtLimit;
}
// Sets reward
public void SetTotalAnimFrames(int totalAnimFrames)
{
_totalAnimFrames = totalAnimFrames;
if (_scoreHistogramData == null) {
var columns = _totalAnimFrames;
if (_decisionRequester?.DecisionPeriod > 1)
columns /= _decisionRequester.DecisionPeriod;
_scoreHistogramData = new ScoreHistogramData(columns, 30);
}
Rewards = _scoreHistogramData.GetAverages().Select(x=>(float)x).ToList();
}
// Resets the agent. Initialize the style animator and master if not initialized.
public override void OnEpisodeBegin()
{
if (!_hasLazyInitialized)
{
_master = GetComponent<StyleTransfer002Master>();
_master.BodyConfig = MarathonManAgent.BodyConfig;
_decisionRequester = GetComponent<DecisionRequester>();
var spawnableEnv = GetComponentInParent<SpawnableEnv>();
_localStyleAnimator = spawnableEnv.gameObject.GetComponentInChildren<StyleTransfer002Animator>();
_styleAnimator = _localStyleAnimator.GetFirstOfThisAnim();
_styleAnimator.BodyConfig = MarathonManAgent.BodyConfig;
_styleAnimator.OnInitializeAgent();
_master.OnInitializeAgent();
_hasLazyInitialized = true;
_localStyleAnimator.DestoryIfNotFirstAnim();
}
_isDone = true;
_ignorScoreForThisFrame = true;
_master.ResetPhase();
_sensors = GetComponentsInChildren<SensorBehavior>()
.Select(x=>x.gameObject)
.ToList();
SensorIsInTouch = Enumerable.Range(0,_sensors.Count).Select(x=>0f).ToList();
if (_scoreHistogramData != null) {
var column = _master.StartAnimationIndex;
if (_decisionRequester?.DecisionPeriod > 1)
column /= _decisionRequester.DecisionPeriod;
if (_ignorScoreForThisFrame)
_ignorScoreForThisFrame = false;
else
_scoreHistogramData.SetItem(column, AverageReward);
}
}
// A method called on terrain collision. Used for early stopping an episode
// on specific objects' collision with terrain.
public virtual void OnTerrainCollision(GameObject other, GameObject terrain)
{
if (string.Compare(terrain.name, "Terrain", true) != 0)
return;
if (!_styleAnimator.AnimationStepsReady)
return;
var bodyPart = _master.BodyParts.FirstOrDefault(x=>x.Transform.gameObject == other);
if (bodyPart == null)
return;
switch (bodyPart.Group)
{
case BodyHelper002.BodyPartGroup.None:
case BodyHelper002.BodyPartGroup.Foot:
case BodyHelper002.BodyPartGroup.LegUpper:
case BodyHelper002.BodyPartGroup.LegLower:
case BodyHelper002.BodyPartGroup.Hand:
case BodyHelper002.BodyPartGroup.ArmLower:
case BodyHelper002.BodyPartGroup.ArmUpper:
break;
default:
EndEpisode();
break;
}
}
// Sets the a flag in Sensors In Touch array when an object enters collision with terrain
public void OnSensorCollisionEnter(Collider sensorCollider, GameObject other)
{
if (string.Compare(other.name, "Terrain", true) !=0)
return;
if (_sensors == null || _sensors.Count == 0)
return;
var sensor = _sensors
.FirstOrDefault(x=>x == sensorCollider.gameObject);
if (sensor != null) {
var idx = _sensors.IndexOf(sensor);
SensorIsInTouch[idx] = 1f;
}
}
// Sets the a flag in Sensors In Touch array when an object stops colliding with terrain
public void OnSensorCollisionExit(Collider sensorCollider, GameObject other)
{
if (string.Compare(other.gameObject.name, "Terrain", true) !=0)
return;
if (_sensors == null || _sensors.Count == 0)
return;
var sensor = _sensors
.FirstOrDefault(x=>x == sensorCollider.gameObject);
if (sensor != null) {
var idx = _sensors.IndexOf(sensor);
SensorIsInTouch[idx] = 0f;
}
}
}