implementation of drecon in unity 2022 lts
forked from:
https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
363 lines
15 KiB
363 lines
15 KiB
10 months ago
|
using System;
|
||
|
using System.Collections;
|
||
|
using System.Collections.Generic;
|
||
|
using System.Linq;
|
||
|
|
||
|
|
||
|
using ManyWorlds;
|
||
|
using Unity.Collections;
|
||
|
using UnityEngine;
|
||
|
|
||
|
public class MarConObservationsStats : MonoBehaviour
|
||
|
{
|
||
|
[Header("Anchor stats")]
|
||
|
public Vector3 HorizontalDirection; // Normalized vector in direction of travel (assume right angle to floor)
|
||
|
public Vector3 AngualrVelocity;
|
||
|
|
||
|
[Header("Stats, relative to HorizontalDirection & Center Of Mass")]
|
||
|
public Vector3 CenterOfMassVelocity;
|
||
|
public Vector3 CenterOfMassHorizontalVelocity;
|
||
|
public float CenterOfMassVelocityMagnitude;
|
||
|
public Vector3 CenterOfMassVelocityInRootSpace;
|
||
|
public float CenterOfMassVelocityMagnitudeInRootSpace;
|
||
|
public float CenterOfMassHorizontalVelocityMagnitude;
|
||
|
public Vector3 DesiredCenterOfMassVelocity;
|
||
|
public Vector3 CenterOfMassVelocityDifference;
|
||
|
[HideInInspector] public Vector3 LastCenterOfMassInWorldSpace;
|
||
|
[HideInInspector] public Quaternion LastRotation;
|
||
|
[Header("Stats, in local joint space")]
|
||
|
public float[] DofRotationWithinRangeOfMotion;
|
||
|
public float[] DofAngularVelocity;
|
||
|
public float[] DofRotationInRad;
|
||
|
[HideInInspector] public float[] DofLastRotationInRad;
|
||
|
|
||
|
ArticulationBody[] _articulationBodyJoints;
|
||
|
Rigidbody[] _rigidBodyJoints;
|
||
|
GameObject[] _jointsToTrack;
|
||
|
Collider[] _collidersToTrack;
|
||
|
GameObject[] _jointForTrackedColliders;
|
||
|
public Vector3[] Positions;
|
||
|
public Quaternion[] Rotations;
|
||
|
public Vector3[] Velocities;
|
||
|
public Vector3[] AngualrVelocities;
|
||
|
[HideInInspector]
|
||
|
public Vector3[] LastLocalPositions;
|
||
|
[HideInInspector]
|
||
|
public Quaternion[] LastLocalRotations;
|
||
|
bool LastIsSet;
|
||
|
MapAnim2Ragdoll _mapAnim2Ragdoll;
|
||
|
GameObject _root;
|
||
|
IAnimationController _animationController;
|
||
|
SpawnableEnv _spawnableEnv;
|
||
|
|
||
|
public void OnAgentInitialize(
|
||
|
Transform defaultTransform,
|
||
|
ArticulationBody[] articulationBodyJoints,
|
||
|
ArticulationBody articulationBodyRoot)
|
||
|
{
|
||
|
_mapAnim2Ragdoll = defaultTransform.GetComponent<MapAnim2Ragdoll>();
|
||
|
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
|
||
|
_animationController = _spawnableEnv.GetComponentInChildren<IAnimationController>();
|
||
|
|
||
|
var jointNames = articulationBodyJoints
|
||
|
.Select(x=>x.name)
|
||
|
.Select(x=>x.Replace("articulation:", ""))
|
||
|
.Select(x=>x.Replace("mixamorig:", ""))
|
||
|
.ToArray();
|
||
|
_articulationBodyJoints = articulationBodyJoints;
|
||
|
_rigidBodyJoints = GetComponentsInChildren<Rigidbody>();
|
||
|
if (_rigidBodyJoints.Length > 0)
|
||
|
{
|
||
|
_jointsToTrack = jointNames
|
||
|
.Select(x=>_rigidBodyJoints.First(y => y.name.EndsWith(x)))
|
||
|
.Select(x=>x.gameObject)
|
||
|
.ToArray();
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
_jointsToTrack = jointNames
|
||
|
.Select(x=>articulationBodyJoints.First(y => y.name.EndsWith(x)))
|
||
|
.Select(x=>x.gameObject)
|
||
|
.ToArray();
|
||
|
}
|
||
|
_collidersToTrack = _jointsToTrack
|
||
|
.SelectMany(x=>x.GetComponentsInChildren<Collider>())
|
||
|
.Where(x => x.enabled)
|
||
|
.Where(x => !x.isTrigger)
|
||
|
.Where(x=> {
|
||
|
var ignoreCollider = x.GetComponent<IgnoreColliderForObservation>();
|
||
|
if (ignoreCollider == null)
|
||
|
return true;
|
||
|
return !ignoreCollider.enabled;})
|
||
|
.Distinct()
|
||
|
.ToArray();
|
||
|
|
||
|
if (_rigidBodyJoints.Length > 0)
|
||
|
{
|
||
|
_jointForTrackedColliders = _collidersToTrack
|
||
|
.Select(x=>x.GetComponentsInParent<Rigidbody>().First())
|
||
|
.Select(x=>x.gameObject)
|
||
|
.ToArray();
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
_jointForTrackedColliders = _collidersToTrack
|
||
|
.Select(x=>x.GetComponentsInParent<ArticulationBody>().First())
|
||
|
.Select(x=>x.gameObject)
|
||
|
.ToArray();
|
||
|
}
|
||
|
Positions = Enumerable.Range(0, _collidersToTrack.Length).Select(x=>Vector3.zero).ToArray();
|
||
|
Rotations = Enumerable.Range(0, _collidersToTrack.Length).Select(x=>Quaternion.identity).ToArray();
|
||
|
Velocities = Enumerable.Range(0, _collidersToTrack.Length).Select(x=>Vector3.zero).ToArray();
|
||
|
AngualrVelocities = Enumerable.Range(0, _collidersToTrack.Length).Select(x=>Vector3.zero).ToArray();
|
||
|
LastLocalPositions = Enumerable.Range(0, _collidersToTrack.Length).Select(x=>Vector3.zero).ToArray();
|
||
|
LastLocalRotations = Enumerable.Range(0, _collidersToTrack.Length).Select(x=>Quaternion.identity).ToArray();
|
||
|
|
||
|
int dof = 0;
|
||
|
foreach (var m in _articulationBodyJoints)
|
||
|
{
|
||
|
if (m.twistLock == ArticulationDofLock.LimitedMotion)
|
||
|
dof++;
|
||
|
if (m.swingYLock == ArticulationDofLock.LimitedMotion)
|
||
|
dof++;
|
||
|
if (m.swingZLock == ArticulationDofLock.LimitedMotion)
|
||
|
dof++;
|
||
|
}
|
||
|
DofRotationWithinRangeOfMotion = Enumerable.Range(0, dof).Select(x=>0f).ToArray();
|
||
|
DofAngularVelocity = Enumerable.Range(0, dof).Select(x=>0f).ToArray();
|
||
|
DofRotationInRad = Enumerable.Range(0, dof).Select(x=>0f).ToArray();
|
||
|
DofLastRotationInRad = Enumerable.Range(0, dof).Select(x=>0f).ToArray();
|
||
|
|
||
|
if (_root == null)
|
||
|
{
|
||
|
Debug.Log("in game object: " + name + "my rootname is: " + articulationBodyRoot.name);
|
||
|
if (_rigidBodyJoints.Length > 0)
|
||
|
_root = _rigidBodyJoints.First(x => x.name == articulationBodyRoot.name).gameObject;
|
||
|
else
|
||
|
_root = articulationBodyRoot.gameObject;
|
||
|
}
|
||
|
transform.position = defaultTransform.position;
|
||
|
transform.rotation = defaultTransform.rotation;
|
||
|
LastIsSet = false;
|
||
|
}
|
||
|
public void OnStep(float timeDelta)
|
||
|
{
|
||
|
// get Center Of Mass velocity in f space
|
||
|
Vector3 newCOM;
|
||
|
// if Moocap, then get from anim2Ragdoll
|
||
|
if (_mapAnim2Ragdoll != null)
|
||
|
{
|
||
|
newCOM = _mapAnim2Ragdoll.LastCenterOfMassInWorldSpace;
|
||
|
var newHorizontalDirection = _mapAnim2Ragdoll.HorizontalDirection;
|
||
|
HorizontalDirection = newHorizontalDirection / 180f;
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
LastCenterOfMassInWorldSpace = newCOM;
|
||
|
}
|
||
|
transform.position = newCOM;
|
||
|
transform.rotation = Quaternion.Euler(newHorizontalDirection);
|
||
|
CenterOfMassVelocity = _mapAnim2Ragdoll.CenterOfMassVelocity;
|
||
|
CenterOfMassVelocityMagnitude = _mapAnim2Ragdoll.CenterOfMassVelocityMagnitude;
|
||
|
CenterOfMassVelocityInRootSpace = transform.InverseTransformVector(CenterOfMassVelocity);
|
||
|
CenterOfMassVelocityMagnitudeInRootSpace = CenterOfMassVelocityInRootSpace.magnitude;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
newCOM = GetCenterOfMass();
|
||
|
var newHorizontalDirection = new Vector3(0f, _root.transform.eulerAngles.y, 0f);
|
||
|
HorizontalDirection = newHorizontalDirection / 180f;
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
LastCenterOfMassInWorldSpace = newCOM;
|
||
|
}
|
||
|
transform.position = newCOM;
|
||
|
transform.rotation = Quaternion.Euler(newHorizontalDirection);
|
||
|
var velocity = newCOM - LastCenterOfMassInWorldSpace;
|
||
|
velocity /= timeDelta;
|
||
|
CenterOfMassVelocity = velocity;
|
||
|
CenterOfMassVelocityMagnitude = CenterOfMassVelocity.magnitude;
|
||
|
CenterOfMassVelocityInRootSpace = transform.InverseTransformVector(CenterOfMassVelocity);
|
||
|
CenterOfMassVelocityMagnitudeInRootSpace = CenterOfMassVelocityInRootSpace.magnitude;
|
||
|
}
|
||
|
LastCenterOfMassInWorldSpace = newCOM;
|
||
|
|
||
|
// get Center Of Mass horizontal velocity in f space
|
||
|
var comHorizontalDirection = new Vector3(CenterOfMassVelocity.x, 0f, CenterOfMassVelocity.z);
|
||
|
CenterOfMassHorizontalVelocity = transform.InverseTransformVector(comHorizontalDirection);
|
||
|
CenterOfMassHorizontalVelocityMagnitude = CenterOfMassHorizontalVelocity.magnitude;
|
||
|
|
||
|
// get Desired Center Of Mass horizontal velocity in f space
|
||
|
Vector3 desiredCom = _animationController.GetDesiredVelocity();
|
||
|
DesiredCenterOfMassVelocity = transform.InverseTransformVector(desiredCom);
|
||
|
|
||
|
// get Desired Center Of Mass horizontal velocity in f space
|
||
|
CenterOfMassVelocityDifference = DesiredCenterOfMassVelocity - CenterOfMassHorizontalVelocity;
|
||
|
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
LastRotation = transform.rotation;
|
||
|
}
|
||
|
AngualrVelocity = Utils.GetAngularVelocity(LastRotation, transform.rotation, timeDelta);
|
||
|
LastRotation = transform.rotation;
|
||
|
|
||
|
TrackUsingColliders(timeDelta);
|
||
|
// TrackUsingJointsForTrackedColliders(timeDelta);
|
||
|
TrackUsingDof(timeDelta);
|
||
|
|
||
|
LastIsSet = true;
|
||
|
}
|
||
|
|
||
|
void TrackUsingJointsForTrackedColliders(float timeDelta)
|
||
|
{
|
||
|
// track in local space
|
||
|
for (int i = 0; i < _jointForTrackedColliders.Length; i++)
|
||
|
{
|
||
|
Transform joint = _jointForTrackedColliders[i].transform;
|
||
|
Vector3 worldPosition = joint.position;
|
||
|
Quaternion worldRotation = transform.rotation;
|
||
|
Vector3 localPosition = transform.InverseTransformPoint(worldPosition);
|
||
|
Quaternion localRotation = Utils.FromToRotation(transform.rotation, worldRotation);
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
LastLocalPositions[i] = localPosition;
|
||
|
LastLocalRotations[i] = localRotation;
|
||
|
}
|
||
|
|
||
|
Positions[i] = localPosition;
|
||
|
Rotations[i] = localRotation;
|
||
|
Velocities[i] = (localPosition - LastLocalPositions[i]) / timeDelta;
|
||
|
AngualrVelocities[i] = Utils.GetAngularVelocity(LastLocalRotations[i], localRotation, timeDelta);
|
||
|
LastLocalPositions[i] = localPosition;
|
||
|
LastLocalRotations[i] = localRotation;
|
||
|
}
|
||
|
}
|
||
|
void TrackUsingColliders(float timeDelta)
|
||
|
{
|
||
|
// track in local space
|
||
|
for (int i = 0; i < _collidersToTrack.Length; i++)
|
||
|
{
|
||
|
Vector3 c = Vector3.zero;
|
||
|
CapsuleCollider capsule = _collidersToTrack[i] as CapsuleCollider;
|
||
|
BoxCollider box = _collidersToTrack[i] as BoxCollider;
|
||
|
SphereCollider sphere = _collidersToTrack[i] as SphereCollider;
|
||
|
Bounds b = new Bounds(c, c);
|
||
|
if (capsule != null)
|
||
|
{
|
||
|
c = capsule.center;
|
||
|
var r = capsule.radius * 2;
|
||
|
var h = capsule.height;
|
||
|
h = Mathf.Max(r, h); // capsules height is clipped at r
|
||
|
if (capsule.direction == 0)
|
||
|
b = new Bounds(c, new Vector3(h, r, r));
|
||
|
else if (capsule.direction == 1)
|
||
|
b = new Bounds(c, new Vector3(r, h, r));
|
||
|
else if (capsule.direction == 2)
|
||
|
b = new Bounds(c, new Vector3(r, r, h));
|
||
|
else throw new NotImplementedException();
|
||
|
}
|
||
|
else if (box != null)
|
||
|
{
|
||
|
c = box.center;
|
||
|
b = new Bounds(c, box.size);
|
||
|
}
|
||
|
else if (sphere != null)
|
||
|
{
|
||
|
c = sphere.center;
|
||
|
var r = sphere.radius * 2;
|
||
|
b = new Bounds(c, new Vector3(r, r, r));
|
||
|
}
|
||
|
else
|
||
|
throw new NotImplementedException();
|
||
|
|
||
|
Vector3 worldPosition = _collidersToTrack[i].transform.TransformPoint(c);
|
||
|
|
||
|
Quaternion worldRotation = _collidersToTrack[i].transform.rotation;
|
||
|
Vector3 localPosition = transform.InverseTransformPoint(worldPosition);
|
||
|
Quaternion localRotation = Utils.FromToRotation(transform.rotation, worldRotation);
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
LastLocalPositions[i] = localPosition;
|
||
|
LastLocalRotations[i] = localRotation;
|
||
|
}
|
||
|
|
||
|
Positions[i] = localPosition;
|
||
|
Rotations[i] = localRotation;
|
||
|
Velocities[i] = (localPosition - LastLocalPositions[i]) / timeDelta;
|
||
|
AngualrVelocities[i] = Utils.GetAngularVelocity(LastLocalRotations[i], localRotation, timeDelta);
|
||
|
LastLocalPositions[i] = localPosition;
|
||
|
LastLocalRotations[i] = localRotation;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void TrackUsingDof(float timeDelta)
|
||
|
{
|
||
|
// track in local space
|
||
|
int i = 0;
|
||
|
for (int j = 0; j < _jointForTrackedColliders.Length; j++)
|
||
|
{
|
||
|
var joint = _jointsToTrack[j];
|
||
|
var reference = _articulationBodyJoints[j];
|
||
|
Vector3 decomposedRotation = Utils.GetSwingTwist(joint.transform.localRotation);
|
||
|
if (reference.twistLock == ArticulationDofLock.LimitedMotion)
|
||
|
{
|
||
|
var drive = reference.xDrive;
|
||
|
var scale = (drive.upperLimit - drive.lowerLimit) / 2f;
|
||
|
var midpoint = drive.lowerLimit + scale;
|
||
|
var deg = decomposedRotation.x;
|
||
|
var pos = (deg - midpoint) / scale;
|
||
|
DofRotationInRad[i] = deg * Mathf.Deg2Rad;
|
||
|
DofRotationWithinRangeOfMotion[i++] = pos;
|
||
|
}
|
||
|
if (reference.swingYLock == ArticulationDofLock.LimitedMotion)
|
||
|
{
|
||
|
var drive = reference.yDrive;
|
||
|
var scale = (drive.upperLimit - drive.lowerLimit) / 2f;
|
||
|
var midpoint = drive.lowerLimit + scale;
|
||
|
var deg = decomposedRotation.y;
|
||
|
var pos = (deg - midpoint) / scale;
|
||
|
DofRotationInRad[i] = deg * Mathf.Deg2Rad;
|
||
|
DofRotationWithinRangeOfMotion[i++] = pos;
|
||
|
}
|
||
|
if (reference.swingZLock == ArticulationDofLock.LimitedMotion)
|
||
|
{
|
||
|
var drive = reference.zDrive;
|
||
|
var scale = (drive.upperLimit - drive.lowerLimit) / 2f;
|
||
|
var midpoint = drive.lowerLimit + scale;
|
||
|
var deg = decomposedRotation.z;
|
||
|
var pos = (deg - midpoint) / scale;
|
||
|
DofRotationInRad[i] = deg * Mathf.Deg2Rad;
|
||
|
DofRotationWithinRangeOfMotion[i++] = pos;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for (i = 0; i < DofRotationInRad.Length; i++)
|
||
|
{
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
DofLastRotationInRad[i] = DofRotationInRad[i];
|
||
|
}
|
||
|
DofAngularVelocity[i] = DofRotationInRad[i] - DofLastRotationInRad[i];
|
||
|
DofAngularVelocity[i] /= timeDelta;
|
||
|
DofLastRotationInRad[i] = DofRotationInRad[i];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
public void OnReset()
|
||
|
{
|
||
|
OnStep(float.Epsilon);
|
||
|
LastIsSet = false;
|
||
|
}
|
||
|
Vector3 GetCenterOfMass()
|
||
|
{
|
||
|
var centerOfMass = Vector3.zero;
|
||
|
float totalMass = 0f;
|
||
|
foreach (ArticulationBody ab in _articulationBodyJoints)
|
||
|
{
|
||
|
centerOfMass += ab.worldCenterOfMass * ab.mass;
|
||
|
totalMass += ab.mass;
|
||
|
}
|
||
|
centerOfMass /= totalMass;
|
||
|
// centerOfMass -= _spawnableEnv.transform.position;
|
||
|
return centerOfMass;
|
||
|
}
|
||
|
}
|