implementation of drecon in unity 2022 lts
forked from:
https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
386 lines
15 KiB
386 lines
15 KiB
10 months ago
|
using System;
|
||
|
using System.Collections;
|
||
|
using System.Collections.Generic;
|
||
|
using System.Linq;
|
||
|
using Unity.MLAgents;
|
||
|
using UnityEngine;
|
||
|
using UnityEngine.Assertions;
|
||
|
using ManyWorlds;
|
||
|
|
||
|
public class RewardStats : MonoBehaviour
|
||
|
{
|
||
|
[Header("Settings")]
|
||
|
|
||
|
public MonoBehaviour ObjectToTrack;
|
||
|
|
||
|
[Header("Stats")]
|
||
|
public Vector3 CenterOfMassVelocity;
|
||
|
public float CenterOfMassVelocityMagnitude;
|
||
|
public Vector3 CenterOfMassVelocityInRootSpace;
|
||
|
public float CenterOfMassVelocityMagnitudeInRootSpace;
|
||
|
|
||
|
[HideInInspector] public Vector3 LastCenterOfMassInWorldSpace;
|
||
|
[HideInInspector] public Vector3 HeadPositionInWorldSpace;
|
||
|
[HideInInspector] public Vector3 RootPositionInWorldSpace;
|
||
|
[HideInInspector] public bool LastIsSet;
|
||
|
|
||
|
SpawnableEnv _spawnableEnv;
|
||
|
List<Collider> _colliders;
|
||
|
List<Rigidbody> _rigidbodyParts;
|
||
|
List<ArticulationBody> _articulationBodyParts;
|
||
|
List<GameObject> _bodyParts;
|
||
|
GameObject _root;
|
||
|
GameObject _head;
|
||
|
MapAnim2Ragdoll _mapAnim2Ragdoll;
|
||
|
List<GameObject> _trackRotations;
|
||
|
public List<Quaternion> Rotations;
|
||
|
public Vector3[] Points;
|
||
|
Vector3[] _lastPoints;
|
||
|
public Vector3[] PointVelocity;
|
||
|
|
||
|
[Header("Stats")]
|
||
|
public List<string> ColliderNames;
|
||
|
public List<string> RotationNames;
|
||
|
public List<string> BodyPartNames;
|
||
|
bool _hasLazyInitialized;
|
||
|
|
||
|
|
||
|
string rootName = "articulation:Hips";
|
||
|
string headName = "";
|
||
|
|
||
|
public void setRootName(string s)
|
||
|
{
|
||
|
rootName = s;
|
||
|
}
|
||
|
public void setHeadName(string s)
|
||
|
{
|
||
|
headName = s;
|
||
|
}
|
||
|
|
||
|
|
||
|
public void OnAgentInitialize(Transform defaultTransform, RewardStats orderToCopy = null)
|
||
|
{
|
||
|
Assert.IsFalse(_hasLazyInitialized);
|
||
|
_hasLazyInitialized = true;
|
||
|
|
||
|
_mapAnim2Ragdoll = defaultTransform.GetComponent<MapAnim2Ragdoll>();
|
||
|
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
|
||
|
_articulationBodyParts = ObjectToTrack
|
||
|
.GetComponentsInChildren<ArticulationBody>()
|
||
|
.Distinct()
|
||
|
.ToList();
|
||
|
_rigidbodyParts = ObjectToTrack
|
||
|
.GetComponentsInChildren<Rigidbody>()
|
||
|
.Distinct()
|
||
|
.ToList();
|
||
|
if (_rigidbodyParts?.Count > 0)
|
||
|
_bodyParts = _rigidbodyParts.Select(x => x.gameObject).ToList();
|
||
|
else
|
||
|
_bodyParts = _articulationBodyParts.Select(x => x.gameObject).ToList();
|
||
|
_trackRotations = _bodyParts
|
||
|
.SelectMany(x => x.GetComponentsInChildren<Transform>())
|
||
|
.Select(x => x.gameObject)
|
||
|
.Distinct()
|
||
|
.Where(x => x.GetComponent<Rigidbody>() != null || x.GetComponent<ArticulationBody>() != null)
|
||
|
.Where(x=>x.name.StartsWith("articulation:"))
|
||
|
.ToList();
|
||
|
_colliders = _bodyParts
|
||
|
.SelectMany(x => x.GetComponentsInChildren<Collider>())
|
||
|
.Where(x => x.enabled)
|
||
|
.Where(x => !x.isTrigger)
|
||
|
.Where(x=> {
|
||
|
var ignoreCollider = x.GetComponent<IgnoreColliderForReward>();
|
||
|
if (ignoreCollider == null)
|
||
|
return true;
|
||
|
return !ignoreCollider.enabled;})
|
||
|
.Distinct()
|
||
|
.ToList();
|
||
|
if (orderToCopy != null)
|
||
|
{
|
||
|
_bodyParts = orderToCopy._bodyParts
|
||
|
.Select(x => _bodyParts.First(y => y.name == x.name))
|
||
|
.ToList();
|
||
|
_trackRotations = orderToCopy._trackRotations
|
||
|
.Select(x => _trackRotations.First(y => y.name == x.name))
|
||
|
.ToList();
|
||
|
_colliders = orderToCopy._colliders
|
||
|
.Select(x => _colliders.First(y => y.name == x.name))
|
||
|
.ToList();
|
||
|
}
|
||
|
Points = Enumerable.Range(0, _colliders.Count * 6)
|
||
|
.Select(x => Vector3.zero)
|
||
|
.ToArray();
|
||
|
_lastPoints = Enumerable.Range(0, _colliders.Count * 6)
|
||
|
.Select(x => Vector3.zero)
|
||
|
.ToArray();
|
||
|
PointVelocity = Enumerable.Range(0, _colliders.Count * 6)
|
||
|
.Select(x => Vector3.zero)
|
||
|
.ToArray();
|
||
|
Rotations = Enumerable.Range(0, _trackRotations.Count)
|
||
|
.Select(x => Quaternion.identity)
|
||
|
.ToList();
|
||
|
_root = _bodyParts.First(x => x.name == rootName);
|
||
|
_head = _bodyParts.First(x => x.name == headName);
|
||
|
transform.position = defaultTransform.position;
|
||
|
transform.rotation = defaultTransform.rotation;
|
||
|
LastCenterOfMassInWorldSpace = transform.position;
|
||
|
HeadPositionInWorldSpace = _head.transform.position;
|
||
|
RootPositionInWorldSpace = _root.transform.position;
|
||
|
ColliderNames = _colliders
|
||
|
.Select(x => x.name)
|
||
|
.ToList();
|
||
|
RotationNames = _trackRotations
|
||
|
.Select(x => x.name)
|
||
|
.ToList();
|
||
|
BodyPartNames = _bodyParts
|
||
|
.Select(x => x.name)
|
||
|
.ToList();
|
||
|
}
|
||
|
public void OnReset()
|
||
|
{
|
||
|
Assert.IsTrue(_hasLazyInitialized);
|
||
|
ResetStatus();
|
||
|
LastIsSet = false;
|
||
|
}
|
||
|
public void ResetStatus()
|
||
|
{
|
||
|
LastIsSet = false;
|
||
|
var timeDelta = float.MinValue;
|
||
|
SetStatusForStep(timeDelta);
|
||
|
}
|
||
|
|
||
|
public void SetStatusForStep(float timeDelta)
|
||
|
{
|
||
|
// get Center Of Mass velocity in f space
|
||
|
Vector3 newCOM;
|
||
|
// if Moocap, then get from anim2Ragdoll
|
||
|
if (_mapAnim2Ragdoll != null)
|
||
|
{
|
||
|
newCOM = _mapAnim2Ragdoll.LastCenterOfMassInWorldSpace;
|
||
|
var newHorizontalDirection = _mapAnim2Ragdoll.HorizontalDirection;
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
LastCenterOfMassInWorldSpace = newCOM;
|
||
|
}
|
||
|
transform.position = newCOM;
|
||
|
transform.rotation = Quaternion.Euler(newHorizontalDirection);
|
||
|
CenterOfMassVelocity = _mapAnim2Ragdoll.CenterOfMassVelocity;
|
||
|
CenterOfMassVelocityMagnitude = _mapAnim2Ragdoll.CenterOfMassVelocityMagnitude;
|
||
|
CenterOfMassVelocityInRootSpace = transform.InverseTransformVector(CenterOfMassVelocity);
|
||
|
CenterOfMassVelocityMagnitudeInRootSpace = CenterOfMassVelocityInRootSpace.magnitude;
|
||
|
HeadPositionInWorldSpace = _mapAnim2Ragdoll.LastHeadPositionInWorldSpace;
|
||
|
RootPositionInWorldSpace = _mapAnim2Ragdoll.LastRootPositionInWorldSpace;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
newCOM = GetCenterOfMass();
|
||
|
var newHorizontalDirection = new Vector3(0f, _root.transform.eulerAngles.y, 0f);
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
LastCenterOfMassInWorldSpace = newCOM;
|
||
|
}
|
||
|
transform.position = newCOM;
|
||
|
transform.rotation = Quaternion.Euler(newHorizontalDirection);
|
||
|
var velocity = newCOM - LastCenterOfMassInWorldSpace;
|
||
|
velocity /= timeDelta;
|
||
|
CenterOfMassVelocity = velocity;
|
||
|
CenterOfMassVelocityMagnitude = CenterOfMassVelocity.magnitude;
|
||
|
CenterOfMassVelocityInRootSpace = transform.InverseTransformVector(CenterOfMassVelocity);
|
||
|
CenterOfMassVelocityMagnitudeInRootSpace = CenterOfMassVelocityInRootSpace.magnitude;
|
||
|
HeadPositionInWorldSpace = _head.transform.position;
|
||
|
RootPositionInWorldSpace = _root.transform.position;
|
||
|
}
|
||
|
LastCenterOfMassInWorldSpace = newCOM;
|
||
|
|
||
|
GetAllPoints(Points);
|
||
|
if (!LastIsSet)
|
||
|
{
|
||
|
Array.Copy(Points, 0, _lastPoints, 0, Points.Length);
|
||
|
}
|
||
|
for (int i = 0; i < Points.Length; i++)
|
||
|
{
|
||
|
PointVelocity[i] = (Points[i] - _lastPoints[i]) / timeDelta;
|
||
|
}
|
||
|
Array.Copy(Points, 0, _lastPoints, 0, Points.Length);
|
||
|
|
||
|
for (int i = 0; i < _trackRotations.Count; i++)
|
||
|
{
|
||
|
Quaternion localRotation = _trackRotations[i].transform.localRotation;
|
||
|
if (_trackRotations[i].gameObject == _root)
|
||
|
localRotation = Quaternion.Inverse(transform.rotation) * _trackRotations[i].transform.rotation;
|
||
|
Rotations[i] = localRotation;
|
||
|
}
|
||
|
|
||
|
LastIsSet = true;
|
||
|
}
|
||
|
|
||
|
public List<float> GetPointDistancesFrom(RewardStats target)
|
||
|
{
|
||
|
List<float> distances = new List<float>();
|
||
|
for (int i = 0; i < Points.Length; i++)
|
||
|
{
|
||
|
float distance = (Points[i] - target.Points[i]).magnitude;
|
||
|
distances.Add(distance);
|
||
|
}
|
||
|
return distances;
|
||
|
}
|
||
|
|
||
|
public List<float> GetPointVelocityDistancesFrom(RewardStats target)
|
||
|
{
|
||
|
List<float> distances = new List<float>();
|
||
|
for (int i = 0; i < PointVelocity.Length; i++)
|
||
|
{
|
||
|
float distance = (PointVelocity[i] - target.PointVelocity[i]).magnitude;
|
||
|
distances.Add(distance);
|
||
|
}
|
||
|
return distances;
|
||
|
}
|
||
|
|
||
|
public void AssertIsCompatible(RewardStats target)
|
||
|
{
|
||
|
Assert.AreEqual(Points.Length, target.Points.Length);
|
||
|
Assert.AreEqual(_lastPoints.Length, target._lastPoints.Length);
|
||
|
Assert.AreEqual(PointVelocity.Length, target.PointVelocity.Length);
|
||
|
Assert.AreEqual(Points.Length, _lastPoints.Length);
|
||
|
Assert.AreEqual(Points.Length, PointVelocity.Length);
|
||
|
Assert.AreEqual(_colliders.Count, target._colliders.Count);
|
||
|
for (int i = 0; i < _colliders.Count; i++)
|
||
|
{
|
||
|
string debugStr = $" _colliders.{_colliders[i].name} vs target._colliders.{target._colliders[i].name}";
|
||
|
Assert.AreEqual(_colliders[i].name, target._colliders[i].name, $"name:{debugStr}");
|
||
|
// Assert.AreEqual(_colliders[i].direction, target._colliders[i].direction, $"direction:{debugStr}");
|
||
|
// Assert.AreEqual(_colliders[i].height, target._colliders[i].height, $"height:{debugStr}");
|
||
|
// Assert.AreEqual(_colliders[i].radius, target._colliders[i].radius, $"radius:{debugStr}");
|
||
|
}
|
||
|
Assert.AreEqual(ColliderNames.Count, target.ColliderNames.Count);
|
||
|
Assert.AreEqual(RotationNames.Count, target.RotationNames.Count);
|
||
|
Assert.AreEqual(BodyPartNames.Count, target.BodyPartNames.Count);
|
||
|
for (int i = 0; i < ColliderNames.Count; i++)
|
||
|
Assert.AreEqual(ColliderNames[i], target.ColliderNames[i]);
|
||
|
for (int i = 0; i < RotationNames.Count; i++)
|
||
|
Assert.AreEqual(RotationNames[i], target.RotationNames[i]);
|
||
|
for (int i = 0; i < BodyPartNames.Count; i++)
|
||
|
Assert.AreEqual(BodyPartNames[i], target.BodyPartNames[i]);
|
||
|
}
|
||
|
|
||
|
void GetAllPoints(Vector3[] pointBuffer)
|
||
|
{
|
||
|
int idx = 0;
|
||
|
foreach (var collider in _colliders)
|
||
|
{
|
||
|
CapsuleCollider capsule = collider as CapsuleCollider;
|
||
|
BoxCollider box = collider as BoxCollider;
|
||
|
SphereCollider sphere = collider as SphereCollider;
|
||
|
Vector3 c = Vector3.zero;
|
||
|
Bounds b = new Bounds(c, c);
|
||
|
|
||
|
if (collider.name == "head")
|
||
|
{
|
||
|
c = c;
|
||
|
}
|
||
|
|
||
|
if (capsule != null)
|
||
|
{
|
||
|
c = capsule.center;
|
||
|
var r = capsule.radius * 2;
|
||
|
var h = capsule.height;
|
||
|
h = Mathf.Max(r, h); // capsules height is clipped at r
|
||
|
if (capsule.direction == 0)
|
||
|
b = new Bounds(c, new Vector3(h, r, r));
|
||
|
else if (capsule.direction == 1)
|
||
|
b = new Bounds(c, new Vector3(r, h, r));
|
||
|
else if (capsule.direction == 2)
|
||
|
b = new Bounds(c, new Vector3(r, r, h));
|
||
|
else throw new NotImplementedException();
|
||
|
}
|
||
|
else if (box != null)
|
||
|
{
|
||
|
c = box.center;
|
||
|
b = new Bounds(c, box.size);
|
||
|
}
|
||
|
else if (sphere != null)
|
||
|
{
|
||
|
c = sphere.center;
|
||
|
var r = sphere.radius * 2;
|
||
|
b = new Bounds(c, new Vector3(r, r, r));
|
||
|
}
|
||
|
else
|
||
|
throw new NotImplementedException();
|
||
|
|
||
|
Vector3 point1, point2, point3, point4, point5, point6;
|
||
|
point1 = new Vector3(b.max.x, c.y, c.z);
|
||
|
point2 = new Vector3(b.min.x, c.y, c.z);
|
||
|
point3 = new Vector3(c.x, b.max.y, c.z);
|
||
|
point4 = new Vector3(c.x, b.min.y, c.z);
|
||
|
point5 = new Vector3(c.x, c.y, b.max.z);
|
||
|
point6 = new Vector3(c.x, c.y, b.min.z);
|
||
|
// from local collider space to world space
|
||
|
point1 = collider.transform.TransformPoint(point1);
|
||
|
point2 = collider.transform.TransformPoint(point2);
|
||
|
point3 = collider.transform.TransformPoint(point3);
|
||
|
point4 = collider.transform.TransformPoint(point4);
|
||
|
point5 = collider.transform.TransformPoint(point5);
|
||
|
point6 = collider.transform.TransformPoint(point6);
|
||
|
// transform from world space, into local space
|
||
|
point1 = this.transform.InverseTransformPoint(point1);
|
||
|
point2 = this.transform.InverseTransformPoint(point2);
|
||
|
point3 = this.transform.InverseTransformPoint(point3);
|
||
|
point4 = this.transform.InverseTransformPoint(point4);
|
||
|
point5 = this.transform.InverseTransformPoint(point5);
|
||
|
point6 = this.transform.InverseTransformPoint(point6);
|
||
|
|
||
|
pointBuffer[idx++] = point1;
|
||
|
pointBuffer[idx++] = point2;
|
||
|
pointBuffer[idx++] = point3;
|
||
|
pointBuffer[idx++] = point4;
|
||
|
pointBuffer[idx++] = point5;
|
||
|
pointBuffer[idx++] = point6;
|
||
|
}
|
||
|
}
|
||
|
Vector3 GetCenterOfMass()
|
||
|
{
|
||
|
var centerOfMass = Vector3.zero;
|
||
|
float totalMass = 0f;
|
||
|
foreach (ArticulationBody ab in _articulationBodyParts)
|
||
|
{
|
||
|
centerOfMass += ab.worldCenterOfMass * ab.mass;
|
||
|
totalMass += ab.mass;
|
||
|
}
|
||
|
centerOfMass /= totalMass;
|
||
|
// centerOfMass -= _spawnableEnv.transform.position;
|
||
|
return centerOfMass;
|
||
|
}
|
||
|
public void DrawPointDistancesFrom(RewardStats target, int objIdex)
|
||
|
{
|
||
|
int start = 0;
|
||
|
int end = Points.Length - 1;
|
||
|
if (objIdex >= 0)
|
||
|
{
|
||
|
start = objIdex * 6;
|
||
|
end = (objIdex * 6) + 6;
|
||
|
}
|
||
|
for (int i = start; i < end; i++)
|
||
|
{
|
||
|
Gizmos.color = Color.white;
|
||
|
var from = Points[i];
|
||
|
var to = target.Points[i];
|
||
|
var toTarget = target.Points[i];
|
||
|
// transform to this object's world space
|
||
|
from = this.transform.TransformPoint(from);
|
||
|
to = this.transform.TransformPoint(to);
|
||
|
// transform to target's world space
|
||
|
toTarget = target.transform.TransformPoint(toTarget);
|
||
|
Gizmos.color = Color.white;
|
||
|
Gizmos.DrawLine(from, toTarget);
|
||
|
// show this objects velocity
|
||
|
Vector3 velocity = PointVelocity[i];
|
||
|
Gizmos.color = Color.blue;
|
||
|
Gizmos.DrawRay(from, velocity);
|
||
|
// show targets velocity
|
||
|
Vector3 velocityTarget = target.PointVelocity[i];
|
||
|
Gizmos.color = Color.green;
|
||
|
Gizmos.DrawRay(toTarget, velocityTarget);
|
||
|
}
|
||
|
}
|
||
|
}
|