Files
drecon-unity/Assets/3_MarathonEnvs/Scripts/Ragdoll004/RagDollAgent.cs
2024-04-29 15:52:09 +01:00

568 lines
20 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
using ManyWorlds;
using UnityEngine.Assertions;
using System;
using Unity.Mathematics;
public class RagDollAgent : Agent
{
[Header("Settings")]
public float FixedDeltaTime = 1f/60f;
public float SmoothBeta = 0.2f;
[Header("Camera")]
public bool RequestCamera;
public bool CameraFollowMe;
public Transform CameraTarget;
[Header("... debug")]
public bool SkipRewardSmoothing;
public bool debugCopyMocap;
public bool ignorActions;
public bool dontResetOnZeroReward;
public bool dontSnapMocapToRagdoll;
public bool DebugPauseOnReset;
public bool UsePDControl = true;
List<Rigidbody> _mocapBodyParts;
List<ArticulationBody> _bodyParts;
SpawnableEnv _spawnableEnv;
DReConObservations _dReConObservations;
DReConRewards _dReConRewards;
RagDoll004 _ragDollSettings;
TrackBodyStatesInWorldSpace _trackBodyStatesInWorldSpace;
List<ArticulationBody> _motors;
MarathonTestBedController _debugController;
InputController _inputController;
SensorObservations _sensorObservations;
DecisionRequester _decisionRequester;
MocapAnimatorController _mocapAnimatorController;
bool _hasLazyInitialized;
float[] _smoothedActions;
float[] _mocapTargets;
[Space(16)]
[SerializeField]
bool _hasAwake = false;
MocapControllerArtanim _mocapControllerArtanim;
private float timer = 0f;
private float waitTime = 3f;
private bool waitStarted;
private float x = 0;
private float y = 100;
public TerrainCollider m_TerrainCollider;
public LayerMask m_ExcludeLayer;
public LayerMask m_Temp;
public delegate void OnMoveToNextModelDelegate();
public static OnMoveToNextModelDelegate m_MoveToNextModel;
public delegate void OnAgentFailDelegate();
public static OnAgentFailDelegate m_AgentFail;
void Awake()
{
if (RequestCamera && CameraTarget != null)
{
// Will follow the last object to be spawned
var camera = FindObjectOfType<Camera>();
if(camera != null) {
var follow = camera.GetComponent<SmoothFollow>();
if (follow != null)
follow.target = CameraTarget;
}
}
_hasAwake = true;
}
void Update()
{
if (debugCopyMocap)
{
EndEpisode();
}
Assert.IsTrue(_hasLazyInitialized);
// hadle mocap going out of bounds
if (!_spawnableEnv.IsPointWithinBoundsInWorldSpace(_mocapControllerArtanim.transform.position))
{
_mocapControllerArtanim.transform.position = _spawnableEnv.transform.position;
_trackBodyStatesInWorldSpace.LinkStatsToRigidBodies();
EndEpisode();
}
}
override public void CollectObservations(VectorSensor sensor)
{
Assert.IsTrue(_hasLazyInitialized);
float timeDelta = Time.fixedDeltaTime * _decisionRequester.DecisionPeriod;
_dReConObservations.OnStep(timeDelta);
sensor.AddObservation(_dReConObservations.MocapCOMVelocity);
sensor.AddObservation(_dReConObservations.RagDollCOMVelocity);
sensor.AddObservation(_dReConObservations.RagDollCOMVelocity-_dReConObservations.MocapCOMVelocity);
sensor.AddObservation(_dReConObservations.InputDesiredHorizontalVelocity);
sensor.AddObservation(_dReConObservations.InputJump);
sensor.AddObservation(_dReConObservations.InputBackflip);
sensor.AddObservation(_dReConObservations.HorizontalVelocityDifference);
// foreach (var stat in _dReConObservations.MocapBodyStats)
// {
// sensor.AddObservation(stat.Position);
// sensor.AddObservation(stat.Velocity);
// }
foreach (var stat in _dReConObservations.RagDollBodyStats)
{
sensor.AddObservation(stat.Position);
sensor.AddObservation(stat.Velocity);
}
foreach (var stat in _dReConObservations.BodyPartDifferenceStats)
{
sensor.AddObservation(stat.Position);
sensor.AddObservation(stat.Velocity);
}
sensor.AddObservation(_dReConObservations.PreviousActions);
// add sensors (feet etc)
sensor.AddObservation(_sensorObservations.SensorIsInTouch);
}
public override void OnActionReceived(ActionBuffers actions)
{
float[] vectorAction = actions.ContinuousActions.Select(x=>x).ToArray();
Assert.IsTrue(_hasLazyInitialized);
float timeDelta = Time.fixedDeltaTime;
if (!_decisionRequester.TakeActionsBetweenDecisions)
timeDelta = timeDelta*_decisionRequester.DecisionPeriod;
_dReConRewards.OnStep(timeDelta);
bool shouldDebug = _debugController != null;
bool dontUpdateMotor = false;
if (_debugController != null)
{
dontUpdateMotor = _debugController.DontUpdateMotor;
dontUpdateMotor &= _debugController.isActiveAndEnabled;
dontUpdateMotor &= _debugController.gameObject.activeInHierarchy;
shouldDebug &= _debugController.isActiveAndEnabled;
shouldDebug &= _debugController.gameObject.activeInHierarchy;
}
if (shouldDebug)
{
vectorAction = GetDebugActions(vectorAction);
}
if (UsePDControl)
{
var targets = GetMocapTargets();
vectorAction = vectorAction
.Zip(targets, (action, target)=> Mathf.Clamp(target + action *2f, -1f, 1f))
.ToArray();
}
if (!SkipRewardSmoothing)
vectorAction = SmoothActions(vectorAction);
if (ignorActions)
vectorAction = vectorAction.Select(x=>0f).ToArray();
int i = 0;
var scale = 1f;
x += Time.fixedDeltaTime * scale;
y += Time.fixedDeltaTime * scale;
foreach (var m in _motors)
{
if (m.isRoot)
continue;
if (dontUpdateMotor)
continue;
Vector3 targetNormalizedRotation = Vector3.zero;
if (m.twistLock == ArticulationDofLock.LimitedMotion && !waitStarted)
{
targetNormalizedRotation.x = vectorAction[i++];
}
else
{
var offset = 7.5f;
targetNormalizedRotation.x = (Mathf.PerlinNoise(x + offset, y + offset) * 2) - 1; ;
}
if (m.swingYLock == ArticulationDofLock.LimitedMotion && !waitStarted)
{
targetNormalizedRotation.y = vectorAction[i++];
}
else
{
var offset = 5f;
targetNormalizedRotation.y = (Mathf.PerlinNoise(x + offset, y + offset) * 2) - 1; ;
}
if (m.swingZLock == ArticulationDofLock.LimitedMotion && !waitStarted)
{
targetNormalizedRotation.z = vectorAction[i++];
}
else
{
var offset = 22.3f;
targetNormalizedRotation.z = (Mathf.PerlinNoise(x + offset, y + offset) * 2) - 1; ;
}
UpdateMotor(m, targetNormalizedRotation);
}
_dReConObservations.PreviousActions = vectorAction;
AddReward(_dReConRewards.Reward);
// if (_dReConRewards.HeadHeightDistance > 0.5f || _dReConRewards.Reward < 1f)
if (_dReConRewards.HeadHeightDistance > 0.5f || _dReConRewards.Reward <= 0f)
{
if (!dontResetOnZeroReward)
{
if (!waitStarted)
{
waitStarted = true;
//Disable Ground Collider
if (m_TerrainCollider)
{
// Exclude Marathon from Collider
m_Temp = m_TerrainCollider.excludeLayers;
m_TerrainCollider.excludeLayers = m_ExcludeLayer;
}
m_AgentFail?.Invoke();
}
timer += Time.fixedDeltaTime;
if (timer >= waitTime)
{
timer = 0;
waitStarted = false;
m_TerrainCollider.excludeLayers = m_Temp;
m_MoveToNextModel?.Invoke();
EndEpisode();
}
}
}
// else if (_dReConRewards.HeadDistance > 1.5f)
else if (_dReConRewards.Reward <= 0.1f && !dontSnapMocapToRagdoll)
{
Transform ragDollCom = _dReConObservations.GetRagDollCOM();
Vector3 snapPosition = ragDollCom.position;
snapPosition.y = 0f;
_mocapControllerArtanim.SnapTo(snapPosition);
AddReward(-.5f);
}
}
float[] GetDebugActions(float[] vectorAction)
{
var debugActions = new List<float>();
foreach (var m in _motors)
{
if (m.isRoot)
continue;
DebugMotor debugMotor = m.GetComponent<DebugMotor>();
if (debugMotor == null)
{
debugMotor = m.gameObject.AddComponent<DebugMotor>();
}
// clip to -1/+1
debugMotor.Actions = new Vector3 (
Mathf.Clamp(debugMotor.Actions.x, -1f, 1f),
Mathf.Clamp(debugMotor.Actions.y, -1f, 1f),
Mathf.Clamp(debugMotor.Actions.z, -1f, 1f)
);
Vector3 targetNormalizedRotation = debugMotor.Actions;
if (m.twistLock == ArticulationDofLock.LimitedMotion)
debugActions.Add(targetNormalizedRotation.x);
if (m.swingYLock == ArticulationDofLock.LimitedMotion)
debugActions.Add(targetNormalizedRotation.y);
if (m.swingZLock == ArticulationDofLock.LimitedMotion)
debugActions.Add(targetNormalizedRotation.z);
}
debugActions = debugActions.Select(x=>Mathf.Clamp(x,-1f,1f)).ToList();
_debugController.Actions = debugActions.ToArray();
return debugActions.ToArray();
}
float[] SmoothActions(float[] vectorAction)
{
// yt =β at +(1β)yt1
if (_smoothedActions == null)
_smoothedActions = vectorAction.Select(x=>0f).ToArray();
_smoothedActions = vectorAction
.Zip(_smoothedActions, (a, y)=> SmoothBeta * a + (1f-SmoothBeta) * y)
.ToArray();
return _smoothedActions;
}
public override void Initialize()
{
Assert.IsTrue(_hasAwake);
Assert.IsFalse(_hasLazyInitialized);
_hasLazyInitialized = true;
_decisionRequester = GetComponent<DecisionRequester>();
_debugController = FindObjectOfType<MarathonTestBedController>();
Time.fixedDeltaTime = FixedDeltaTime;
_spawnableEnv = GetComponentInParent<SpawnableEnv>();
if (_debugController != null)
{
dontResetOnZeroReward = true;
dontSnapMocapToRagdoll = true;
UsePDControl = false;
}
_mocapControllerArtanim = _spawnableEnv.GetComponentInChildren<MocapControllerArtanim>();
_mocapBodyParts = _mocapControllerArtanim.GetRigidBodies();
_bodyParts = GetComponentsInChildren<ArticulationBody>().ToList();
_dReConObservations = GetComponent<DReConObservations>();
_dReConRewards = GetComponent<DReConRewards>();
_trackBodyStatesInWorldSpace = _mocapControllerArtanim.GetComponent<TrackBodyStatesInWorldSpace>();
_ragDollSettings = GetComponent<RagDoll004>();
_inputController = _spawnableEnv.GetComponentInChildren<InputController>();
_sensorObservations = GetComponent<SensorObservations>();
foreach (var body in GetComponentsInChildren<ArticulationBody>())
{
body.solverIterations = 255;
body.solverVelocityIterations = 255;
}
_motors = GetComponentsInChildren<ArticulationBody>()
.Where(x=>x.jointType == ArticulationJointType.SphericalJoint)
.Where(x=>!x.isRoot)
.Distinct()
.ToList();
var individualMotors = new List<float>();
foreach (var m in _motors)
{
if (m.twistLock == ArticulationDofLock.LimitedMotion)
individualMotors.Add(0f);
if (m.swingYLock == ArticulationDofLock.LimitedMotion)
individualMotors.Add(0f);
if (m.swingZLock == ArticulationDofLock.LimitedMotion)
individualMotors.Add(0f);
}
_dReConObservations.PreviousActions = individualMotors.ToArray();
//_mocapAnimatorController = _mocapControllerArtanim.GetComponentInChildren<MocapAnimatorController>();
_mocapAnimatorController = _mocapControllerArtanim.GetComponent<MocapAnimatorController>();
_mocapControllerArtanim.OnAgentInitialize();
_dReConObservations.OnAgentInitialize();
_dReConRewards.OnAgentInitialize();
_trackBodyStatesInWorldSpace.OnAgentInitialize();
_mocapAnimatorController.OnAgentInitialize();
_inputController.OnReset();
_hasLazyInitialized = true;
}
public override void OnEpisodeBegin()
{
Assert.IsTrue(_hasAwake);
_smoothedActions = null;
debugCopyMocap = false;
_mocapAnimatorController.OnReset();
var angle = Vector3.SignedAngle(Vector3.forward, _inputController.HorizontalDirection, Vector3.up);
var rotation = Quaternion.Euler(0f, angle, 0f);
_mocapControllerArtanim.OnReset(rotation);
_mocapControllerArtanim.CopyStatesTo(this.gameObject);
// _trackBodyStatesInWorldSpace.CopyStatesTo(this.gameObject);
float timeDelta = float.MinValue;
_dReConObservations.OnReset();
_dReConRewards.OnReset();
_dReConObservations.OnStep(timeDelta);
_dReConRewards.OnStep(timeDelta);
#if UNITY_EDITOR
if (DebugPauseOnReset)
{
UnityEditor.EditorApplication.isPaused = true;
}
#endif
if (_debugController != null && _debugController.isActiveAndEnabled)
{
_debugController.OnAgentEpisodeBegin();
}
}
float[] GetMocapTargets()
{
if (_mocapTargets == null)
{
_mocapTargets = _motors
.Where(x=>!x.isRoot)
.SelectMany(x => {
List<float> list = new List<float>();
if (x.twistLock == ArticulationDofLock.LimitedMotion)
list.Add(0f);
if (x.swingYLock == ArticulationDofLock.LimitedMotion)
list.Add(0f);
if (x.swingZLock == ArticulationDofLock.LimitedMotion)
list.Add(0f);
return list.ToArray();
})
.ToArray();
}
int i=0;
foreach (var joint in _motors)
{
if (joint.isRoot)
continue;
Rigidbody mocapBody = _mocapBodyParts.First(x=>x.name == joint.name);
Vector3 targetRotationInJointSpace = -(Quaternion.Inverse(joint.anchorRotation) * Quaternion.Inverse(mocapBody.transform.localRotation) * joint.parentAnchorRotation).eulerAngles;
targetRotationInJointSpace = new Vector3(
Mathf.DeltaAngle(0, targetRotationInJointSpace.x),
Mathf.DeltaAngle(0, targetRotationInJointSpace.y),
Mathf.DeltaAngle(0, targetRotationInJointSpace.z));
if (joint.twistLock == ArticulationDofLock.LimitedMotion)
{
var drive = joint.xDrive;
var scale = (drive.upperLimit-drive.lowerLimit) / 2f;
var midpoint = drive.lowerLimit + scale;
var target = (targetRotationInJointSpace.x -midpoint) / scale;
_mocapTargets[i] = target;
i++;
}
if (joint.swingYLock == ArticulationDofLock.LimitedMotion)
{
var drive = joint.yDrive;
var scale = (drive.upperLimit-drive.lowerLimit) / 2f;
var midpoint = drive.lowerLimit + scale;
var target = (targetRotationInJointSpace.y -midpoint) / scale;
_mocapTargets[i] = target;
i++;
}
if (joint.swingZLock == ArticulationDofLock.LimitedMotion)
{
var drive = joint.zDrive;
var scale = (drive.upperLimit-drive.lowerLimit) / 2f;
var midpoint = drive.lowerLimit + scale;
var target = (targetRotationInJointSpace.z -midpoint) / scale;
_mocapTargets[i] = target;
i++;
}
}
return _mocapTargets;
}
void UpdateMotor(ArticulationBody joint, Vector3 targetNormalizedRotation)
{
//Vector3 power = _ragDollSettings.MusclePowers.First(x=>x.Muscle == joint.name).PowerVector;
Vector3 power = Vector3.zero;
try
{
power = _ragDollSettings.MusclePowers.First(x => x.Muscle == joint.name).PowerVector;
}
catch
{
Debug.Log("there is no muscle for joint " + joint.name);
}
power *= _ragDollSettings.Stiffness;
float damping = _ragDollSettings.Damping;
float forceLimit = _ragDollSettings.ForceLimit;
if (joint.twistLock == ArticulationDofLock.LimitedMotion)
{
var drive = joint.xDrive;
var scale = (drive.upperLimit-drive.lowerLimit) / 2f;
var midpoint = drive.lowerLimit + scale;
var target = midpoint + (targetNormalizedRotation.x *scale);
drive.target = target;
drive.stiffness = power.x;
drive.damping = damping;
drive.forceLimit = forceLimit;
joint.xDrive = drive;
}
if (joint.swingYLock == ArticulationDofLock.LimitedMotion)
{
var drive = joint.yDrive;
var scale = (drive.upperLimit-drive.lowerLimit) / 2f;
var midpoint = drive.lowerLimit + scale;
var target = midpoint + (targetNormalizedRotation.y *scale);
drive.target = target;
drive.stiffness = power.y;
drive.damping = damping;
drive.forceLimit = forceLimit;
joint.yDrive = drive;
}
if (joint.swingZLock == ArticulationDofLock.LimitedMotion)
{
var drive = joint.zDrive;
var scale = (drive.upperLimit-drive.lowerLimit) / 2f;
var midpoint = drive.lowerLimit + scale;
var target = midpoint + (targetNormalizedRotation.z *scale);
drive.target = target;
drive.stiffness = power.z;
drive.damping = damping;
drive.forceLimit = forceLimit;
joint.zDrive = drive;
}
}
void FixedUpdate()
{
if (debugCopyMocap)
{
EndEpisode();
}
}
void OnDrawGizmos()
{
if (_dReConRewards == null)
return;
var comTransform = _dReConRewards._ragDollBodyStats.transform;
var vector = new Vector3( _inputController.MovementVector.x, 0f, _inputController.MovementVector.y);
var pos = new Vector3(comTransform.position.x, 0.001f, comTransform.position.z);
DrawArrow(pos, vector, Color.black);
}
void DrawArrow(Vector3 start, Vector3 vector, Color color)
{
float headSize = 0.25f;
float headAngle = 20.0f;
Gizmos.color = color;
Gizmos.DrawRay(start, vector);
if (vector != Vector3.zero)
{
Vector3 right = Quaternion.LookRotation(vector) * Quaternion.Euler(0,180+headAngle,0) * new Vector3(0,0,1);
Vector3 left = Quaternion.LookRotation(vector) * Quaternion.Euler(0,180-headAngle,0) * new Vector3(0,0,1);
Gizmos.DrawRay(start + vector, right * headSize);
Gizmos.DrawRay(start + vector, left * headSize);
}
}
}