implementation of drecon in unity 2022 lts forked from: https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

358 lines
13 KiB

9 months ago
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents;
using UnityEngine;
using ManyWorlds;
using UnityEngine.Assertions;
public class DReConObservationStats : MonoBehaviour
{
[System.Serializable]
public class Stat
{
public string Name;
public Vector3 Position;
public Quaternion Rotation;
public Vector3 Velocity;
public Vector3 AngualrVelocity;
[HideInInspector]
public Vector3 LastLocalPosition;
[HideInInspector]
public Quaternion LastLocalRotation;
[HideInInspector]
public bool LastIsSet;
}
public MonoBehaviour ObjectToTrack;
List<string> _bodyPartsToTrack;
[Header("Anchor stats")]
public Vector3 HorizontalDirection; // Normalized vector in direction of travel (assume right angle to floor)
// public Vector3 CenterOfMassInWorldSpace;
public Vector3 AngualrVelocity;
[Header("Stats, relative to HorizontalDirection & Center Of Mass")]
public Vector3 CenterOfMassVelocity;
public Vector3 CenterOfMassHorizontalVelocity;
public float CenterOfMassVelocityMagnitude;
public float CenterOfMassHorizontalVelocityMagnitude;
public Vector3 DesiredCenterOfMassVelocity;
public Vector3 CenterOfMassVelocityDifference;
public List<Stat> Stats;
// [Header("... for debugging")]
[Header("Gizmos")]
public bool VelocityInWorldSpace = true;
public bool HorizontalVelocity = true;
[HideInInspector]
public Vector3 LastCenterOfMassInWorldSpace;
[HideInInspector]
public Quaternion LastRotation;
[HideInInspector]
public bool LastIsSet;
SpawnableEnv _spawnableEnv;
List<Collider> _bodyParts;
internal List<Rigidbody> _rigidbodyParts;
internal List<ArticulationBody> _articulationBodyParts;
GameObject _root;
InputController _inputController;
bool _hasLazyInitialized;
string rootName = "articulation:Hips";
public void setRootName(string s) {
rootName = s;
}
public void OnAgentInitialize(List<string> bodyPartsToTrack, Transform defaultTransform)
{
Assert.IsFalse(_hasLazyInitialized);
_hasLazyInitialized = true;
_bodyPartsToTrack = bodyPartsToTrack;
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
_inputController = _spawnableEnv.GetComponentInChildren<InputController>();
_rigidbodyParts = ObjectToTrack.GetComponentsInChildren<Rigidbody>().ToList();
_articulationBodyParts = ObjectToTrack.GetComponentsInChildren<ArticulationBody>().ToList();
if (_rigidbodyParts?.Count > 0)
_bodyParts = _rigidbodyParts
.SelectMany(x=>x.GetComponentsInChildren<Collider>())
.Distinct()
.ToList();
else
_bodyParts = _articulationBodyParts
.SelectMany(x=>x.GetComponentsInChildren<Collider>())
.Distinct()
.ToList();
// if (_rigidbodyParts?.Count > 0)
// _bodyParts = _rigidbodyParts
// .SelectMany(x => x.GetComponentsInChildren<Transform>())
// .Distinct()
// .ToList();
// else
// _bodyParts = _articulationBodyParts
// .SelectMany(x => x.GetComponentsInChildren<Transform>())
// .Distinct()
// .ToList();
var bodyPartNames = _bodyParts.Select(x=>x.name);
if (_bodyPartsToTrack?.Count > 0)
_bodyParts = _bodyPartsToTrack
.Where(x=>bodyPartNames.Contains(x))
.Select(x=>_bodyParts.First(y=>y.name == x))
.ToList();
Stats = _bodyParts
.Select(x=> new Stat{Name = x.name})
.ToList();
if (_root == null)
{
// Debug.Log("in game object: " + name + "my rootname is: " + rootName);
_root = _bodyParts.First(x=>x.name== rootName).gameObject;
}
transform.position = defaultTransform.position;
transform.rotation = defaultTransform.rotation;
}
public void OnReset()
{
Assert.IsTrue(_hasLazyInitialized);
ResetStatus();
foreach (var bodyPart in Stats)
{
bodyPart.LastIsSet = false;
}
LastIsSet = false;
}
void ResetStatus()
{
foreach (var bodyPart in Stats)
{
bodyPart.LastIsSet = false;
}
LastIsSet = false;
var timeDelta = float.MinValue;
SetStatusForStep(timeDelta);
}
// Return rotation from one rotation to another
public static Quaternion FromToRotation(Quaternion from, Quaternion to) {
if (to == from) return Quaternion.identity;
return to * Quaternion.Inverse(from);
}
// Adjust the value of an angle to lie within [-pi, +pi].
public static float NormalizedAngle(float angle) {
if (angle < 180) {
return angle * Mathf.Deg2Rad;
}
return (angle - 360) * Mathf.Deg2Rad;
}
// Calculate rotation between two rotations in radians. Adjusts the value to lie within [-pi, +pi].
public static Vector3 NormalizedEulerAngles(Vector3 eulerAngles) {
var x = NormalizedAngle(eulerAngles.x);
var y = NormalizedAngle(eulerAngles.y);
var z = NormalizedAngle(eulerAngles.z);
return new Vector3(x, y, z);
}
// Find angular velocity. The delta rotation is converted to radians within [-pi, +pi].
public static Vector3 GetAngularVelocity(Quaternion from, Quaternion to, float timeDelta) {
var rotationVelocity = FromToRotation(from, to);
var angularVelocity = NormalizedEulerAngles(rotationVelocity.eulerAngles) / timeDelta;
return angularVelocity;
}
public void SetStatusForStep(float timeDelta)
{
// find Center Of Mass
Vector3 newCOM;
if (_rigidbodyParts?.Count > 0)
newCOM = GetCenterOfMass(_rigidbodyParts);
else
newCOM = GetCenterOfMass(_articulationBodyParts);
if (!LastIsSet)
{
LastCenterOfMassInWorldSpace = newCOM;
}
// generate Horizontal Direction
var newHorizontalDirection = new Vector3(0f, _root.transform.eulerAngles.y, 0f);
HorizontalDirection = newHorizontalDirection / 180f;
// set this object to be f space
transform.position = newCOM;
transform.rotation = Quaternion.Euler(newHorizontalDirection);
// get Center Of Mass velocity in f space
var velocity = transform.position - LastCenterOfMassInWorldSpace;
velocity /= timeDelta;
CenterOfMassVelocity = transform.InverseTransformVector(velocity);
CenterOfMassVelocityMagnitude = CenterOfMassVelocity.magnitude;
// get Center Of Mass horizontal velocity in f space
var comHorizontalDirection = new Vector3(velocity.x, 0f, velocity.z);
CenterOfMassHorizontalVelocity = transform.InverseTransformVector(comHorizontalDirection);
CenterOfMassHorizontalVelocityMagnitude = CenterOfMassHorizontalVelocity.magnitude;
// get Desired Center Of Mass horizontal velocity in f space
Vector3 desiredCom = new Vector3(
_inputController.DesiredHorizontalVelocity.x,
0f,
_inputController.DesiredHorizontalVelocity.y);
DesiredCenterOfMassVelocity = transform.InverseTransformVector(desiredCom);
// get Desired Center Of Mass horizontal velocity in f space
CenterOfMassVelocityDifference = DesiredCenterOfMassVelocity-CenterOfMassHorizontalVelocity;
if (!LastIsSet)
{
LastRotation = transform.rotation;
}
AngualrVelocity = GetAngularVelocity(LastRotation, transform.rotation, timeDelta);
LastRotation = transform.rotation;
LastCenterOfMassInWorldSpace = newCOM;
LastIsSet = true;
// get bodyParts stats in local space
foreach (var bodyPart in _bodyParts)
{
Stat bodyPartStat = Stats.First(x=>x.Name == bodyPart.name);
Vector3 c = Vector3.zero;
CapsuleCollider capsule = bodyPart as CapsuleCollider;
BoxCollider box = bodyPart as BoxCollider;
SphereCollider sphere = bodyPart as SphereCollider;
if (capsule != null)
c = capsule.center;
else if (box != null)
c = box.center;
else if (sphere != null)
c = sphere.center;
Vector3 worldPosition = bodyPart.transform.TransformPoint(c);
// Vector3 worldPosition = transform.position;
Quaternion worldRotation = bodyPart.transform.rotation;
Vector3 localPosition = transform.InverseTransformPoint(worldPosition);
Quaternion localRotation = FromToRotation(transform.rotation, worldRotation);
if (!bodyPartStat.LastIsSet)
{
bodyPartStat.LastLocalPosition = localPosition;
bodyPartStat.LastLocalRotation = localRotation;
}
bodyPartStat.Position = localPosition;
bodyPartStat.Rotation = localRotation;
bodyPartStat.Velocity = (localPosition - bodyPartStat.LastLocalPosition)/timeDelta;
bodyPartStat.AngualrVelocity = GetAngularVelocity(bodyPartStat.LastLocalRotation, localRotation, timeDelta);
bodyPartStat.LastLocalPosition = localPosition;
bodyPartStat.LastLocalRotation = localRotation;
bodyPartStat.LastIsSet = true;
}
}
Vector3 GetCenterOfMass(IEnumerable<Rigidbody> bodies)
{
var centerOfMass = Vector3.zero;
float totalMass = 0f;
foreach (Rigidbody ab in bodies)
{
centerOfMass += ab.worldCenterOfMass * ab.mass;
totalMass += ab.mass;
}
centerOfMass /= totalMass;
// centerOfMass -= _spawnableEnv.transform.position;
return centerOfMass;
}
Vector3 GetCenterOfMass(IEnumerable<ArticulationBody> bodies)
{
var centerOfMass = Vector3.zero;
float totalMass = 0f;
foreach (ArticulationBody ab in bodies)
{
centerOfMass += ab.worldCenterOfMass * ab.mass;
totalMass += ab.mass;
}
centerOfMass /= totalMass;
// centerOfMass -= _spawnableEnv.transform.position;
return centerOfMass;
}
void OnDrawGizmosSelected()
{
if (_bodyPartsToTrack == null)
return;
// draw arrow for desired input velocity
// Vector3 pos = new Vector3(transform.position.x, transform.position.y, transform.position.z);
Vector3 pos = new Vector3(transform.position.x, .3f, transform.position.z);
Vector3 vector = DesiredCenterOfMassVelocity;
if (VelocityInWorldSpace)
vector = transform.TransformVector(vector);
DrawArrow(pos, vector, Color.green);
Vector3 desiredInputPos = pos+vector;
if (HorizontalVelocity)
{
// arrow for actual velocity
vector = CenterOfMassHorizontalVelocity;
if (VelocityInWorldSpace)
vector = transform.TransformVector(vector);
DrawArrow(pos, vector, Color.blue);
Vector3 actualPos = pos+vector;
// arrow for actual velocity difference
vector = CenterOfMassVelocityDifference;
if (VelocityInWorldSpace)
vector = transform.TransformVector(vector);
DrawArrow(actualPos, vector, Color.red);
}
else
{
vector = CenterOfMassVelocity;
if (VelocityInWorldSpace)
vector = transform.TransformVector(vector);
DrawArrow(pos, vector, Color.blue);
Vector3 actualPos = pos+vector;
// arrow for actual velocity difference
vector = DesiredCenterOfMassVelocity-CenterOfMassVelocity;
if (VelocityInWorldSpace)
vector = transform.TransformVector(vector);
DrawArrow(actualPos, vector, Color.red);
}
}
void DrawArrow(Vector3 start, Vector3 vector, Color color)
{
float headSize = 0.25f;
float headAngle = 20.0f;
Gizmos.color = color;
Gizmos.DrawRay(start, vector);
if (vector.magnitude > 0f)
{
Vector3 right = Quaternion.LookRotation(vector) * Quaternion.Euler(0,180+headAngle,0) * new Vector3(0,0,1);
Vector3 left = Quaternion.LookRotation(vector) * Quaternion.Euler(0,180-headAngle,0) * new Vector3(0,0,1);
Gizmos.DrawRay(start + vector, right * headSize);
Gizmos.DrawRay(start + vector, left * headSize);
}
}
public void ShiftCOM (Vector3 snapDistance)
{
Vector3 newCOM = LastCenterOfMassInWorldSpace + snapDistance;
LastCenterOfMassInWorldSpace = newCOM;
transform.position = newCOM;
}
}