implementation of drecon in unity 2022 lts forked from: https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

348 lines
13 KiB

10 months ago
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents;
using UnityEngine;
using ManyWorlds;
using UnityEngine.Assertions;
public class ObservationStats : MonoBehaviour
{
[System.Serializable]
public class Stat
{
public string Name;
public Vector3 Position;
public Quaternion Rotation;
public Vector3 Velocity;
public Vector3 AngularVelocity;
[HideInInspector]
public Vector3 LastLocalPosition;
[HideInInspector]
public Quaternion LastLocalRotation;
[HideInInspector]
public bool LastIsSet;
}
public MonoBehaviour ObjectToTrack;
[Header("Anchor stats")]
public Vector3 HorizontalDirection; // Normalized vector in direction of travel (assume right angle to floor)
// public Vector3 CenterOfMassInWorldSpace;
public Vector3 AngualrVelocity;
[Header("Stats, relative to HorizontalDirection & Center Of Mass")]
public Vector3 CenterOfMassVelocity;
public Vector3 CenterOfMassHorizontalVelocity;
public float CenterOfMassVelocityMagnitude;
public Vector3 CenterOfMassVelocityInRootSpace;
public float CenterOfMassVelocityMagnitudeInRootSpace;
public float CenterOfMassHorizontalVelocityMagnitude;
public Vector3 DesiredCenterOfMassVelocity;
public Vector3 CenterOfMassVelocityDifference;
public List<Stat> Stats;
// [Header("... for debugging")]
[Header("Gizmos")]
public bool VelocityInWorldSpace = true;
public bool HorizontalVelocity = true;
[HideInInspector]
public Vector3 LastCenterOfMassInWorldSpace;
[HideInInspector]
public Quaternion LastRotation;
[HideInInspector]
public bool LastIsSet;
SpawnableEnv _spawnableEnv;
List<Collider> _bodyParts;
internal List<Rigidbody> _rigidbodyParts;
internal List<ArticulationBody> _articulationBodyParts;
GameObject _root;
IAnimationController _animationController;
bool _hasLazyInitialized;
MapAnim2Ragdoll _mapAnim2Ragdoll;
string rootName = "articulation:Hips";
public void setRootName(string s)
{
rootName = s;
}
public void OnAgentInitialize(Transform defaultTransform)
{
Assert.IsFalse(_hasLazyInitialized);
_hasLazyInitialized = true;
_mapAnim2Ragdoll = defaultTransform.GetComponent<MapAnim2Ragdoll>();
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
_animationController = _spawnableEnv.GetComponentInChildren<IAnimationController>();
_rigidbodyParts = ObjectToTrack.GetComponentsInChildren<Rigidbody>().ToList();
_articulationBodyParts = ObjectToTrack.GetComponentsInChildren<ArticulationBody>().ToList();
if (_rigidbodyParts?.Count > 0)
_bodyParts = _rigidbodyParts
.SelectMany(x=>x.GetComponentsInChildren<Collider>())
.Distinct()
.ToList();
else
_bodyParts = _articulationBodyParts
.SelectMany(x=>x.GetComponentsInChildren<Collider>())
.Distinct()
.ToList();
_bodyParts = _bodyParts
.Where(x => x.enabled)
.Where(x => !x.isTrigger)
.Where(x=> {
var ignoreCollider = x.GetComponent<IgnoreColliderForObservation>();
if (ignoreCollider == null)
return true;
return !ignoreCollider.enabled;})
.Distinct()
.ToList();
Stats = _bodyParts
.Select(x => new Stat { Name = x.name })
.ToList();
//TODO: this is quite sketchy, we should have a better way to deal with this
if (_root == null)
{
// Debug.Log("in game object: " + name + " my rootname is: " + rootName);
if (_rigidbodyParts?.Count > 0)
_root = _rigidbodyParts.First(x => x.name == rootName).gameObject;
else
_root = _articulationBodyParts.First(x => x.name == rootName).gameObject;
}
transform.position = defaultTransform.position;
transform.rotation = defaultTransform.rotation;
}
public void OnReset()
{
Assert.IsTrue(_hasLazyInitialized);
ResetStatus();
LastIsSet = false;
}
void ResetStatus()
{
LastIsSet = false;
var timeDelta = float.MinValue;
SetStatusForStep(timeDelta);
}
// Return rotation from one rotation to another
public static Quaternion FromToRotation(Quaternion from, Quaternion to)
{
if (to == from) return Quaternion.identity;
return to * Quaternion.Inverse(from);
}
// Adjust the value of an angle to lie within [-pi, +pi].
public static float NormalizedAngle(float angle)
{
if (angle < 180)
{
return angle * Mathf.Deg2Rad;
}
return (angle - 360) * Mathf.Deg2Rad;
}
// Calculate rotation between two rotations in radians. Adjusts the value to lie within [-pi, +pi].
public static Vector3 NormalizedEulerAngles(Vector3 eulerAngles)
{
var x = NormalizedAngle(eulerAngles.x);
var y = NormalizedAngle(eulerAngles.y);
var z = NormalizedAngle(eulerAngles.z);
return new Vector3(x, y, z);
}
// Find angular velocity. The delta rotation is converted to radians within [-pi, +pi].
public static Vector3 GetAngularVelocity(Quaternion from, Quaternion to, float timeDelta)
{
var rotationVelocity = FromToRotation(from, to);
var angularVelocity = NormalizedEulerAngles(rotationVelocity.eulerAngles) / timeDelta;
return angularVelocity;
}
public void SetStatusForStep(float timeDelta)
{
// get Center Of Mass velocity in f space
Vector3 newCOM;
// if Moocap, then get from anim2Ragdoll
if (_mapAnim2Ragdoll != null)
{
newCOM = _mapAnim2Ragdoll.LastCenterOfMassInWorldSpace;
var newHorizontalDirection = _mapAnim2Ragdoll.HorizontalDirection;
HorizontalDirection = newHorizontalDirection / 180f;
if (!LastIsSet)
{
LastCenterOfMassInWorldSpace = newCOM;
}
transform.position = newCOM;
transform.rotation = Quaternion.Euler(newHorizontalDirection);
CenterOfMassVelocity = _mapAnim2Ragdoll.CenterOfMassVelocity;
CenterOfMassVelocityMagnitude = _mapAnim2Ragdoll.CenterOfMassVelocityMagnitude;
CenterOfMassVelocityInRootSpace = transform.InverseTransformVector(CenterOfMassVelocity);
CenterOfMassVelocityMagnitudeInRootSpace = CenterOfMassVelocityInRootSpace.magnitude;
}
else
{
newCOM = GetCenterOfMass();
var newHorizontalDirection = new Vector3(0f, _root.transform.eulerAngles.y, 0f);
HorizontalDirection = newHorizontalDirection / 180f;
if (!LastIsSet)
{
LastCenterOfMassInWorldSpace = newCOM;
}
transform.position = newCOM;
transform.rotation = Quaternion.Euler(newHorizontalDirection);
var velocity = newCOM - LastCenterOfMassInWorldSpace;
velocity /= timeDelta;
CenterOfMassVelocity = velocity;
CenterOfMassVelocityMagnitude = CenterOfMassVelocity.magnitude;
CenterOfMassVelocityInRootSpace = transform.InverseTransformVector(CenterOfMassVelocity);
CenterOfMassVelocityMagnitudeInRootSpace = CenterOfMassVelocityInRootSpace.magnitude;
}
LastCenterOfMassInWorldSpace = newCOM;
// get Center Of Mass horizontal velocity in f space
var comHorizontalDirection = new Vector3(CenterOfMassVelocity.x, 0f, CenterOfMassVelocity.z);
CenterOfMassHorizontalVelocity = transform.InverseTransformVector(comHorizontalDirection);
CenterOfMassHorizontalVelocityMagnitude = CenterOfMassHorizontalVelocity.magnitude;
// get Desired Center Of Mass horizontal velocity in f space
Vector3 desiredCom = _animationController.GetDesiredVelocity();
DesiredCenterOfMassVelocity = transform.InverseTransformVector(desiredCom);
// get Desired Center Of Mass horizontal velocity in f space
CenterOfMassVelocityDifference = DesiredCenterOfMassVelocity - CenterOfMassHorizontalVelocity;
if (!LastIsSet)
{
LastRotation = transform.rotation;
}
AngualrVelocity = GetAngularVelocity(LastRotation, transform.rotation, timeDelta);
LastRotation = transform.rotation;
// get bodyParts stats in local space
foreach (var bodyPart in _bodyParts)
{
Stat bodyPartStat = Stats.First(x => x.Name == bodyPart.name);
Vector3 c = Vector3.zero;
CapsuleCollider capsule = bodyPart as CapsuleCollider;
BoxCollider box = bodyPart as BoxCollider;
SphereCollider sphere = bodyPart as SphereCollider;
if (capsule != null)
c = capsule.center;
else if (box != null)
c = box.center;
else if (sphere != null)
c = sphere.center;
Vector3 worldPosition = bodyPart.transform.TransformPoint(c);
Quaternion worldRotation = bodyPart.transform.rotation;
Vector3 localPosition = transform.InverseTransformPoint(worldPosition);
Quaternion localRotation = FromToRotation(transform.rotation, worldRotation);
if (!bodyPartStat.LastIsSet || !LastIsSet)
{
bodyPartStat.LastLocalPosition = localPosition;
bodyPartStat.LastLocalRotation = localRotation;
}
bodyPartStat.Position = localPosition;
bodyPartStat.Rotation = localRotation;
bodyPartStat.Velocity = (localPosition - bodyPartStat.LastLocalPosition) / timeDelta;
bodyPartStat.AngularVelocity = GetAngularVelocity(bodyPartStat.LastLocalRotation, localRotation, timeDelta);
bodyPartStat.LastLocalPosition = localPosition;
bodyPartStat.LastLocalRotation = localRotation;
bodyPartStat.LastIsSet = true;
}
LastIsSet = true;
}
Vector3 GetCenterOfMass()
{
var centerOfMass = Vector3.zero;
float totalMass = 0f;
foreach (ArticulationBody ab in _articulationBodyParts)
{
centerOfMass += ab.worldCenterOfMass * ab.mass;
totalMass += ab.mass;
}
centerOfMass /= totalMass;
// centerOfMass -= _spawnableEnv.transform.position;
return centerOfMass;
}
void OnDrawGizmosSelected()
{
if (_bodyParts == null || _bodyParts.Count ==0)
return;
// draw arrow for desired input velocity
// Vector3 pos = new Vector3(transform.position.x, transform.position.y, transform.position.z);
Vector3 pos = new Vector3(transform.position.x, .3f, transform.position.z);
Vector3 vector = DesiredCenterOfMassVelocity;
if (VelocityInWorldSpace)
vector = transform.TransformVector(vector);
DrawArrow(pos, vector, Color.green);
Vector3 desiredInputPos = pos + vector;
if (HorizontalVelocity)
{
// arrow for actual velocity
vector = CenterOfMassHorizontalVelocity;
if (VelocityInWorldSpace)
vector = transform.TransformVector(vector);
DrawArrow(pos, vector, Color.blue);
Vector3 actualPos = pos + vector;
// arrow for actual velocity difference
vector = CenterOfMassVelocityDifference;
if (VelocityInWorldSpace)
vector = transform.TransformVector(vector);
DrawArrow(actualPos, vector, Color.red);
}
else
{
vector = CenterOfMassVelocity;
if (VelocityInWorldSpace)
vector = transform.TransformVector(vector);
DrawArrow(pos, vector, Color.blue);
Vector3 actualPos = pos + vector;
// arrow for actual velocity difference
vector = DesiredCenterOfMassVelocity - CenterOfMassVelocity;
if (VelocityInWorldSpace)
vector = transform.TransformVector(vector);
DrawArrow(actualPos, vector, Color.red);
}
}
void DrawArrow(Vector3 start, Vector3 vector, Color color)
{
float headSize = 0.25f;
float headAngle = 20.0f;
Gizmos.color = color;
Gizmos.DrawRay(start, vector);
if (vector.magnitude > 0f)
{
Vector3 right = Quaternion.LookRotation(vector) * Quaternion.Euler(0, 180 + headAngle, 0) * new Vector3(0, 0, 1);
Vector3 left = Quaternion.LookRotation(vector) * Quaternion.Euler(0, 180 - headAngle, 0) * new Vector3(0, 0, 1);
Gizmos.DrawRay(start + vector, right * headSize);
Gizmos.DrawRay(start + vector, left * headSize);
}
}
}