model swap manager added

This commit is contained in:
2024-04-29 13:21:37 +01:00
parent 024200a8a7
commit a4de1e1104
8 changed files with 118 additions and 18 deletions

View File

@@ -1,5 +1,9 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using Unity.Barracuda;
using Unity.Barracuda.ONNX;
using Unity.MLAgents.Policies;
using UnityEngine;
@@ -8,22 +12,96 @@ public class ModelManager : MonoBehaviour
[SerializeField]
public string m_ModelName;
[Header("Model Settings")]
public NNModel m_InitialModel;
public string m_DirectoryPath;
private List<NNModel> m_ModelList = new List<NNModel>();
private int m_CurrentModel = 0;
private BehaviorParameters m_Parameters;
private RagDollAgent m_Agent;
// Unity Events //
public delegate void onUpdateModelNameDelegate(string modelName);
public delegate void onUpdateModelNameDelegate(float modelName);
public static onUpdateModelNameDelegate m_UpdateModelName;
private void OnEnable()
{
RagDollAgent.m_MoveToNextModel += SwapModel;
m_Parameters = GetComponent<BehaviorParameters>();
m_Agent = GetComponent<RagDollAgent>();
}
private void FixedUpdate()
private void OnDisable()
{
RagDollAgent.m_MoveToNextModel -= SwapModel;
}
private void Start()
{
LoadLocalModels();
SetModel(m_ModelList[m_CurrentModel]);
}
private void SwapModel()
{
m_CurrentModel++;
if (m_CurrentModel >= m_ModelList.Count)
m_CurrentModel = 0;
SetModel(m_ModelList[m_CurrentModel]);
}
private void SetModel(NNModel model)
{
m_Agent.SetModel("DReCon-v0", model, InferenceDevice.Burst);
UpdateModelName();
}
private void LoadLocalModels()
{
DirectoryInfo dirInfo = new DirectoryInfo(Path.Combine(Application.dataPath, m_DirectoryPath));
FileInfo[] nnList = dirInfo.GetFiles("*.onnx");
// Sort files by filename (assuming filenames are numeric)
Array.Sort(nnList, (x, y) => {
int num1 = int.Parse(Path.GetFileNameWithoutExtension(x.Name));
int num2 = int.Parse(Path.GetFileNameWithoutExtension(y.Name));
return num1.CompareTo(num2);
});
ConvertNNModels(nnList);
}
private void ConvertNNModels(FileInfo[] nnList)
{
foreach (FileInfo element in nnList)
{
var converter = new ONNXModelConverter(true);
byte[] modelData = File.ReadAllBytes(element.FullName.ToString());
Model model = converter.Convert(modelData);
NNModelData modelD = ScriptableObject.CreateInstance<NNModelData>();
using (var memoryStream = new MemoryStream())
using (var writer = new BinaryWriter(memoryStream))
{
ModelWriter.Save(writer, model);
modelD.Value = memoryStream.ToArray();
}
modelD.name = "Data";
modelD.hideFlags = HideFlags.HideInHierarchy;
NNModel result = ScriptableObject.CreateInstance<NNModel>();
result.modelData = modelD;
result.name = element.Name;
// Add Model to Model List
m_ModelList.Add(result);
}
}
private void UpdateModelName()
{
if (m_Parameters != null)
{
m_ModelName = m_Parameters.Model.name;
m_UpdateModelName?.Invoke(m_ModelName);
m_ModelName = Path.GetFileNameWithoutExtension(m_Parameters.Model.name);
m_UpdateModelName?.Invoke(float.Parse(m_ModelName));
}
}
}

View File

@@ -34,11 +34,11 @@ public class UIManager : MonoBehaviour
}
}
private void UpdateModelName(string modelName)
private void UpdateModelName(float modelName)
{
if( m_ModelNameText != null )
{
m_ModelNameText.text = modelName;
m_ModelNameText.text = modelName.ToString("N0");
}
}
}