using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using System.Linq;
using static BodyHelper002;

public class MarathonManAgent : Agent, IOnTerrainCollision
{
	BodyManager002 _bodyManager;
	bool _isDone;
	bool _hasLazyInitialized;

	public static BodyConfig BodyConfig = new BodyConfig
	{
		GetBodyPartGroup = (name) =>
		{
			name = name.ToLower();
			if (name.Contains("mixamorig"))
				return BodyPartGroup.None;

			if (name.Contains("butt"))
				return BodyPartGroup.Hips;
			if (name.Contains("torso"))
				return BodyPartGroup.Torso;
			if (name.Contains("head"))
				return BodyPartGroup.Head;
			if (name.Contains("waist"))
				return BodyPartGroup.Spine;

			if (name.Contains("thigh"))
				return BodyPartGroup.LegUpper;
			if (name.Contains("shin"))
				return BodyPartGroup.LegLower;
			if (name.Contains("right_right_foot") || name.Contains("left_left_foot"))
				return BodyPartGroup.Foot;
			if (name.Contains("upper_arm"))
				return BodyPartGroup.ArmUpper;
			if (name.Contains("larm"))
				return BodyPartGroup.ArmLower;
			if (name.Contains("hand"))
				return BodyPartGroup.Hand;

			return BodyPartGroup.None;
		},
		GetMuscleGroup = (name) =>
		{
			name = name.ToLower();
			if (name.Contains("mixamorig"))
				return MuscleGroup.None;
			if (name.Contains("butt"))
				return MuscleGroup.Hips;
			if (name.Contains("lower_waist")
				|| name.Contains("abdomen_y"))
				return MuscleGroup.Spine;
			if (name.Contains("thigh")
				|| name.Contains("hip"))
				return MuscleGroup.LegUpper;
			if (name.Contains("shin"))
				return MuscleGroup.LegLower;
			if (name.Contains("right_right_foot")
				|| name.Contains("left_left_foot")
				|| name.Contains("ankle_x"))
				return MuscleGroup.Foot;
			if (name.Contains("upper_arm"))
				return MuscleGroup.ArmUpper;
			if (name.Contains("larm"))
				return MuscleGroup.ArmLower;
			if (name.Contains("hand"))
				return MuscleGroup.Hand;

			return MuscleGroup.None;
		},
        GetRootBodyPart = () => BodyPartGroup.Hips,
        GetRootMuscle = () => MuscleGroup.Hips
    };



	override public void CollectObservations(VectorSensor sensor)
	{
		if (!_hasLazyInitialized)
		{
			OnEpisodeBegin();
		}

		Vector3 normalizedVelocity = _bodyManager.GetNormalizedVelocity();
        var pelvis = _bodyManager.GetFirstBodyPart(BodyPartGroup.Hips);
        var shoulders = _bodyManager.GetFirstBodyPart(BodyPartGroup.Torso);

        sensor.AddObservation(normalizedVelocity); 
        sensor.AddObservation(pelvis.Rigidbody.transform.forward); // gyroscope 
        sensor.AddObservation(pelvis.Rigidbody.transform.up);

        sensor.AddObservation(shoulders.Rigidbody.transform.forward); // gyroscope 
        sensor.AddObservation(shoulders.Rigidbody.transform.up);

		sensor.AddObservation(_bodyManager.GetSensorIsInTouch());
		foreach (var bodyPart in _bodyManager.BodyParts)
		{
			bodyPart.UpdateObservations();
			sensor.AddObservation(bodyPart.ObsLocalPosition);
			sensor.AddObservation(bodyPart.ObsRotation);
			sensor.AddObservation(bodyPart.ObsRotationVelocity);
			sensor.AddObservation(bodyPart.ObsVelocity);
		}
		sensor.AddObservation(_bodyManager.GetSensorObservations());

		// _bodyManager.OnCollectObservationsHandleDebug(GetInfo());
	}

	public override void OnActionReceived(ActionBuffers actions)
    {
        float[] vectorAction = actions.ContinuousActions.Select(x=>x).ToArray();
		
		if (!_hasLazyInitialized)
		{
			return;
		}
		_isDone = false;
		// apply actions to body
		_bodyManager.OnAgentAction(vectorAction);

		// manage reward
        float velocity = Mathf.Clamp(_bodyManager.GetNormalizedVelocity().x, 0f, 1f);
		var actionDifference = _bodyManager.GetActionDifference();
		var actionsAbsolute = vectorAction.Select(x=>Mathf.Abs(x)).ToList();
		var actionsAtLimit = actionsAbsolute.Select(x=> x>=1f ? 1f : 0f).ToList();
		float actionaAtLimitCount = actionsAtLimit.Sum();
        float notAtLimitBonus = 1f - (actionaAtLimitCount / (float) actionsAbsolute.Count);
        float reducedPowerBonus = 1f - actionsAbsolute.Average();

		// velocity *= 0.85f;
		// reducedPowerBonus *=0f;
		// notAtLimitBonus *=.1f;
		// actionDifference *=.05f;
        // var reward = velocity
		// 				+ notAtLimitBonus
		// 				+ reducedPowerBonus
		// 				+ actionDifference;		
        var pelvis = _bodyManager.GetFirstBodyPart(BodyPartGroup.Hips);
		if (pelvis.Transform.position.y<0){
			EndEpisode();
		}

        var reward = velocity;

		AddReward(reward);
		_bodyManager.SetDebugFrameReward(reward);
	}

	public override void OnEpisodeBegin()
	{
		if (!_hasLazyInitialized)
		{
			_bodyManager = GetComponent<BodyManager002>();
			_bodyManager.BodyConfig = MarathonManAgent.BodyConfig;
			_bodyManager.OnInitializeAgent();
			_hasLazyInitialized = true;
		}
		_isDone = true;
		_bodyManager.OnAgentReset();
	}
	public virtual void OnTerrainCollision(GameObject other, GameObject terrain)
	{
		// if (string.Compare(terrain.name, "Terrain", true) != 0)
		if (terrain.GetComponent<Terrain>() == null)
			return;
		// if (!_styleAnimator.AnimationStepsReady)
		// 	return;
        // HACK - for when agent has not been initialized
		if (_bodyManager == null)
			return;
		var bodyPart = _bodyManager.BodyParts.FirstOrDefault(x=>x.Transform.gameObject == other);
		if (bodyPart == null)
			return;
		switch (bodyPart.Group)
		{
			case BodyHelper002.BodyPartGroup.None:
			case BodyHelper002.BodyPartGroup.Foot:
			// case BodyHelper002.BodyPartGroup.LegUpper:
			case BodyHelper002.BodyPartGroup.LegLower:
			case BodyHelper002.BodyPartGroup.Hand:
			// case BodyHelper002.BodyPartGroup.ArmLower:
			// case BodyHelper002.BodyPartGroup.ArmUpper:
				break;
			default:
				// AddReward(-100f);
				if (!_isDone){
					EndEpisode();
				}
				break;
		}
	}
}