implementation of drecon in unity 2022 lts forked from: https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

222 lines
8.4 KiB

10 months ago
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents;
using UnityEngine;
using ManyWorlds;
using UnityEngine.Assertions;
public class DReConObservations : MonoBehaviour
{
[Header("Observations")]
[Tooltip("Kinematic character center of mass velocity, Vector3")]
public Vector3 MocapCOMVelocity;
[Tooltip("RagDoll character center of mass velocity, Vector3")]
public Vector3 RagDollCOMVelocity;
[Tooltip("User-input desired horizontal CM velocity. Vector2")]
public Vector2 InputDesiredHorizontalVelocity;
[Tooltip("User-input requests jump, bool")]
public bool InputJump;
[Tooltip("User-input requests backflip, bool")]
public bool InputBackflip;
[Tooltip("Difference between RagDoll character horizontal CM velocity and user-input desired horizontal CM velocity. Vector2")]
public Vector2 HorizontalVelocityDifference;
[Tooltip("Positions and velocities for subset of bodies")]
public List<BodyPartDifferenceStats> BodyPartDifferenceStats;
public List<DReConObservationStats.Stat> MocapBodyStats;
public List<DReConObservationStats.Stat> RagDollBodyStats;
[Tooltip("Smoothed actions produced in the previous step of the policy are collected in t −1")]
public float[] PreviousActions;
[Header("Settings")]
public List<string> BodyPartsToTrack;
[Header("Gizmos")]
public bool VelocityInWorldSpace = true;
public bool PositionInWorldSpace = true;
public string targetedRootName = "articulation:Hips";
InputController _inputController;
SpawnableEnv _spawnableEnv;
DReConObservationStats _mocapBodyStats;
DReConObservationStats _ragDollBodyStats;
bool _hasLazyInitialized;
public void OnAgentInitialize()
{
Assert.IsFalse(_hasLazyInitialized);
_hasLazyInitialized = true;
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
_inputController = _spawnableEnv.GetComponentInChildren<InputController>();
BodyPartDifferenceStats = BodyPartsToTrack
.Select(x=> new BodyPartDifferenceStats{Name = x})
.ToList();
_mocapBodyStats= new GameObject("MocapDReConObservationStats").AddComponent<DReConObservationStats>();
_mocapBodyStats.setRootName(targetedRootName);
_mocapBodyStats.ObjectToTrack = _spawnableEnv.GetComponentInChildren<MocapControllerArtanim>();
_mocapBodyStats.transform.SetParent(_spawnableEnv.transform);
_mocapBodyStats.OnAgentInitialize(BodyPartsToTrack, _mocapBodyStats.ObjectToTrack.transform);
_ragDollBodyStats = new GameObject("RagDollDReConObservationStats").AddComponent<DReConObservationStats>();
_ragDollBodyStats.setRootName(targetedRootName);
_ragDollBodyStats.ObjectToTrack = this;
_ragDollBodyStats.transform.SetParent(_spawnableEnv.transform);
_ragDollBodyStats.OnAgentInitialize(BodyPartsToTrack, transform);
}
public void OnStep(float timeDelta)
{
Assert.IsTrue(_hasLazyInitialized);
_mocapBodyStats.SetStatusForStep(timeDelta);
_ragDollBodyStats.SetStatusForStep(timeDelta);
UpdateObservations(timeDelta);
}
public void OnReset()
{
Assert.IsTrue(_hasLazyInitialized);
_mocapBodyStats.OnReset();
_ragDollBodyStats.OnReset();
_ragDollBodyStats.transform.position = _mocapBodyStats.transform.position;
_ragDollBodyStats.transform.rotation = _mocapBodyStats.transform.rotation;
var timeDelta = float.MinValue;
UpdateObservations(timeDelta);
}
public void UpdateObservations(float timeDelta)
{
MocapCOMVelocity = _mocapBodyStats.CenterOfMassVelocity;
RagDollCOMVelocity = _ragDollBodyStats.CenterOfMassVelocity;
InputDesiredHorizontalVelocity = new Vector2(
_ragDollBodyStats.DesiredCenterOfMassVelocity.x,
_ragDollBodyStats.DesiredCenterOfMassVelocity.z);
InputJump = _inputController.Jump;
InputBackflip = _inputController.Backflip;
HorizontalVelocityDifference = new Vector2(
_ragDollBodyStats.CenterOfMassVelocityDifference.x,
_ragDollBodyStats.CenterOfMassVelocityDifference.z);
MocapBodyStats = BodyPartsToTrack
.Select(x=>_mocapBodyStats.Stats.First(y=>y.Name == x))
.ToList();
RagDollBodyStats = BodyPartsToTrack
.Select(x=>_ragDollBodyStats.Stats.First(y=>y.Name == x))
.ToList();
// BodyPartStats =
foreach (var differenceStats in BodyPartDifferenceStats)
{
var mocapStats = _mocapBodyStats.Stats.First(x=>x.Name == differenceStats.Name);
var ragDollStats = _ragDollBodyStats.Stats.First(x=>x.Name == differenceStats.Name);
differenceStats.Position = mocapStats.Position - ragDollStats.Position;
differenceStats.Velocity = mocapStats.Velocity - ragDollStats.Velocity;
differenceStats.AngualrVelocity = mocapStats.AngualrVelocity - ragDollStats.AngualrVelocity;
differenceStats.Rotation = DReConObservationStats.GetAngularVelocity(mocapStats.Rotation, ragDollStats.Rotation, timeDelta);
}
}
public Transform GetRagDollCOM()
{
return _ragDollBodyStats.transform;
}
public void ShiftMocapCOM(Vector3 snapDistance)
{
_ragDollBodyStats.ShiftCOM(snapDistance);
}
void OnDrawGizmos()
{
if (_mocapBodyStats == null)
return;
// MocapCOMVelocity
Vector3 pos = new Vector3(transform.position.x, .3f, transform.position.z);
Vector3 vector = MocapCOMVelocity;
if (VelocityInWorldSpace)
vector = _mocapBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.grey);
// RagDollCOMVelocity;
vector = RagDollCOMVelocity;
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.blue);
Vector3 actualPos = pos+vector;
// InputDesiredHorizontalVelocity;
vector = new Vector3(InputDesiredHorizontalVelocity.x, 0f, InputDesiredHorizontalVelocity.y);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.green);
// HorizontalVelocityDifference;
vector = new Vector3(HorizontalVelocityDifference.x, 0f, HorizontalVelocityDifference.y);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(actualPos, vector, Color.red);
for (int i = 0; i < RagDollBodyStats.Count; i++)
{
var stat = RagDollBodyStats[i];
var differenceStat = BodyPartDifferenceStats[i];
pos = stat.Position;
vector = stat.Velocity;
if (PositionInWorldSpace)
pos = _ragDollBodyStats.transform.TransformPoint(pos);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(pos, vector, Color.cyan);
Vector3 velocityPos = pos+vector;
pos = stat.Position;
vector = differenceStat.Position;
if (PositionInWorldSpace)
pos = _ragDollBodyStats.transform.TransformPoint(pos);
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
Gizmos.color = Color.magenta;
Gizmos.DrawRay(pos, vector);
Vector3 differencePos = pos+vector;
vector = differenceStat.Velocity;
if (VelocityInWorldSpace)
vector = _ragDollBodyStats.transform.TransformVector(vector);
DrawArrow(velocityPos, vector, Color.red);
}
}
void DrawArrow(Vector3 start, Vector3 vector, Color color)
{
float headSize = 0.25f;
float headAngle = 20.0f;
Gizmos.color = color;
Gizmos.DrawRay(start, vector);
if (vector.magnitude > 0f)
{
Vector3 right = Quaternion.LookRotation(vector) * Quaternion.Euler(0,180+headAngle,0) * new Vector3(0,0,1);
Vector3 left = Quaternion.LookRotation(vector) * Quaternion.Euler(0,180-headAngle,0) * new Vector3(0,0,1);
Gizmos.DrawRay(start + vector, right * headSize);
Gizmos.DrawRay(start + vector, left * headSize);
}
}
}