implementation of drecon in unity 2022 lts
forked from:
https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
218 lines
8.1 KiB
218 lines
8.1 KiB
10 months ago
|
using System.Collections;
|
||
|
using System.Collections.Generic;
|
||
|
using System.Linq;
|
||
|
using Unity.MLAgents;
|
||
|
using UnityEngine;
|
||
|
using UnityEngine.Assertions;
|
||
|
using ManyWorlds;
|
||
|
|
||
|
public class DReConRewards : MonoBehaviour
|
||
|
{
|
||
|
[Header("Reward")]
|
||
|
public float SumOfSubRewards;
|
||
|
public float Reward;
|
||
|
|
||
|
[Header("Position Reward")]
|
||
|
public float SumOfDistances;
|
||
|
public float SumOfSqrDistances;
|
||
|
public float PositionReward;
|
||
|
|
||
|
[Header("Velocity Reward")]
|
||
|
public float PointsVelocityDifferenceSquared;
|
||
|
public float PointsVelocityReward;
|
||
|
|
||
|
[Header("Local Pose Reward")]
|
||
|
public List<float> RotationDifferences;
|
||
|
public float SumOfRotationDifferences;
|
||
|
public float SumOfRotationSqrDifferences;
|
||
|
public float LocalPoseReward;
|
||
|
|
||
|
|
||
|
|
||
|
[Header("Center of Mass Velocity Reward")]
|
||
|
public Vector3 MocapCOMVelocity;
|
||
|
public Vector3 RagDollCOMVelocity;
|
||
|
|
||
|
public float COMVelocityDifference;
|
||
|
public float ComReward;
|
||
|
|
||
|
|
||
|
[Header("Distance Factor")]
|
||
|
public float ComDistance;
|
||
|
public float DistanceFactor;
|
||
|
|
||
|
// [Header("Direction Factor")]
|
||
|
// public float DirectionDistance;
|
||
|
// public float DirectionFactor;
|
||
|
|
||
|
|
||
|
[Header("Misc")]
|
||
|
public float HeadHeightDistance;
|
||
|
|
||
|
[Header("Gizmos")]
|
||
|
public int ObjectForPointDistancesGizmo;
|
||
|
|
||
|
SpawnableEnv _spawnableEnv;
|
||
|
MocapControllerArtanim _mocap;
|
||
|
GameObject _ragDoll;
|
||
|
InputController _inputController;
|
||
|
|
||
|
internal DReConRewardStats _mocapBodyStats;
|
||
|
internal DReConRewardStats _ragDollBodyStats;
|
||
|
|
||
|
// List<ArticulationBody> _mocapBodyParts;
|
||
|
// List<ArticulationBody> _ragDollBodyParts;
|
||
|
Transform _mocapHead;
|
||
|
Transform _ragDollHead;
|
||
|
|
||
|
bool _hasLazyInitialized;
|
||
|
|
||
|
[Header("Things to check for rewards")]
|
||
|
public string headname = "head";
|
||
|
|
||
|
public string targetedRootName = "articulation:Hips";
|
||
|
|
||
|
|
||
|
public void OnAgentInitialize()
|
||
|
{
|
||
|
Assert.IsFalse(_hasLazyInitialized);
|
||
|
|
||
|
_hasLazyInitialized = true;
|
||
|
|
||
|
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
|
||
|
Assert.IsNotNull(_spawnableEnv);
|
||
|
|
||
|
_mocap = _spawnableEnv.GetComponentInChildren<MocapControllerArtanim>();
|
||
|
|
||
|
_ragDoll = _spawnableEnv.GetComponentInChildren<RagDollAgent>().gameObject;
|
||
|
Assert.IsNotNull(_mocap);
|
||
|
Assert.IsNotNull(_ragDoll);
|
||
|
_inputController = _spawnableEnv.GetComponentInChildren<InputController>();
|
||
|
// _mocapBodyParts = _mocap.GetComponentsInChildren<ArticulationBody>().ToList();
|
||
|
// _ragDollBodyParts = _ragDoll.GetComponentsInChildren<ArticulationBody>().ToList();
|
||
|
// Assert.AreEqual(_mocapBodyParts.Count, _ragDollBodyParts.Count);
|
||
|
_mocapHead = _mocap
|
||
|
.GetComponentsInChildren<Transform>()
|
||
|
.First(x => x.name == headname);
|
||
|
_ragDollHead = _ragDoll
|
||
|
.GetComponentsInChildren<Transform>()
|
||
|
.First(x => x.name == headname);
|
||
|
_mocapBodyStats = new GameObject("MocapDReConRewardStats").AddComponent<DReConRewardStats>();
|
||
|
_mocapBodyStats.setRootName(targetedRootName);
|
||
|
|
||
|
_mocapBodyStats.ObjectToTrack = _mocap;
|
||
|
|
||
|
_mocapBodyStats.transform.SetParent(_spawnableEnv.transform);
|
||
|
_mocapBodyStats.OnAgentInitialize(_mocapBodyStats.ObjectToTrack.transform);
|
||
|
|
||
|
_ragDollBodyStats= new GameObject("RagDollDReConRewardStats").AddComponent<DReConRewardStats>();
|
||
|
_ragDollBodyStats.setRootName(targetedRootName);
|
||
|
|
||
|
|
||
|
_ragDollBodyStats.ObjectToTrack = this;
|
||
|
_ragDollBodyStats.transform.SetParent(_spawnableEnv.transform);
|
||
|
_ragDollBodyStats.OnAgentInitialize(transform, _mocapBodyStats);
|
||
|
|
||
|
_mocapBodyStats.AssertIsCompatible(_ragDollBodyStats);
|
||
|
}
|
||
|
|
||
|
// Update is called once per frame
|
||
|
public void OnStep(float timeDelta)
|
||
|
{
|
||
|
_mocapBodyStats.SetStatusForStep(timeDelta);
|
||
|
_ragDollBodyStats.SetStatusForStep(timeDelta);
|
||
|
|
||
|
// position reward
|
||
|
List<float> distances = _mocapBodyStats.GetPointDistancesFrom(_ragDollBodyStats);
|
||
|
PositionReward = -7.37f/(distances.Count/6f);
|
||
|
List<float> sqrDistances = distances.Select(x=> x*x).ToList();
|
||
|
SumOfDistances = distances.Sum();
|
||
|
SumOfSqrDistances = sqrDistances.Sum();
|
||
|
PositionReward *= SumOfSqrDistances;
|
||
|
PositionReward = Mathf.Exp(PositionReward);
|
||
|
|
||
|
// center of mass velocity reward
|
||
|
MocapCOMVelocity = _mocapBodyStats.CenterOfMassVelocity;
|
||
|
RagDollCOMVelocity = _ragDollBodyStats.CenterOfMassVelocity;
|
||
|
COMVelocityDifference = (MocapCOMVelocity-RagDollCOMVelocity).magnitude;
|
||
|
ComReward = -Mathf.Pow(COMVelocityDifference,2);
|
||
|
ComReward = Mathf.Exp(ComReward);
|
||
|
|
||
|
// points velocity
|
||
|
List<float> velocityDistances = _mocapBodyStats.GetPointVelocityDistancesFrom(_ragDollBodyStats);
|
||
|
List<float> sqrVelocityDistances = velocityDistances.Select(x=> x*x).ToList();
|
||
|
PointsVelocityDifferenceSquared = sqrVelocityDistances.Sum();
|
||
|
PointsVelocityReward = (-1f/_mocapBodyStats.PointVelocity.Length) * PointsVelocityDifferenceSquared;
|
||
|
PointsVelocityReward = Mathf.Exp(PointsVelocityReward);
|
||
|
|
||
|
// local pose reward
|
||
|
if (RotationDifferences == null || RotationDifferences.Count < _mocapBodyStats.Rotations.Count)
|
||
|
RotationDifferences = Enumerable.Range(0,_mocapBodyStats.Rotations.Count)
|
||
|
.Select(x=>0f)
|
||
|
.ToList();
|
||
|
SumOfRotationDifferences = 0f;
|
||
|
SumOfRotationSqrDifferences = 0f;
|
||
|
for (int i = 0; i < _mocapBodyStats.Rotations.Count; i++)
|
||
|
{
|
||
|
var angle = Quaternion.Angle(_mocapBodyStats.Rotations[i], _ragDollBodyStats.Rotations[i]);
|
||
|
Assert.IsTrue(angle <= 180f);
|
||
|
angle = DReConObservationStats.NormalizedAngle(angle);
|
||
|
var sqrAngle = angle * angle;
|
||
|
RotationDifferences[i] = angle;
|
||
|
SumOfRotationDifferences += angle;
|
||
|
SumOfRotationSqrDifferences += sqrAngle;
|
||
|
}
|
||
|
LocalPoseReward = -6.5f/RotationDifferences.Count;
|
||
|
LocalPoseReward *= SumOfRotationSqrDifferences;
|
||
|
LocalPoseReward = Mathf.Exp(LocalPoseReward);
|
||
|
|
||
|
// distance factor
|
||
|
ComDistance = (_mocapBodyStats.transform.position - _ragDollBodyStats.transform.position).magnitude;
|
||
|
DistanceFactor = Mathf.Pow(ComDistance,2);
|
||
|
DistanceFactor = 1.4f*DistanceFactor;
|
||
|
DistanceFactor = 1.01f-DistanceFactor;
|
||
|
DistanceFactor = Mathf.Clamp(DistanceFactor, 0f, 1f);
|
||
|
|
||
|
// // direction factor
|
||
|
// Vector3 desiredDirection = _inputController.HorizontalDirection;
|
||
|
// var curDirection = _ragDollBodyStats.transform.forward;
|
||
|
// // cosAngle
|
||
|
// var directionDifference = Vector3.Dot(desiredDirection, curDirection);
|
||
|
// DirectionDistance = (1f + directionDifference) /2f; // normalize the error
|
||
|
// DirectionFactor = Mathf.Pow(DirectionDistance,2);
|
||
|
// DirectionFactor = Mathf.Clamp(DirectionFactor, 0f, 1f);
|
||
|
|
||
|
// misc
|
||
|
HeadHeightDistance = (_mocapHead.position.y - _ragDollHead.position.y);
|
||
|
HeadHeightDistance = Mathf.Abs(HeadHeightDistance);
|
||
|
|
||
|
// reward
|
||
|
SumOfSubRewards = PositionReward+ComReward+PointsVelocityReward+LocalPoseReward;
|
||
|
Reward = DistanceFactor*SumOfSubRewards;
|
||
|
// Reward = (DirectionFactor*SumOfSubRewards) * DistanceFactor;
|
||
|
}
|
||
|
public void OnReset()
|
||
|
{
|
||
|
Assert.IsTrue(_hasLazyInitialized);
|
||
|
|
||
|
_mocapBodyStats.OnReset();
|
||
|
_ragDollBodyStats.OnReset();
|
||
|
_ragDollBodyStats.transform.position = _mocapBodyStats.transform.position;
|
||
|
_ragDollBodyStats.transform.rotation = _mocapBodyStats.transform.rotation;
|
||
|
}
|
||
|
public void ShiftMocapCOM(Vector3 snapDistance)
|
||
|
{
|
||
|
_mocapBodyStats.ShiftCOM(snapDistance);
|
||
|
}
|
||
|
|
||
|
void OnDrawGizmos()
|
||
|
{
|
||
|
if (_ragDollBodyStats == null)
|
||
|
return;
|
||
|
var max = (_ragDollBodyStats.Points.Length/6)-1;
|
||
|
ObjectForPointDistancesGizmo = Mathf.Clamp(ObjectForPointDistancesGizmo, -1, max);
|
||
|
// _mocapBodyStats.DrawPointDistancesFrom(_ragDollBodyStats, ObjectForPointDistancesGizmo);
|
||
|
_ragDollBodyStats.DrawPointDistancesFrom(_mocapBodyStats, ObjectForPointDistancesGizmo);
|
||
|
}
|
||
|
}
|