implementation of drecon in unity 2022 lts forked from: https://github.com/joanllobera/marathon-envs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
7.4 KiB

9 months ago
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using System.Linq;
using ManyWorlds;
public class TerrainGenerator : MonoBehaviour
{
Terrain terrain;
Agent _agent;
DecisionRequester _decisionRequester;
public int posXInTerrain;
public int posYInTerrain;
float[,] _heights;
float[,] _rowHeight;
public int heightIndex;
public float curHeight;
public float actionReward;
internal const float _minHeight = 0f;
internal const float _maxHeight = 10f;
internal const float _minSpawnHeight = 0f;//2f;
internal const float _maxSpawnHeight = 10f;//8f;
const float _midHeight = 5f;
float _mapScaleY;
float[,] _heightMap;
public List<float> debugLastHeights;
public List<float> debugLastNormHeights;
public float debugLastFraction;
PhysicsScene physicsScene;
// Start is called before the first frame update
void Start()
{
physicsScene = (GetComponentInParent<SpawnableEnv>().GetPhysicsScene());
}
// Update is called once per frame
void Update()
{
}
public void Reset()
{
if (this.terrain == null)
{
var parent = gameObject.transform.parent;
terrain = parent.GetComponentInChildren<Terrain>();
var sharedTerrainData = terrain.terrainData;
terrain.terrainData = new TerrainData();
terrain.terrainData.heightmapResolution = sharedTerrainData.heightmapResolution;
terrain.terrainData.baseMapResolution = sharedTerrainData.baseMapResolution;
terrain.terrainData.SetDetailResolution(sharedTerrainData.detailResolution, sharedTerrainData.detailResolutionPerPatch);
terrain.terrainData.size = sharedTerrainData.size;
//terrain.terrainData.thickness = sharedTerrainData.thickness;
// terrain.terrainData.splatPrototypes = sharedTerrainData.splatPrototypes;
terrain.terrainData.terrainLayers = sharedTerrainData.terrainLayers;
var collider = terrain.GetComponent<TerrainCollider>();
collider.terrainData = terrain.terrainData;
_rowHeight = new float[terrain.terrainData.heightmapResolution,1];
}
if (this._agent == null)
{
_agent = GetComponent<Agent>();
_decisionRequester = GetComponent<DecisionRequester>();
}
//print($"HeightMap {this.terrain.terrainData.heightmapWidth}, {this.terrain.terrainData.heightmapHeight}.
// Scale {this.terrain.terrainData.heightmapScale}. Resolution {this.terrain.terrainData.heightmapResolution}");
_mapScaleY = this.terrain.terrainData.heightmapScale.y;
// get the normalized position of this game object relative to the terrain
Vector3 tempCoord = (transform.position - terrain.gameObject.transform.position);
Vector3 coord;
tempCoord.x = Mathf.Clamp(tempCoord.x,0f, terrain.terrainData.size.x-0.000001f);
tempCoord.z = Mathf.Clamp(tempCoord.z,0f, terrain.terrainData.size.z-0.000001f);
coord.x = (tempCoord.x-1) / terrain.terrainData.size.x;
coord.y = tempCoord.y / terrain.terrainData.size.y;
coord.z = tempCoord.z / terrain.terrainData.size.z;
// get the position of the terrain heightmap where this game object is
posXInTerrain = (int) (coord.x * terrain.terrainData.heightmapResolution);
posYInTerrain = (int) (coord.z * terrain.terrainData.heightmapResolution);
// we set an offset so that all the raising terrain is under this game object
int offset = 0 / 2;
// get the heights of the terrain under this game object
_heights = terrain.terrainData.GetHeights(posXInTerrain-offset,posYInTerrain-offset, 100,1);
curHeight = _midHeight;
heightIndex = posXInTerrain;
actionReward = 0f;
ResetHeights();
}
void ResetHeights()
{
if (_heightMap == null){
_heightMap = new float [terrain.terrainData.heightmapResolution, terrain.terrainData.heightmapResolution];
}
heightIndex = 0;
while(heightIndex <posXInTerrain)
SetNextHeight(0);
SetNextHeight(0);
SetNextHeight(0);
SetNextHeight(0);
SetNextHeight(0);
SetNextHeight(0);
SetNextHeight(0);
while(heightIndex < terrain.terrainData.heightmapResolution)
{
int action = Random.Range(0,21);
try
{
SetNextHeight(action);
}
catch (System.Exception)
{
SetNextHeight(action);
throw;
}
}
this.terrain.terrainData.SetHeights(0, 0, _heightMap);
}
void SetNextHeight(int action)
{
float actionSize = 0f;
bool actionPos = (action-1) % 2 == 0;
if (action != 0)
{
actionSize = ((float)((action+1)/2)) * 0.1f;
curHeight += actionPos ? actionSize : -actionSize;
if (curHeight < _minSpawnHeight) {
curHeight = _minSpawnHeight;
actionSize = 0;
}
if (curHeight > _maxSpawnHeight)
{
curHeight = _maxSpawnHeight;
actionSize = 0;
}
}
float height = curHeight / _mapScaleY;
// var unit = terrain.terrainData.heightmapWidth / (int)_mapScaleY;
int unit = 1;
int startH = heightIndex * unit;
for (int h = startH; h < startH+unit; h++)
{
for (int w = 0; w < terrain.terrainData.heightmapResolution; w++){
_heightMap[w,h] = height;
}
height += 1/300f/_mapScaleY;
}
heightIndex++;
}
public List<float> GetDistances2d(IEnumerable<Vector3> points)
{
List<float> distances = points
.Select(x=> GetDistance2d(x))
.ToList();
return distances;
}
float GetDistance2d(Vector3 point)
{
int layerMask = ~(1 << 14);
RaycastHit hit;
if (!physicsScene.Raycast(point, Vector3.down, out hit,_maxHeight,layerMask))
return 1f;
float distance = hit.distance;
distance = Mathf.Clamp(distance, -1f, 1f);
return distance;
}
public bool IsPointOffEdge(Vector3 point)
{
Vector3 localPos = (point - terrain.gameObject.transform.position);
bool isOffEdge = false;
isOffEdge |= (localPos.z < 0f);
isOffEdge |= (localPos.z >= terrain.terrainData.size.z);
return isOffEdge;
}
public (List<float>, float) GetDistances2d(Vector3 pos, bool showDebug)
{
int layerMask = ~(1 << 14);
var xpos = pos.x;
xpos -= 2f;
float fraction = (xpos - (Mathf.Floor(xpos*5)/5)) * 5;
float ypos = pos.y;
List<Ray> rays = Enumerable.Range(0, 5*7).Select(x => new Ray(new Vector3(xpos+(x*.2f), TerrainGenerator._maxHeight, pos.z), Vector3.down)).ToList();
RaycastHit hit;
List<float> distances = rays.Select
( ray=> {
if (!physicsScene.Raycast(ray.origin, ray.direction, out hit,_maxHeight,layerMask))
return _maxHeight;
return ypos - (_maxHeight - hit.distance);
}).ToList();
if (Application.isEditor && showDebug)
{
var view = distances.Skip(10).Take(20).Select(x=>x).ToList();
//Monitor.Log("distances", view.ToArray());
var time = Time.deltaTime;
if (_decisionRequester?.DecisionPeriod > 1)
time *= _decisionRequester.DecisionPeriod;
for (int i = 0; i < rays.Count; i++)
{
var distance = distances[i];
var origin = new Vector3(rays[i].origin.x, ypos,0f);
var direction = distance > 0 ? Vector3.down : Vector3.up;
var color = distance > 0 ? Color.yellow : Color.red;
Debug.DrawRay(origin, direction*Mathf.Abs(distance), color, time, false);
}
}
List<float> normalizedDistances = distances
.Select(x => Mathf.Clamp(x, -10f, 10f))
.Select(x => x/10f)
.ToList();
;
debugLastNormHeights = normalizedDistances;
debugLastHeights = distances;
debugLastFraction = fraction;
return (normalizedDistances, fraction);
}
}