using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents;
using UnityEngine;
using ManyWorlds;
using UnityEngine.Assertions;

public class DReConObservations : MonoBehaviour
{
    [Header("Observations")]

    [Tooltip("Kinematic character center of mass velocity, Vector3")]
    public Vector3 MocapCOMVelocity;

    [Tooltip("RagDoll character center of mass velocity, Vector3")]
    public Vector3 RagDollCOMVelocity;

    [Tooltip("User-input desired horizontal CM velocity. Vector2")]
    public Vector2 InputDesiredHorizontalVelocity;

    [Tooltip("User-input requests jump, bool")]
    public bool InputJump;

    [Tooltip("User-input requests backflip, bool")]
    public bool InputBackflip;

    [Tooltip("Difference between RagDoll character horizontal CM velocity and user-input desired horizontal CM velocity. Vector2")]
    public Vector2 HorizontalVelocityDifference;

    [Tooltip("Positions and velocities for subset of bodies")]
    public List<BodyPartDifferenceStats> BodyPartDifferenceStats;
    public List<DReConObservationStats.Stat> MocapBodyStats;
    public List<DReConObservationStats.Stat> RagDollBodyStats;

    [Tooltip("Smoothed actions produced in the previous step of the policy are collected in t −1")]
    public float[] PreviousActions;

    [Header("Settings")]
    public List<string> BodyPartsToTrack;

    [Header("Gizmos")]
    public bool VelocityInWorldSpace = true;
    public bool PositionInWorldSpace = true;


    public string targetedRootName = "articulation:Hips";



    InputController _inputController;
    SpawnableEnv _spawnableEnv;
    DReConObservationStats _mocapBodyStats;
    DReConObservationStats _ragDollBodyStats;
    bool _hasLazyInitialized;

    public void OnAgentInitialize()
    {
        Assert.IsFalse(_hasLazyInitialized);
        _hasLazyInitialized = true;

        _spawnableEnv = GetComponentInParent<SpawnableEnv>();
        _inputController = _spawnableEnv.GetComponentInChildren<InputController>();
        BodyPartDifferenceStats = BodyPartsToTrack
            .Select(x=> new BodyPartDifferenceStats{Name = x})
            .ToList();

        _mocapBodyStats= new GameObject("MocapDReConObservationStats").AddComponent<DReConObservationStats>();
        _mocapBodyStats.setRootName(targetedRootName);



        _mocapBodyStats.ObjectToTrack = _spawnableEnv.GetComponentInChildren<MocapControllerArtanim>();

        _mocapBodyStats.transform.SetParent(_spawnableEnv.transform);
        _mocapBodyStats.OnAgentInitialize(BodyPartsToTrack, _mocapBodyStats.ObjectToTrack.transform);

        _ragDollBodyStats = new GameObject("RagDollDReConObservationStats").AddComponent<DReConObservationStats>();
        _ragDollBodyStats.setRootName(targetedRootName);

        _ragDollBodyStats.ObjectToTrack = this;
        _ragDollBodyStats.transform.SetParent(_spawnableEnv.transform);
        _ragDollBodyStats.OnAgentInitialize(BodyPartsToTrack, transform);
    }



    public void OnStep(float timeDelta)
    {
        Assert.IsTrue(_hasLazyInitialized);
        _mocapBodyStats.SetStatusForStep(timeDelta);
        _ragDollBodyStats.SetStatusForStep(timeDelta);
        UpdateObservations(timeDelta);
    }
    public void OnReset()
    {
        Assert.IsTrue(_hasLazyInitialized);
        _mocapBodyStats.OnReset();
        _ragDollBodyStats.OnReset();
        _ragDollBodyStats.transform.position = _mocapBodyStats.transform.position;
        _ragDollBodyStats.transform.rotation = _mocapBodyStats.transform.rotation;
        var timeDelta = float.MinValue;
        UpdateObservations(timeDelta);
    }

