using System.Collections; using System.Collections.Generic; using UnityEngine; using Unity.MLAgents; using Unity.Barracuda; using Unity.MLAgents.Policies; using System; using Unity.Barracuda.ONNX; using System.IO; using System.Linq; public class ModelSwap : MonoBehaviour { // Models public NNModel m_InitialModel; // public Agent agent; public List nnModelList; [HideInInspector] public Walker agent; [HideInInspector] public int currentModel = 0; [HideInInspector] public string m_currentModelName = string.Empty; [HideInInspector] public string m_PastModel = string.Empty; // Path to .onnx files string m_relPath = string.Empty; private void OnEnable() { WaypointManager.SwapModelOnWaypointReached += SwapModelOnReachingWaypoint; WaypointManager.OnOverrideWaypoint += SwitchModelOverride; } private void OnDisable() { WaypointManager.SwapModelOnWaypointReached -= SwapModelOnReachingWaypoint; WaypointManager.OnOverrideWaypoint += SwitchModelOverride; } private void Start() { agent = this.GetComponent(); } // Set Model by name public void SwitchModel(string modelName, Agent inst) { FindModelByName(modelName); inst.SetModel("c_Walker", nnModelList[currentModel]); if (modelName.Equals("Walker")) { agent.MTargetWalkingSpeed = 8f; agent.m_SelectedBrain = Walker.Brain.Walker; m_currentModelName = modelName; } else if (modelName.Equals("Stairs")) { agent.MTargetWalkingSpeed = 8f; agent.m_SelectedBrain = Walker.Brain.DMScrambler; m_currentModelName = modelName; } else if (modelName.Equals("Climber")) { agent.MTargetWalkingSpeed = 8f; agent.m_SelectedBrain = Walker.Brain.Climber; m_currentModelName = modelName; } else if (modelName.Equals("Getup")) { agent.MTargetWalkingSpeed = 8f; agent.m_SelectedBrain = Walker.Brain.Getup; m_currentModelName = modelName; } else if (modelName.Equals("Sitting")) { agent.MTargetWalkingSpeed = 8f; agent.m_SelectedBrain = Walker.Brain.Sitting; m_currentModelName = modelName; } else if (modelName.Equals("Treadmill")) { agent.MTargetWalkingSpeed = 30f; agent.m_SelectedBrain = Walker.Brain.Treadmill; m_currentModelName = modelName; } Debug.Log("Current Model: " + nnModelList[currentModel].name); } private void FindModelByName(string modelName) { int i = 0; foreach (NNModel element in nnModelList) { if (element.name.Equals(modelName)) { currentModel = i; return; } i++; } } private void SwapModelOnReachingWaypoint(string modelName) { SwitchModel(modelName, agent); } public void SetInitialModel() { SwitchModel(m_InitialModel.name, agent); } private void SwitchModelOverride(string modelName) { if (!m_currentModelName.Equals("Getup")) { SwapModelOnReachingWaypoint(modelName); } else { m_PastModel = modelName; } } };