implementation of drecon in unity 2022 lts
forked from:
https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
358 lines
13 KiB
358 lines
13 KiB
9 months ago
|
using System.Collections;
|
||
|
using System.Collections.Generic;
|
||
|
using System.Linq;
|
||
|
using Unity.MLAgents;
|
||
|
using UnityEngine;
|
||
|
using ManyWorlds;
|
||
|
using UnityEngine.Assertions;
|
||
|
|
||
|
public class DReConObservationStats : MonoBehaviour
|
||
|
{
|
||
|
[System.Serializable]
|
||
|
public class Stat
|
||
|
{
|
||
|
public string Name;
|
||
|
public Vector3 Position;
|
||
|
public Quaternion Rotation;
|
||
|
public Vector3 Velocity;
|
||
|
public Vector3 AngualrVelocity;
|
||
|
[HideInInspector]
|
||
|
public Vector3 LastLocalPosition;
|
||
|
[HideInInspector]
|
||
|
public Quaternion LastLocalRotation;
|
||
|
[HideInInspector]
|
||
|
public bool LastIsSet;
|
||
|
}
|
||
|
|
||
|
public MonoBehaviour ObjectToTrack;
|
||
|
List<string> _bodyPartsToTrack;
|
||
|
|
||
|
[Header("Anchor stats")]
|
||
|
public Vector3 HorizontalDirection; // Normalized vector in direction of travel (assume right angle to floor)
|
||
|
// public Vector3 CenterOfMassInWorldSpace;
|
||
|
public Vector3 AngualrVelocity;
|
||
|
|
||
|
[Header("Stats, relative to HorizontalDirection & Center Of Mass")]
|
||
|
public Vector3 CenterOfMassVelocity;
|
||
|
public Vector3 CenterOfMassHorizontalVelocity;
|
||
|
public float CenterOfMassVelocityMagnitude;
|
||
|
public float CenterOfMassHorizontalVelocityMagnitude;
|
||
|
public Vector3 DesiredCenterOfMassVelocity;
|
||
|
public Vector3 CenterOfMassVelocityDifference;
|
||
|
public List<Stat> Stats;
|
||
|
|
||
|
// [Header("... for debugging")]
|
||
|
[Header("Gizmos")]
|
||
|
public bool VelocityInWorldSpace = true;
|
||
|
public bool HorizontalVelocity = true;
|
||
|
|
||
|
[HideInInspector]
|
||
|
public Vector3 LastCenterOfMassInWorldSpace;
|
||
|
[HideInInspector]
|
||
|
public Quaternion LastRotation;
|
||
|
[HideInInspector]
|
||
|
public bool LastIsSet;
|
||
|
|
||
|
|
||
|
SpawnableEnv _spawnableEnv;
|
||
|
List<Collider> _bodyParts;
|
||
|
internal List<Rigidbody> _rigidbodyParts;
|
||
|
internal List<ArticulationBody> _articulationBodyParts;
|
||
|
GameObject _root;
|
||
|
InputController _inputController;
|
||
|
bool _hasLazyInitialized;
|
||
|
|
||
|
string rootName = "articulation:Hips";
|
||
|
|
||
|
public void setRootName(string s) {
|
||
|
rootName = s;
|
||
|
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
public void OnAgentInitialize(List<string> bodyPartsToTrack, Transform defaultTransform)
|
||
|
{
|
||
|
Assert.IsFalse(_hasLazyInitialized);
|
||
|
_hasLazyInitialized = true;
|
||
|
|
||
|
_bodyPartsToTrack = bodyPartsToTrack;
|
||
|
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
|
||
|
_inputController = _spawnableEnv.GetComponentInChildren<InputController>();
|
||
|
_rigidbodyParts = ObjectToTrack.GetComponentsInChildren<Rigidbody>().ToList();
|
||
|
_articulationBodyParts = ObjectToTrack.GetComponentsInChildren<ArticulationBody>().ToList();
|
||
|
|
||
|
|
||
|
if (_rigidbodyParts?.Count > 0)
|
||
|
_bodyParts = _rigidbodyParts
|
||
|
.SelectMany(x=>x.GetComponentsInChildren<Collider>())
|
||
|
.Distinct()
|
||
|
.ToList();
|
||
|
else
|
||
|
_bodyParts = _articulationBodyParts
|
||
|
.SelectMany(x=>x.GetComponentsInChildren<Collider>())
|
||
|
.Distinct()
|
||
|
.ToList();
|
||
|
// if (_rigidbodyParts?.Count > 0)
|
||
|
// _bodyParts = _rigidbodyParts
|
||
|
// .SelectMany(x => x.GetComponentsInChildren<Transform>())
|
||
|
// .Distinct()
|
||
|
// .ToList();
|
||
|
// else
|
||
|
// _bodyParts = _articulationBodyParts
|
||
|
// .SelectMany(x => x.GetComponentsInChildren<Transform>())
|
||
|
// .Distinct()
|
||
|
// .ToList();
|
||
|
|
||
|
var bodyPartNames = _bodyParts.Select(x=>x.name);
|
||
|
if (_bodyPartsToTrack?.Count > 0)
|
||
|
_bodyParts = _bodyPartsToTrack
|
||
|
.Where(x=>bodyPartNames.Contains(x))
|
||
|
.Select(x=>_bodyParts.First(y=>y.name == x))
|
||
|
.ToList();
|
||
|
Stats = _bodyParts
|
||
|
.Select(x=> new Stat{Name = x.name})
|
||
|
.ToList();
|
||
|
if (_root == null)
|
||
|
{
|
||
|
// Debug.Log("in game object: " + name + "my rootname is: " + rootName);
|
||
|
_root = _bodyParts.First(x=>x.name== rootName).gameObject;
|
||
|
}
|
||
|
transform.position = defaultTransform.position;
|
||
|
transform.rotation = defaultTransform.rotation;
|
||
|
}
|
||
|
|
||
|
public void OnReset()
|
||
|
{
|
||
|
Assert.IsTrue(_hasLazyInitialized);
|
||
|
ResetStatus();
|
||
|
foreach (var bodyPart in Stats)
|
||
|
{
|
||
|
bodyPart.LastIsSet = false;
|
||
|
}
|
||
|
LastIsSet = false;
|
||
|
}
|
||
|
void ResetStatus()
|
||
|
{
|
||
|
foreach (var bodyPart in Stats)
|
||
|
{
|
||
|
bodyPart.LastIsSet = false;
|
||
|
}
|
||
|
LastIsSet = false;
|
||
|
var timeDelta = float.MinValue;
|
||
|
SetStatusForStep(timeDelta);
|
||
|
}
|
||
|
|
||
|
|
||
|
// Return rotation from one rotation to another
|
||
|
public static Quaternion FromToRotation(Quaternion from, Quaternion to) {
|
||
|
if (to == from) return Quaternion.identity;
|
||
|
|
||
|
return to * Quaternion.Inverse(from);
|
||
|
}
|
||
|
|
||
|
// Adjust the value of an angle to lie within [-pi, +pi].
|
||
|
public static float NormalizedAngle(float angle) {
|
||
|
if (angle < 180) {
|
||
|
return angle * Mathf.Deg2Rad;
|
||
|
}
|
||
|
return (angle - 360) * Mathf.Deg2Rad;
|
||
|
}
|
||
|
|
||
|
// Calculate rotation between two rotations in radians. Adjusts the value to lie within [-pi, +pi].
|
||
|
public static Vector3 NormalizedEulerAngles(Vector3 eulerAngles) {
|
||
|
var x = NormalizedAngle(eulerAngles.x);
|
||
|
var y = NormalizedAngle(eulerAngles.y);
|
||
|
var z = NormalizedAngle(eulerAngles.z);
|
||
|
return new Vector3(x, y, z);
|
||
|
}
|
||
|
|
||
|
// Find angular velocity. The delta rotation is converted to radians within [-pi, +pi].
|
||
|
public static Vector3 GetAngularVelocity(Quaternion from, Quaternion to, float timeDelta) {
|
||
|
var rotationVelocity = FromToRotation(from, to);
|
||
|
var angularVelocity = NormalizedEulerAngles(rotationVelocity.eulerAngles) / timeDelta;
|
||
|
return angularVelocity;
|
||
|
}
|
||
|
|
||
|
public void SetStatusForStep(float timeDelta)
|
||
|
{
|
||
|
// find Center Of Mass
|
||
|
Vector3 newCOM;
|
||
|
if (_rigidbodyParts?.Count > 0)
|
||
|
newCOM = GetCenterOfMass(_rigidbodyParts);
|
||
|
else
|
||
|
newCOM = GetCenterOfMass(_articulationBodyParts);
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
LastCenterOfMassInWorldSpace = newCOM;
|
||
|
}
|
||
|
|
||
|
// generate Horizontal Direction
|
||
|
var newHorizontalDirection = new Vector3(0f, _root.transform.eulerAngles.y, 0f);
|
||
|
HorizontalDirection = newHorizontalDirection / 180f;
|
||
|
|
||
|
// set this object to be f space
|
||
|
transform.position = newCOM;
|
||
|
transform.rotation = Quaternion.Euler(newHorizontalDirection);
|
||
|
|
||
|
// get Center Of Mass velocity in f space
|
||
|
var velocity = transform.position - LastCenterOfMassInWorldSpace;
|
||
|
velocity /= timeDelta;
|
||
|
CenterOfMassVelocity = transform.InverseTransformVector(velocity);
|
||
|
CenterOfMassVelocityMagnitude = CenterOfMassVelocity.magnitude;
|
||
|
|
||
|
// get Center Of Mass horizontal velocity in f space
|
||
|
var comHorizontalDirection = new Vector3(velocity.x, 0f, velocity.z);
|
||
|
CenterOfMassHorizontalVelocity = transform.InverseTransformVector(comHorizontalDirection);
|
||
|
CenterOfMassHorizontalVelocityMagnitude = CenterOfMassHorizontalVelocity.magnitude;
|
||
|
|
||
|
// get Desired Center Of Mass horizontal velocity in f space
|
||
|
Vector3 desiredCom = new Vector3(
|
||
|
_inputController.DesiredHorizontalVelocity.x,
|
||
|
0f,
|
||
|
_inputController.DesiredHorizontalVelocity.y);
|
||
|
DesiredCenterOfMassVelocity = transform.InverseTransformVector(desiredCom);
|
||
|
|
||
|
// get Desired Center Of Mass horizontal velocity in f space
|
||
|
CenterOfMassVelocityDifference = DesiredCenterOfMassVelocity-CenterOfMassHorizontalVelocity;
|
||
|
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
LastRotation = transform.rotation;
|
||
|
}
|
||
|
AngualrVelocity = GetAngularVelocity(LastRotation, transform.rotation, timeDelta);
|
||
|
LastRotation = transform.rotation;
|
||
|
LastCenterOfMassInWorldSpace = newCOM;
|
||
|
LastIsSet = true;
|
||
|
|
||
|
// get bodyParts stats in local space
|
||
|
foreach (var bodyPart in _bodyParts)
|
||
|
{
|
||
|
Stat bodyPartStat = Stats.First(x=>x.Name == bodyPart.name);
|
||
|
|
||
|
Vector3 c = Vector3.zero;
|
||
|
CapsuleCollider capsule = bodyPart as CapsuleCollider;
|
||
|
BoxCollider box = bodyPart as BoxCollider;
|
||
|
SphereCollider sphere = bodyPart as SphereCollider;
|
||
|
if (capsule != null)
|
||
|
c = capsule.center;
|
||
|
else if (box != null)
|
||
|
c = box.center;
|
||
|
else if (sphere != null)
|
||
|
c = sphere.center;
|
||
|
Vector3 worldPosition = bodyPart.transform.TransformPoint(c);
|
||
|
// Vector3 worldPosition = transform.position;
|
||
|
|
||
|
Quaternion worldRotation = bodyPart.transform.rotation;
|
||
|
Vector3 localPosition = transform.InverseTransformPoint(worldPosition);
|
||
|
Quaternion localRotation = FromToRotation(transform.rotation, worldRotation);
|
||
|
if (!bodyPartStat.LastIsSet)
|
||
|
{
|
||
|
bodyPartStat.LastLocalPosition = localPosition;
|
||
|
bodyPartStat.LastLocalRotation = localRotation;
|
||
|
}
|
||
|
|
||
|
bodyPartStat.Position = localPosition;
|
||
|
bodyPartStat.Rotation = localRotation;
|
||
|
bodyPartStat.Velocity = (localPosition - bodyPartStat.LastLocalPosition)/timeDelta;
|
||
|
bodyPartStat.AngualrVelocity = GetAngularVelocity(bodyPartStat.LastLocalRotation, localRotation, timeDelta);
|
||
|
bodyPartStat.LastLocalPosition = localPosition;
|
||
|
bodyPartStat.LastLocalRotation = localRotation;
|
||
|
bodyPartStat.LastIsSet = true;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
Vector3 GetCenterOfMass(IEnumerable<Rigidbody> bodies)
|
||
|
{
|
||
|
var centerOfMass = Vector3.zero;
|
||
|
float totalMass = 0f;
|
||
|
foreach (Rigidbody ab in bodies)
|
||
|
{
|
||
|
centerOfMass += ab.worldCenterOfMass * ab.mass;
|
||
|
totalMass += ab.mass;
|
||
|
}
|
||
|
centerOfMass /= totalMass;
|
||
|
// centerOfMass -= _spawnableEnv.transform.position;
|
||
|
return centerOfMass;
|
||
|
}
|
||
|
Vector3 GetCenterOfMass(IEnumerable<ArticulationBody> bodies)
|
||
|
{
|
||
|
var centerOfMass = Vector3.zero;
|
||
|
float totalMass = 0f;
|
||
|
foreach (ArticulationBody ab in bodies)
|
||
|
{
|
||
|
centerOfMass += ab.worldCenterOfMass * ab.mass;
|
||
|
totalMass += ab.mass;
|
||
|
}
|
||
|
centerOfMass /= totalMass;
|
||
|
// centerOfMass -= _spawnableEnv.transform.position;
|
||
|
return centerOfMass;
|
||
|
}
|
||
|
|
||
|
void OnDrawGizmosSelected()
|
||
|
{
|
||
|
if (_bodyPartsToTrack == null)
|
||
|
return;
|
||
|
// draw arrow for desired input velocity
|
||
|
// Vector3 pos = new Vector3(transform.position.x, transform.position.y, transform.position.z);
|
||
|
Vector3 pos = new Vector3(transform.position.x, .3f, transform.position.z);
|
||
|
Vector3 vector = DesiredCenterOfMassVelocity;
|
||
|
if (VelocityInWorldSpace)
|
||
|
vector = transform.TransformVector(vector);
|
||
|
DrawArrow(pos, vector, Color.green);
|
||
|
Vector3 desiredInputPos = pos+vector;
|
||
|
|
||
|
if (HorizontalVelocity)
|
||
|
{
|
||
|
// arrow for actual velocity
|
||
|
vector = CenterOfMassHorizontalVelocity;
|
||
|
if (VelocityInWorldSpace)
|
||
|
vector = transform.TransformVector(vector);
|
||
|
DrawArrow(pos, vector, Color.blue);
|
||
|
Vector3 actualPos = pos+vector;
|
||
|
|
||
|
// arrow for actual velocity difference
|
||
|
vector = CenterOfMassVelocityDifference;
|
||
|
if (VelocityInWorldSpace)
|
||
|
vector = transform.TransformVector(vector);
|
||
|
DrawArrow(actualPos, vector, Color.red);
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
vector = CenterOfMassVelocity;
|
||
|
if (VelocityInWorldSpace)
|
||
|
vector = transform.TransformVector(vector);
|
||
|
DrawArrow(pos, vector, Color.blue);
|
||
|
Vector3 actualPos = pos+vector;
|
||
|
|
||
|
// arrow for actual velocity difference
|
||
|
vector = DesiredCenterOfMassVelocity-CenterOfMassVelocity;
|
||
|
if (VelocityInWorldSpace)
|
||
|
vector = transform.TransformVector(vector);
|
||
|
DrawArrow(actualPos, vector, Color.red);
|
||
|
|
||
|
}
|
||
|
}
|
||
|
void DrawArrow(Vector3 start, Vector3 vector, Color color)
|
||
|
{
|
||
|
float headSize = 0.25f;
|
||
|
float headAngle = 20.0f;
|
||
|
Gizmos.color = color;
|
||
|
Gizmos.DrawRay(start, vector);
|
||
|
|
||
|
if (vector.magnitude > 0f)
|
||
|
{
|
||
|
Vector3 right = Quaternion.LookRotation(vector) * Quaternion.Euler(0,180+headAngle,0) * new Vector3(0,0,1);
|
||
|
Vector3 left = Quaternion.LookRotation(vector) * Quaternion.Euler(0,180-headAngle,0) * new Vector3(0,0,1);
|
||
|
Gizmos.DrawRay(start + vector, right * headSize);
|
||
|
Gizmos.DrawRay(start + vector, left * headSize);
|
||
|
}
|
||
|
}
|
||
|
public void ShiftCOM (Vector3 snapDistance)
|
||
|
{
|
||
|
Vector3 newCOM = LastCenterOfMassInWorldSpace + snapDistance;
|
||
|
LastCenterOfMassInWorldSpace = newCOM;
|
||
|
transform.position = newCOM;
|
||
|
}
|
||
|
}
|