Files
drecon-unity/Assets/2_MarathonController/Scripts/Observations2Learn.cs
2024-03-28 14:55:33 +00:00

345 lines
14 KiB
C#
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents;
using UnityEngine;
using ManyWorlds;
using UnityEngine.Assertions;
public class Observations2Learn : MonoBehaviour
{
[Header("Observations")]
[Tooltip("Kinematic character center of mass velocity, Vector3")]
public Vector3 MocapCOMVelocity;
[Tooltip("RagDoll character center of mass velocity, Vector3")]
public Vector3 RagDollCOMVelocity;
[Tooltip("User-input desired horizontal CM velocity. Vector2")]
public Vector2 InputDesiredHorizontalVelocity;
[Tooltip("User-input requests jump, bool")]
public bool InputJump;
[Tooltip("User-input requests backflip, bool")]
public bool InputBackflip;
[Tooltip("Difference between RagDoll character horizontal CM velocity and user-input desired horizontal CM velocity. Vector2")]
public Vector2 HorizontalVelocityDifference;
[Tooltip("Positions and velocities for subset of bodies")]
public List<BodyPartDifferenceStats> BodyPartDifferenceStats;
public List<ObservationStats.Stat> MocapBodyStats;
public List<ObservationStats.Stat> RagDollBodyStats;
[Tooltip("Smoothed actions produced in the previous step of the policy are collected in t 1")]
public float[] PreviousActions;
//[Tooltip("RagDoll ArticulationBody joint positions in reduced space")]
//public float[] RagDollJointPositions;
// [Tooltip("RagDoll ArticulationBody joint velocity in reduced space")]
// public float[] RagDollJointVelocities;
// [Tooltip("RagDoll ArticulationBody joint accelerations in reduced space")]
// public float[] RagDollJointAccelerations;
[Tooltip("RagDoll ArticulationBody joint forces in reduced space")]
public float[] RagDollJointForces;
[Tooltip("Macap: ave of joint angular velocity")]
public float EnergyAngularMocap;
[Tooltip("RagDoll: ave of joint angular velocity")]
public float EnergyAngularRagDoll;
[Tooltip("RagDoll-Macap: ave of joint angular velocity")]
public float EnergyDifferenceAngular;
[Tooltip("Macap: ave of joint velocity in local space")]
public float EnergyPositionalMocap;
[Tooltip("RagDoll: ave of joint velocity in local space")]
public float EnergyPositionalRagDoll;
[Tooltip("RagDoll-Macap: ave of joint velocity in local space")]
public float EnergyDifferencePositional;
[Header("Gizmos")]
public bool VelocityInWorldSpace = true;
public bool PositionInWorldSpace = true;
public string targetedRootName = "articulation:Hips";
InputController _inputController;
SpawnableEnv _spawnableEnv;
ObservationStats _mocapBodyStats;
ObservationStats _ragDollBodyStats;
bool _hasLazyInitialized;
List<ArticulationBody> _motors;
public void OnAgentInitialize()
{
Assert.IsFalse(_hasLazyInitialized);
_hasLazyInitialized = true;
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
_inputController = _spawnableEnv.GetComponentInChildren<InputController>();
_mocapBodyStats = new GameObject("MocapDReConObservationStats").AddComponent<ObservationStats>();
_mocapBodyStats.setRootName(targetedRootName);
_mocapBodyStats.ObjectToTrack = _spawnableEnv.GetComponentInChildren<MapAnim2Ragdoll>();
_mocapBodyStats.transform.SetParent(_spawnableEnv.transform);
_mocapBodyStats.OnAgentInitialize(_mocapBodyStats.ObjectToTrack.transform);
_ragDollBodyStats = new GameObject("RagDollDReConObservationStats").AddComponent<ObservationStats>();
_ragDollBodyStats.setRootName(targetedRootName);
_ragDollBodyStats.ObjectToTrack = this;
_ragDollBodyStats.transform.SetParent(_spawnableEnv.transform);
_ragDollBodyStats.OnAgentInitialize(transform);
BodyPartDifferenceStats = _mocapBodyStats.Stats
.Select(x => new BodyPartDifferenceStats { Name = x.Name })
.ToList();
int numJoints = 0;
_motors = GetComponentsInChildren<ArticulationBody>()
.Where(x => x.jointType == ArticulationJointType.SphericalJoint)
.Where(x => !x.isRoot)
.Distinct()
.ToList();
foreach (var m in _motors)
{
if (m.twistLock == ArticulationDofLock.LimitedMotion)
numJoints++;
if (m.swingYLock == ArticulationDofLock.LimitedMotion)
numJoints++;
if (m.swingZLock == ArticulationDofLock.LimitedMotion)
numJoints++;
}
PreviousActions = Enumerable.Range(0,numJoints).Select(x=>0f).ToArray();
//RagDollJointPositions = Enumerable.Range(0,numJoints).Select(x=>0f).ToArray();
// RagDollJointVelocities = Enumerable.Range(0,numJoints).Select(x=>0f).ToArray();
// RagDollJointAccelerations = Enumerable.Range(0,numJoints).Select(x=>0f).ToArray();
RagDollJointForces = Enumerable.Range(0,numJoints).Select(x=>0f).ToArray();
}
public List<Collider> EstimateBodyPartsForObservation()
{
var colliders = GetComponentsInChildren<Collider>()
.Where(x => x.enabled)
.Where(x => !x.isTrigger)
.Where(x=> {
var ignoreCollider = x.GetComponent<IgnoreColliderForObservation>();
if (ignoreCollider == null)
return true;
return !ignoreCollider.enabled;})
.Distinct()
.ToList();
return colliders;
}
public List<Collider> EstimateBodyPartsForReward()
{
var colliders = GetComponentsInChildren<Collider>()
.Where(x => x.enabled)
.Where(x => !x.isTrigger)
.Where(x=> {
var ignoreCollider = x.GetComponent<IgnoreColliderForReward>();
if (ignoreCollider == null)
return true;
return !ignoreCollider.enabled;})
.Distinct()
.ToList();
return colliders;
}
public void OnStep(float timeDelta)
{
Assert.IsTrue(_hasLazyInitialized);
_mocapBodyStats.SetStatusForStep(timeDelta);
_ragDollBodyStats.SetStatusForStep(timeDelta);
UpdateObservations(timeDelta);
}
public void OnReset()
{
Assert.IsTrue(_hasLazyInitialized);
_mocapBodyStats.OnReset();
_ragDollBodyStats.OnReset();
_ragDollBodyStats.transform.position = _mocapBodyStats.transform.position;
_ragDollBodyStats.transform.rotation = _mocapBodyStats.transform.rotation;
var timeDelta = float.MinValue;
UpdateObservations(timeDelta);
}
public void UpdateObservations(float timeDelta)
{
MocapCOMVelocity = _mocapBodyStats.CenterOfMassVelocity;
RagDollCOMVelocity = _ragDollBodyStats.CenterOfMassVelocity;
InputDesiredHorizontalVelocity = new Vector2(
_ragDollBodyStats.DesiredCenterOfMassVelocity.x,
_ragDollBodyStats.DesiredCenterOfMassVelocity.z);
if (_inputController != null)
{
InputJump = _inputController.Jump;
InputBackflip = _inputController.Backflip;
}
HorizontalVelocityDifference = new Vector2(
_ragDollBodyStats.CenterOfMassVelocityDifference.x,
_ragDollBodyStats.CenterOfMassVelocityDifference.z);
MocapBodyStats = _mocapBodyStats.Stats.ToList();
RagDollBodyStats = MocapBodyStats
.Select(x => _ragDollBodyStats.Stats.First(y => y.Name == x.Name))
.ToList();
// BodyPartStats =
foreach (var differenceStats in BodyPartDifferenceStats)
{
var mocapStats = _mocapBodyStats.Stats.First(x => x.Name == differenceStats.Name);
var ragDollStats = _ragDollBodyStats.Stats.First(x => x.Name == differenceStats.Name);
differenceStats.Position = mocapStats.Position - ragDollStats.Position;
differenceStats.Velocity = mocapStats.Velocity - ragDollStats.Velocity;
differenceStats.AngualrVelocity = mocapStats.AngularVelocity - ragDollStats.AngularVelocity;
differenceStats.Rotation = ObservationStats.GetAngularVelocity(mocapStats.Rotation, ragDollStats.Rotation, timeDelta);
}
int i = 0;
foreach (var m in _motors)
{
int j = 0;
if (m.twistLock == ArticulationDofLock.LimitedMotion)
{
//RagDollJointPositions[i] = m.jointPosition[j];
// RagDollJointVelocities[i] = m.jointVelocity[j];
// RagDollJointAccelerations[i] = m.jointAcceleration[j];
RagDollJointForces[i++] = m.jointForce[j++];
}
if (m.swingYLock == ArticulationDofLock.LimitedMotion)
{
// RagDollJointPositions[i] = m.jointPosition[j];
// RagDollJointVelocities[i] = m.jointVelocity[j];
// RagDollJointAccelerations[i] = m.jointAcceleration[j];
RagDollJointForces[i++] = m.jointForce[j++];
}
if (m.swingZLock == ArticulationDofLock.LimitedMotion)
{
// RagDollJointPositions[i] = m.jointPosition[j];
// RagDollJointVelocities[i] = m.jointVelocity[j];
// RagDollJointAccelerations[i] = m.jointAcceleration[j];
RagDollJointForces[i++] = m.jointForce[j++];
}
}
EnergyAngularMocap = MocapBodyStats
.Select(x=>x.AngularVelocity.magnitude)
.Average();
EnergyAngularRagDoll = RagDollBodyStats
.Select(x=>x.AngularVelocity.magnitude)
.Average();
EnergyDifferenceAngular = RagDollBodyStats
.Zip(MocapBodyStats, (x,y) => x.AngularVelocity.magnitude-y.AngularVelocity.magnitude)
.Average();
EnergyPositionalMocap = MocapBodyStats
.Select(x=>x.Velocity.magnitude)
.Average();
EnergyPositionalRagDoll = RagDollBodyStats
.Select(x=>x.Velocity.magnitude)
.Average();
EnergyDifferencePositional = RagDollBodyStats
.Zip(MocapBodyStats, (x,y) => x.Velocity.magnitude-y.Velocity.magnitude)
.Average();
}
public Transform GetRagDollCOM()
{
return _ragDollBodyStats.transform;
}
public Vector3 GetMocapCOMVelocityInWorldSpace()
{
var velocity = _mocapBodyStats.CenterOfMassVelocity;
var velocityInWorldSpace = _mocapBodyStats.transform.TransformVector(velocity);
return velocityInWorldSpace;
}
void OnDrawGizmos()
{
if (_mocapBodyStats == null)
return;
// MocapCOMVelocity
Vector3 pos = new Vector3(transform.position.x, .3f, transform.position.z);
Vector3 vector = MocapCOMVelocity;
if (VelocityInWorldSpace)
vector = _mocapBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.grey);
// RagDollCOMVelocity;
vector = RagDollCOMVelocity;
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.blue);
Vector3 actualPos = pos + vector;
// InputDesiredHorizontalVelocity;
vector = new Vector3(InputDesiredHorizontalVelocity.x, 0f, InputDesiredHorizontalVelocity.y);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.green);
// HorizontalVelocityDifference;
vector = new Vector3(HorizontalVelocityDifference.x, 0f, HorizontalVelocityDifference.y);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(actualPos, vector, Color.red);
for (int i = 0; i < RagDollBodyStats.Count; i++)
{
var stat = RagDollBodyStats[i];
var differenceStat = BodyPartDifferenceStats[i];
pos = stat.Position;
vector = stat.Velocity;
if (PositionInWorldSpace)
pos = _ragDollBodyStats.transform.TransformPoint(pos);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.cyan);
Vector3 velocityPos = pos + vector;
pos = stat.Position;
vector = differenceStat.Position;
if (PositionInWorldSpace)
pos = _ragDollBodyStats.transform.TransformPoint(pos);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
Gizmos.color = Color.magenta;
Gizmos.DrawRay(pos, vector);
Vector3 differencePos = pos + vector;
vector = differenceStat.Velocity;
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(velocityPos, vector, Color.red);
}
}
void DrawArrow(Vector3 start, Vector3 vector, Color color)
{
float headSize = 0.25f;
float headAngle = 20.0f;
Gizmos.color = color;
Gizmos.DrawRay(start, vector);
if (vector.magnitude > 0f)
{
Vector3 right = Quaternion.LookRotation(vector) * Quaternion.Euler(0, 180 + headAngle, 0) * new Vector3(0, 0, 1);
Vector3 left = Quaternion.LookRotation(vector) * Quaternion.Euler(0, 180 - headAngle, 0) * new Vector3(0, 0, 1);
Gizmos.DrawRay(start + vector, right * headSize);
Gizmos.DrawRay(start + vector, left * headSize);
}
}
}