    public void UpdateObservations(float timeDelta)
    {

        MocapCOMVelocity = _mocapBodyStats.CenterOfMassVelocity;
        RagDollCOMVelocity = _ragDollBodyStats.CenterOfMassVelocity;
        InputDesiredHorizontalVelocity = new Vector2(
            _ragDollBodyStats.DesiredCenterOfMassVelocity.x, 
            _ragDollBodyStats.DesiredCenterOfMassVelocity.z);
        InputJump = _inputController.Jump;
        InputBackflip = _inputController.Backflip;
        HorizontalVelocityDifference = new Vector2(
            _ragDollBodyStats.CenterOfMassVelocityDifference.x,
            _ragDollBodyStats.CenterOfMassVelocityDifference.z);

        MocapBodyStats = BodyPartsToTrack
            .Select(x=>_mocapBodyStats.Stats.First(y=>y.Name == x))
            .ToList();
        RagDollBodyStats = BodyPartsToTrack
            .Select(x=>_ragDollBodyStats.Stats.First(y=>y.Name == x))
            .ToList();
        // BodyPartStats = 
        foreach (var differenceStats in BodyPartDifferenceStats)
        {
            var mocapStats = _mocapBodyStats.Stats.First(x=>x.Name == differenceStats.Name);
            var ragDollStats = _ragDollBodyStats.Stats.First(x=>x.Name == differenceStats.Name);

            differenceStats.Position = mocapStats.Position - ragDollStats.Position;
            differenceStats.Velocity = mocapStats.Velocity - ragDollStats.Velocity;
            differenceStats.AngualrVelocity = mocapStats.AngualrVelocity - ragDollStats.AngualrVelocity;
            differenceStats.Rotation = DReConObservationStats.GetAngularVelocity(mocapStats.Rotation, ragDollStats.Rotation, timeDelta);
        }
    }
    public Transform GetRagDollCOM()
    {
        return _ragDollBodyStats.transform;
    }
    public void ShiftMocapCOM(Vector3 snapDistance)
    {
        _ragDollBodyStats.ShiftCOM(snapDistance);
    }
    void OnDrawGizmos()
    {
        if (_mocapBodyStats == null)
            return;        
        // MocapCOMVelocity
        Vector3 pos = new Vector3(transform.position.x, .3f, transform.position.z);
        Vector3 vector = MocapCOMVelocity;
        if (VelocityInWorldSpace)
            vector = _mocapBodyStats.transform.TransformVector(vector);
        DrawArrow(pos, vector, Color.grey);

        // RagDollCOMVelocity;
        vector = RagDollCOMVelocity;
        if (VelocityInWorldSpace)
            vector = _ragDollBodyStats.transform.TransformVector(vector);
        DrawArrow(pos, vector, Color.blue);
        Vector3 actualPos = pos+vector;

        // InputDesiredHorizontalVelocity;
        vector = new Vector3(InputDesiredHorizontalVelocity.x, 0f, InputDesiredHorizontalVelocity.y);
        if (VelocityInWorldSpace)
            vector = _ragDollBodyStats.transform.TransformVector(vector);
        DrawArrow(pos, vector, Color.green);

        // HorizontalVelocityDifference;
        vector = new Vector3(HorizontalVelocityDifference.x, 0f, HorizontalVelocityDifference.y);
        if (VelocityInWorldSpace)
            vector = _ragDollBodyStats.transform.TransformVector(vector);
        DrawArrow(actualPos, vector, Color.red);

        for (int i = 0; i < RagDollBodyStats.Count; i++)
        {
            var stat = RagDollBodyStats[i];
            var differenceStat = BodyPartDifferenceStats[i];
            pos = stat.Position;
            vector = stat.Velocity;
            if (PositionInWorldSpace)
                pos = _ragDollBodyStats.transform.TransformPoint(pos);
            if (VelocityInWorldSpace)
                vector = _ragDollBodyStats.transform.TransformVector(vector);
            DrawArrow(pos, vector, Color.cyan);
            Vector3 velocityPos = pos+vector;

            pos = stat.Position;
            vector = differenceStat.Position;
            if (PositionInWorldSpace)
                pos = _ragDollBodyStats.transform.TransformPoint(pos);
            if (VelocityInWorldSpace)
                vector = _ragDollBodyStats.transform.TransformVector(vector);
            Gizmos.color = Color.magenta;
            Gizmos.DrawRay(pos, vector);
            Vector3 differencePos = pos+vector;

            vector = differenceStat.Velocity;
            if (VelocityInWorldSpace)
                vector = _ragDollBodyStats.transform.TransformVector(vector);
            DrawArrow(velocityPos, vector, Color.red);
        }
        
    }
    void DrawArrow(Vector3 start, Vector3 vector, Color color)
    {
        float headSize = 0.25f;
        float headAngle = 20.0f;
        Gizmos.color = color;
		Gizmos.DrawRay(start, vector);
 
        if (vector.magnitude > 0f)
        { 
            Vector3 right = Quaternion.LookRotation(vector) * Quaternion.Euler(0,180+headAngle,0) * new Vector3(0,0,1);
            Vector3 left = Quaternion.LookRotation(vector) * Quaternion.Euler(0,180-headAngle,0) * new Vector3(0,0,1);
            Gizmos.DrawRay(start + vector, right * headSize);
            Gizmos.DrawRay(start + vector, left * headSize);
        }
    }
}