implementation of drecon in unity 2022 lts forked from: https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
3.2 KiB

using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using Unity.Barracuda;
using Unity.Barracuda.ONNX;
using Unity.MLAgents.Policies;
using UnityEngine;
public class ModelManager : MonoBehaviour
{
[SerializeField]
public string m_ModelName;
[Header("Model Settings")]
public NNModel m_InitialModel;
public string m_DirectoryPath;
private List<NNModel> m_ModelList = new List<NNModel>();
private int m_CurrentModel = 0;
private BehaviorParameters m_Parameters;
private RagDollAgent m_Agent;
// Unity Events //
public delegate void onUpdateModelNameDelegate(float modelName);
public static onUpdateModelNameDelegate m_UpdateModelName;
private void OnEnable()
{
RagDollAgent.m_MoveToNextModel += SwapModel;
m_Parameters = GetComponent<BehaviorParameters>();
m_Agent = GetComponent<RagDollAgent>();
}
private void OnDisable()
{
RagDollAgent.m_MoveToNextModel -= SwapModel;
}
private void Start()
{
LoadLocalModels();
SetModel(m_ModelList[m_CurrentModel]);
}
private void SwapModel()
{
m_CurrentModel++;
if (m_CurrentModel >= m_ModelList.Count)
m_CurrentModel = 0;
SetModel(m_ModelList[m_CurrentModel]);
}
private void SetModel(NNModel model)
{
m_Agent.SetModel("DReCon-v0", model, InferenceDevice.Burst);
UpdateModelName();
}
private void LoadLocalModels()
{
9 months ago
DirectoryInfo dirInfo = new DirectoryInfo("C:\\Users\\caile\\Desktop\\onnx");
FileInfo[] nnList = dirInfo.GetFiles("*.onnx");
// Sort files by filename (assuming filenames are numeric)
Array.Sort(nnList, (x, y) => {
int num1 = int.Parse(Path.GetFileNameWithoutExtension(x.Name));
int num2 = int.Parse(Path.GetFileNameWithoutExtension(y.Name));
return num1.CompareTo(num2);
});
ConvertNNModels(nnList);
}
private void ConvertNNModels(FileInfo[] nnList)
{
foreach (FileInfo element in nnList)
{
var converter = new ONNXModelConverter(true);
byte[] modelData = File.ReadAllBytes(element.FullName.ToString());
Model model = converter.Convert(modelData);
NNModelData modelD = ScriptableObject.CreateInstance<NNModelData>();
using (var memoryStream = new MemoryStream())
using (var writer = new BinaryWriter(memoryStream))
{
ModelWriter.Save(writer, model);
modelD.Value = memoryStream.ToArray();
}
modelD.name = "Data";
modelD.hideFlags = HideFlags.HideInHierarchy;
NNModel result = ScriptableObject.CreateInstance<NNModel>();
result.modelData = modelD;
result.name = element.Name;
// Add Model to Model List
m_ModelList.Add(result);
}
}
private void UpdateModelName()
{
if (m_Parameters != null)
{
m_ModelName = Path.GetFileNameWithoutExtension(m_Parameters.Model.name);
m_UpdateModelName?.Invoke(float.Parse(m_ModelName));
}
}
}