implementation of drecon in unity 2022 lts
forked from:
https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
107 lines
3.2 KiB
107 lines
3.2 KiB
using System;
|
|
using System.Collections;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using Unity.Barracuda;
|
|
using Unity.Barracuda.ONNX;
|
|
using Unity.MLAgents.Policies;
|
|
using UnityEngine;
|
|
|
|
public class ModelManager : MonoBehaviour
|
|
{
|
|
[SerializeField]
|
|
public string m_ModelName;
|
|
|
|
[Header("Model Settings")]
|
|
public NNModel m_InitialModel;
|
|
public string m_DirectoryPath;
|
|
|
|
private List<NNModel> m_ModelList = new List<NNModel>();
|
|
private int m_CurrentModel = 0;
|
|
|
|
private BehaviorParameters m_Parameters;
|
|
private RagDollAgent m_Agent;
|
|
|
|
// Unity Events //
|
|
public delegate void onUpdateModelNameDelegate(float modelName);
|
|
public static onUpdateModelNameDelegate m_UpdateModelName;
|
|
private void OnEnable()
|
|
{
|
|
RagDollAgent.m_MoveToNextModel += SwapModel;
|
|
m_Parameters = GetComponent<BehaviorParameters>();
|
|
m_Agent = GetComponent<RagDollAgent>();
|
|
}
|
|
|
|
private void OnDisable()
|
|
{
|
|
RagDollAgent.m_MoveToNextModel -= SwapModel;
|
|
}
|
|
|
|
private void Start()
|
|
{
|
|
LoadLocalModels();
|
|
SetModel(m_ModelList[m_CurrentModel]);
|
|
}
|
|
|
|
private void SwapModel()
|
|
{
|
|
m_CurrentModel++;
|
|
if (m_CurrentModel >= m_ModelList.Count)
|
|
m_CurrentModel = 0;
|
|
SetModel(m_ModelList[m_CurrentModel]);
|
|
}
|
|
|
|
private void SetModel(NNModel model)
|
|
{
|
|
m_Agent.SetModel("DReCon-v0", model, InferenceDevice.Burst);
|
|
UpdateModelName();
|
|
}
|
|
|
|
private void LoadLocalModels()
|
|
{
|
|
DirectoryInfo dirInfo = new DirectoryInfo(Path.Combine(Application.dataPath, m_DirectoryPath));
|
|
FileInfo[] nnList = dirInfo.GetFiles("*.onnx");
|
|
// Sort files by filename (assuming filenames are numeric)
|
|
Array.Sort(nnList, (x, y) => {
|
|
int num1 = int.Parse(Path.GetFileNameWithoutExtension(x.Name));
|
|
int num2 = int.Parse(Path.GetFileNameWithoutExtension(y.Name));
|
|
return num1.CompareTo(num2);
|
|
});
|
|
|
|
ConvertNNModels(nnList);
|
|
}
|
|
|
|
private void ConvertNNModels(FileInfo[] nnList)
|
|
{
|
|
foreach (FileInfo element in nnList)
|
|
{
|
|
|
|
var converter = new ONNXModelConverter(true);
|
|
byte[] modelData = File.ReadAllBytes(element.FullName.ToString());
|
|
Model model = converter.Convert(modelData);
|
|
NNModelData modelD = ScriptableObject.CreateInstance<NNModelData>();
|
|
using (var memoryStream = new MemoryStream())
|
|
using (var writer = new BinaryWriter(memoryStream))
|
|
{
|
|
ModelWriter.Save(writer, model);
|
|
modelD.Value = memoryStream.ToArray();
|
|
}
|
|
modelD.name = "Data";
|
|
modelD.hideFlags = HideFlags.HideInHierarchy;
|
|
NNModel result = ScriptableObject.CreateInstance<NNModel>();
|
|
result.modelData = modelD;
|
|
result.name = element.Name;
|
|
// Add Model to Model List
|
|
m_ModelList.Add(result);
|
|
}
|
|
}
|
|
|
|
private void UpdateModelName()
|
|
{
|
|
if (m_Parameters != null)
|
|
{
|
|
m_ModelName = Path.GetFileNameWithoutExtension(m_Parameters.Model.name);
|
|
m_UpdateModelName?.Invoke(float.Parse(m_ModelName));
|
|
}
|
|
}
|
|
}
|
|
|