#pragma kernel EstimatePlanePoints

uint OfsHistBinLength;
uint PointCloudOfsLength;
float BinSize;
uint BinAggregation;
uint MinimumFloorPointCount;

StructuredBuffer<uint> OfsHistBinCount;
StructuredBuffer<float> PointCloudPos;
StructuredBuffer<float> PointCloudOfs;
StructuredBuffer<bool> PointCloudMask;
StructuredBuffer<float> OfsMinMax;

RWStructuredBuffer<float> OfsHistBinLeft;
RWStructuredBuffer<uint> HistCumulativeCount;
RWStructuredBuffer<uint> InlierIndices;
RWStructuredBuffer<float3> PlanePosNorm;


[numthreads(1, 1, 1)]
void EstimatePlanePoints(uint3 id : SV_DispatchThreadID)
{
	float minOfs = OfsMinMax[0];
	OfsHistBinLeft[0] = minOfs;

	HistCumulativeCount[0] = OfsHistBinCount[0];
	for (uint hi = 1; hi < OfsHistBinLength; hi++)
	{
		OfsHistBinLeft[hi] = hi * BinSize + minOfs;
		HistCumulativeCount[hi] = HistCumulativeCount[hi - 1] + OfsHistBinCount[hi];
	}

	PlanePosNorm[0] = float3(0, 0, 0);
	PlanePosNorm[1] = float3(0, 0, 0);

	for (uint i = 1; (i + BinAggregation) < OfsHistBinLength; i++)  // i += BinAggregation
	{
		uint aggBinStart = i;                 // inclusive bin
		uint aggBinEnd = i + BinAggregation;  // exclusive bin
		uint inlierCount = HistCumulativeCount[aggBinEnd - 1] - HistCumulativeCount[aggBinStart - 1];

		if (inlierCount > MinimumFloorPointCount)
		{
			float offsetStart = OfsHistBinLeft[aggBinStart]; // inclusive
			float offsetEnd = OfsHistBinLeft[aggBinEnd];     // exclusive

			// Inlier indices.
			uint iiIndex = 0;
			for (uint j = 0; j < PointCloudOfsLength; j++)
			{
				if (PointCloudMask[j] && offsetStart <= PointCloudOfs[j] && PointCloudOfs[j] < offsetEnd)
				{
					InlierIndices[iiIndex + 1] = j;
					iiIndex++;
				}
			}

			InlierIndices[0] = iiIndex;

			// Compute centroid.
			float3 centr = float3(0, 0, 0);
			for (uint ii1 = 0; ii1 < iiIndex; ii1++)
			{
				int idx1 = InlierIndices[ii1 + 1];
				int idx3 = idx1 * 3;

				//if (PointCloudMask[idx1])
				{
					float3 pcPos = float3(PointCloudPos[idx3], PointCloudPos[idx3 + 1], PointCloudPos[idx3 + 2]);
					centr += pcPos;
				}
			}

			centr = centr / iiIndex;

			// Compute the zero-mean 3x3 symmetric covariance matrix relative.
			float xx = 0, xy = 0, xz = 0, yy = 0, yz = 0, zz = 0;
			for (uint ii2 = 0; ii2 < iiIndex; ii2++)
			{
				int idx1 = InlierIndices[ii2 + 1];
				int idx3 = idx1 * 3;

				//if (PointCloudMask[idx1])
				{
					float3 pcPos = float3(PointCloudPos[idx3], PointCloudPos[idx3 + 1], PointCloudPos[idx3 + 2]);
					float3 r = pcPos - centr;

					xx += r.x * r.x;
					xy += r.x * r.y;
					xz += r.x * r.z;
					yy += r.y * r.y;
					yz += r.y * r.z;
					zz += r.z * r.z;
				}
			}

			float detX = yy * zz - yz * yz;
			float detY = xx * zz - xz * xz;
			float detZ = xx * yy - xy * xy;

			float detMax = max(detX, detY);
			detMax = max(detMax, detZ);

			float3 norm = float3(0, 0, 0);
			if (detMax == detX)
			{
				norm = float3(detX, xz * yz - xy * zz, xy * yz - xz * yy);
			}
			else if (detMax == detY)
			{
				norm = float3(xz * yz - xy * zz, detY, xy * xz - yz * xx);
			}
			else
			{
				norm = float3(xy * yz - xz * yy, xy * xz - yz * xx, detZ);
			}

			// save centroid and normal
			PlanePosNorm[0] = centr;
			PlanePosNorm[1] = norm;
			PlanePosNorm[2] = float3(offsetStart, offsetEnd, 0.0);
			PlanePosNorm[3] = float3(aggBinStart, aggBinEnd, HistCumulativeCount[aggBinEnd - 1] - HistCumulativeCount[aggBinStart - 1]);

			break;
		}
	}
}