implementation of drecon in unity 2022 lts forked from: https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

345 lines
14 KiB

10 months ago
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents;
using UnityEngine;
using ManyWorlds;
using UnityEngine.Assertions;
public class Observations2Learn : MonoBehaviour
{
[Header("Observations")]
[Tooltip("Kinematic character center of mass velocity, Vector3")]
public Vector3 MocapCOMVelocity;
[Tooltip("RagDoll character center of mass velocity, Vector3")]
public Vector3 RagDollCOMVelocity;
[Tooltip("User-input desired horizontal CM velocity. Vector2")]
public Vector2 InputDesiredHorizontalVelocity;
[Tooltip("User-input requests jump, bool")]
public bool InputJump;
[Tooltip("User-input requests backflip, bool")]
public bool InputBackflip;
[Tooltip("Difference between RagDoll character horizontal CM velocity and user-input desired horizontal CM velocity. Vector2")]
public Vector2 HorizontalVelocityDifference;
[Tooltip("Positions and velocities for subset of bodies")]
public List<BodyPartDifferenceStats> BodyPartDifferenceStats;
public List<ObservationStats.Stat> MocapBodyStats;
public List<ObservationStats.Stat> RagDollBodyStats;
[Tooltip("Smoothed actions produced in the previous step of the policy are collected in t −1")]
public float[] PreviousActions;
//[Tooltip("RagDoll ArticulationBody joint positions in reduced space")]
//public float[] RagDollJointPositions;
// [Tooltip("RagDoll ArticulationBody joint velocity in reduced space")]
// public float[] RagDollJointVelocities;
// [Tooltip("RagDoll ArticulationBody joint accelerations in reduced space")]
// public float[] RagDollJointAccelerations;
[Tooltip("RagDoll ArticulationBody joint forces in reduced space")]
public float[] RagDollJointForces;
[Tooltip("Macap: ave of joint angular velocity")]
public float EnergyAngularMocap;
[Tooltip("RagDoll: ave of joint angular velocity")]
public float EnergyAngularRagDoll;
[Tooltip("RagDoll-Macap: ave of joint angular velocity")]
public float EnergyDifferenceAngular;
[Tooltip("Macap: ave of joint velocity in local space")]
public float EnergyPositionalMocap;
[Tooltip("RagDoll: ave of joint velocity in local space")]
public float EnergyPositionalRagDoll;
[Tooltip("RagDoll-Macap: ave of joint velocity in local space")]
public float EnergyDifferencePositional;
[Header("Gizmos")]
public bool VelocityInWorldSpace = true;
public bool PositionInWorldSpace = true;
public string targetedRootName = "articulation:Hips";
InputController _inputController;
SpawnableEnv _spawnableEnv;
ObservationStats _mocapBodyStats;
ObservationStats _ragDollBodyStats;
bool _hasLazyInitialized;
List<ArticulationBody> _motors;
public void OnAgentInitialize()
{
Assert.IsFalse(_hasLazyInitialized);
_hasLazyInitialized = true;
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
_inputController = _spawnableEnv.GetComponentInChildren<InputController>();
_mocapBodyStats = new GameObject("MocapDReConObservationStats").AddComponent<ObservationStats>();
_mocapBodyStats.setRootName(targetedRootName);
_mocapBodyStats.ObjectToTrack = _spawnableEnv.GetComponentInChildren<MapAnim2Ragdoll>();
_mocapBodyStats.transform.SetParent(_spawnableEnv.transform);
_mocapBodyStats.OnAgentInitialize(_mocapBodyStats.ObjectToTrack.transform);
_ragDollBodyStats = new GameObject("RagDollDReConObservationStats").AddComponent<ObservationStats>();
_ragDollBodyStats.setRootName(targetedRootName);
_ragDollBodyStats.ObjectToTrack = this;
_ragDollBodyStats.transform.SetParent(_spawnableEnv.transform);
_ragDollBodyStats.OnAgentInitialize(transform);
BodyPartDifferenceStats = _mocapBodyStats.Stats
.Select(x => new BodyPartDifferenceStats { Name = x.Name })
.ToList();
int numJoints = 0;
_motors = GetComponentsInChildren<ArticulationBody>()
.Where(x => x.jointType == ArticulationJointType.SphericalJoint)
.Where(x => !x.isRoot)
.Distinct()
.ToList();
foreach (var m in _motors)
{
if (m.twistLock == ArticulationDofLock.LimitedMotion)
numJoints++;
if (m.swingYLock == ArticulationDofLock.LimitedMotion)
numJoints++;
if (m.swingZLock == ArticulationDofLock.LimitedMotion)
numJoints++;
}
PreviousActions = Enumerable.Range(0,numJoints).Select(x=>0f).ToArray();
//RagDollJointPositions = Enumerable.Range(0,numJoints).Select(x=>0f).ToArray();
// RagDollJointVelocities = Enumerable.Range(0,numJoints).Select(x=>0f).ToArray();
// RagDollJointAccelerations = Enumerable.Range(0,numJoints).Select(x=>0f).ToArray();
RagDollJointForces = Enumerable.Range(0,numJoints).Select(x=>0f).ToArray();
}
public List<Collider> EstimateBodyPartsForObservation()
{
var colliders = GetComponentsInChildren<Collider>()
.Where(x => x.enabled)
.Where(x => !x.isTrigger)
.Where(x=> {
var ignoreCollider = x.GetComponent<IgnoreColliderForObservation>();
if (ignoreCollider == null)
return true;
return !ignoreCollider.enabled;})
.Distinct()
.ToList();
return colliders;
}
public List<Collider> EstimateBodyPartsForReward()
{
var colliders = GetComponentsInChildren<Collider>()
.Where(x => x.enabled)
.Where(x => !x.isTrigger)
.Where(x=> {
var ignoreCollider = x.GetComponent<IgnoreColliderForReward>();
if (ignoreCollider == null)
return true;
return !ignoreCollider.enabled;})
.Distinct()
.ToList();
return colliders;
}
public void OnStep(float timeDelta)
{
Assert.IsTrue(_hasLazyInitialized);
_mocapBodyStats.SetStatusForStep(timeDelta);
_ragDollBodyStats.SetStatusForStep(timeDelta);
UpdateObservations(timeDelta);
}
public void OnReset()
{
Assert.IsTrue(_hasLazyInitialized);
_mocapBodyStats.OnReset();
_ragDollBodyStats.OnReset();
_ragDollBodyStats.transform.position = _mocapBodyStats.transform.position;
_ragDollBodyStats.transform.rotation = _mocapBodyStats.transform.rotation;
var timeDelta = float.MinValue;
UpdateObservations(timeDelta);
}
public void UpdateObservations(float timeDelta)
{
MocapCOMVelocity = _mocapBodyStats.CenterOfMassVelocity;
RagDollCOMVelocity = _ragDollBodyStats.CenterOfMassVelocity;
InputDesiredHorizontalVelocity = new Vector2(
_ragDollBodyStats.DesiredCenterOfMassVelocity.x,
_ragDollBodyStats.DesiredCenterOfMassVelocity.z);
if (_inputController != null)
{
InputJump = _inputController.Jump;
InputBackflip = _inputController.Backflip;
}
HorizontalVelocityDifference = new Vector2(
_ragDollBodyStats.CenterOfMassVelocityDifference.x,
_ragDollBodyStats.CenterOfMassVelocityDifference.z);
MocapBodyStats = _mocapBodyStats.Stats.ToList();
RagDollBodyStats = MocapBodyStats
.Select(x => _ragDollBodyStats.Stats.First(y => y.Name == x.Name))
.ToList();
// BodyPartStats =
foreach (var differenceStats in BodyPartDifferenceStats)
{
var mocapStats = _mocapBodyStats.Stats.First(x => x.Name == differenceStats.Name);
var ragDollStats = _ragDollBodyStats.Stats.First(x => x.Name == differenceStats.Name);
differenceStats.Position = mocapStats.Position - ragDollStats.Position;
differenceStats.Velocity = mocapStats.Velocity - ragDollStats.Velocity;
differenceStats.AngualrVelocity = mocapStats.AngularVelocity - ragDollStats.AngularVelocity;
differenceStats.Rotation = ObservationStats.GetAngularVelocity(mocapStats.Rotation, ragDollStats.Rotation, timeDelta);
}
int i = 0;
foreach (var m in _motors)
{
int j = 0;
if (m.twistLock == ArticulationDofLock.LimitedMotion)
{
//RagDollJointPositions[i] = m.jointPosition[j];
// RagDollJointVelocities[i] = m.jointVelocity[j];
// RagDollJointAccelerations[i] = m.jointAcceleration[j];
RagDollJointForces[i++] = m.jointForce[j++];
}
if (m.swingYLock == ArticulationDofLock.LimitedMotion)
{
// RagDollJointPositions[i] = m.jointPosition[j];
// RagDollJointVelocities[i] = m.jointVelocity[j];
// RagDollJointAccelerations[i] = m.jointAcceleration[j];
RagDollJointForces[i++] = m.jointForce[j++];
}
if (m.swingZLock == ArticulationDofLock.LimitedMotion)
{
// RagDollJointPositions[i] = m.jointPosition[j];
// RagDollJointVelocities[i] = m.jointVelocity[j];
// RagDollJointAccelerations[i] = m.jointAcceleration[j];
RagDollJointForces[i++] = m.jointForce[j++];
}
}
EnergyAngularMocap = MocapBodyStats
.Select(x=>x.AngularVelocity.magnitude)
.Average();
EnergyAngularRagDoll = RagDollBodyStats
.Select(x=>x.AngularVelocity.magnitude)
.Average();
EnergyDifferenceAngular = RagDollBodyStats
.Zip(MocapBodyStats, (x,y) => x.AngularVelocity.magnitude-y.AngularVelocity.magnitude)
.Average();
EnergyPositionalMocap = MocapBodyStats
.Select(x=>x.Velocity.magnitude)
.Average();
EnergyPositionalRagDoll = RagDollBodyStats
.Select(x=>x.Velocity.magnitude)
.Average();
EnergyDifferencePositional = RagDollBodyStats
.Zip(MocapBodyStats, (x,y) => x.Velocity.magnitude-y.Velocity.magnitude)
.Average();
}
public Transform GetRagDollCOM()
{
return _ragDollBodyStats.transform;
}
public Vector3 GetMocapCOMVelocityInWorldSpace()
{
var velocity = _mocapBodyStats.CenterOfMassVelocity;
var velocityInWorldSpace = _mocapBodyStats.transform.TransformVector(velocity);
return velocityInWorldSpace;
}
void OnDrawGizmos()
{
if (_mocapBodyStats == null)
return;
// MocapCOMVelocity
Vector3 pos = new Vector3(transform.position.x, .3f, transform.position.z);
Vector3 vector = MocapCOMVelocity;
if (VelocityInWorldSpace)
vector = _mocapBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.grey);
// RagDollCOMVelocity;
vector = RagDollCOMVelocity;
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.blue);
Vector3 actualPos = pos + vector;
// InputDesiredHorizontalVelocity;
vector = new Vector3(InputDesiredHorizontalVelocity.x, 0f, InputDesiredHorizontalVelocity.y);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.green);
// HorizontalVelocityDifference;
vector = new Vector3(HorizontalVelocityDifference.x, 0f, HorizontalVelocityDifference.y);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(actualPos, vector, Color.red);
for (int i = 0; i < RagDollBodyStats.Count; i++)
{
var stat = RagDollBodyStats[i];
var differenceStat = BodyPartDifferenceStats[i];
pos = stat.Position;
vector = stat.Velocity;
if (PositionInWorldSpace)
pos = _ragDollBodyStats.transform.TransformPoint(pos);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.cyan);
Vector3 velocityPos = pos + vector;
pos = stat.Position;
vector = differenceStat.Position;
if (PositionInWorldSpace)
pos = _ragDollBodyStats.transform.TransformPoint(pos);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
Gizmos.color = Color.magenta;
Gizmos.DrawRay(pos, vector);
Vector3 differencePos = pos + vector;
vector = differenceStat.Velocity;
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(velocityPos, vector, Color.red);
}
}
void DrawArrow(Vector3 start, Vector3 vector, Color color)
{
float headSize = 0.25f;
float headAngle = 20.0f;
Gizmos.color = color;
Gizmos.DrawRay(start, vector);
if (vector.magnitude > 0f)
{
Vector3 right = Quaternion.LookRotation(vector) * Quaternion.Euler(0, 180 + headAngle, 0) * new Vector3(0, 0, 1);
Vector3 left = Quaternion.LookRotation(vector) * Quaternion.Euler(0, 180 - headAngle, 0) * new Vector3(0, 0, 1);
Gizmos.DrawRay(start + vector, right * headSize);
Gizmos.DrawRay(start + vector, left * headSize);
}
}
}