MB3_AgglomerativeClustering.cs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. using UnityEngine;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using System;
  6. namespace DigitalOpus.MB.Core
  7. {
  8. [Serializable]
  9. public class MB3_AgglomerativeClustering
  10. {
  11. public List<item_s> items = new List<item_s>();
  12. public ClusterNode[] clusters;
  13. public bool wasCanceled;
  14. [Serializable]
  15. public class ClusterNode
  16. {
  17. public item_s leaf;
  18. public ClusterNode cha;
  19. public ClusterNode chb;
  20. public int height; /* height of node from the bottom */
  21. public float distToMergedCentroid;
  22. public Vector3 centroid; /* centroid of this cluster */
  23. public int[] leafs; /* indexes of root clusters merged */
  24. public int idx; //index in clusters list
  25. public bool isUnclustered = true;
  26. public ClusterNode(item_s ii, int index)
  27. {
  28. leaf = ii;
  29. idx = index;
  30. leafs = new int[1];
  31. leafs[0] = index;
  32. centroid = ii.coord;
  33. height = 0;
  34. }
  35. public ClusterNode(ClusterNode a, ClusterNode b, int index, int h, float dist, ClusterNode[] clusters)
  36. {
  37. cha = a;
  38. chb = b;
  39. idx = index;
  40. leafs = new int[a.leafs.Length + b.leafs.Length];
  41. Array.Copy(a.leafs, leafs, a.leafs.Length);
  42. Array.Copy(b.leafs, 0, leafs, a.leafs.Length, b.leafs.Length);
  43. Vector3 c = Vector3.zero;
  44. for (int i = 0; i < leafs.Length; i++)
  45. {
  46. c += clusters[leafs[i]].centroid;
  47. }
  48. centroid = c / leafs.Length;
  49. height = h;
  50. distToMergedCentroid = dist;
  51. }
  52. };
  53. [Serializable]
  54. public class item_s
  55. {
  56. public GameObject go;
  57. public Vector3 coord; /* coordinate of the input data point */
  58. };
  59. float euclidean_distance(Vector3 a, Vector3 b)
  60. {
  61. return Vector3.Distance(a, b);
  62. }
  63. public bool agglomerate(ProgressUpdateCancelableDelegate progFunc)
  64. {
  65. wasCanceled = true;
  66. if (progFunc != null) wasCanceled = progFunc("Filling Priority Queue:", 0);
  67. if (items.Count <= 1)
  68. {
  69. clusters = new ClusterNode[0];
  70. return false;
  71. //yield break;
  72. }
  73. clusters = new ClusterNode[items.Count * 2 - 1];
  74. for (int i = 0; i < items.Count; i++)
  75. {
  76. clusters[i] = new ClusterNode(items[i], i);
  77. }
  78. int numClussters = items.Count;
  79. List<ClusterNode> unclustered = new List<ClusterNode>();
  80. for (int i = 0; i < numClussters; i++)
  81. {
  82. clusters[i].isUnclustered = true;
  83. unclustered.Add(clusters[i]);
  84. }
  85. int height = 0;
  86. System.Diagnostics.Stopwatch timer = new System.Diagnostics.Stopwatch();
  87. timer.Start();
  88. float largestDistInQ = 0;
  89. long usedMemory = GC.GetTotalMemory(false) / 1000000;
  90. PriorityQueue < float, ClusterDistance > pq = new PriorityQueue<float, ClusterDistance>();
  91. //largestDistInQ = _RefillPriorityQWithSome(pq, unclustered, clusters /*,null,null*/);
  92. int numRefills = 0;
  93. while (unclustered.Count > 1)
  94. {
  95. int numToFindClosetPair = 0;
  96. height++;
  97. //find closest pair
  98. if (pq.Count == 0)
  99. {
  100. numRefills++;
  101. usedMemory = GC.GetTotalMemory(false) / 1000000;
  102. if (progFunc != null) wasCanceled = progFunc("Refilling Q:" + ((float)(items.Count - unclustered.Count) * 100) / items.Count + " unclustered:" + unclustered.Count + " inQ:" + pq.Count + " usedMem:" + usedMemory,
  103. ((float)(items.Count - unclustered.Count)) / items.Count);
  104. largestDistInQ = _RefillPriorityQWithSome(pq, unclustered, clusters, progFunc);
  105. if (pq.Count == 0) break;
  106. }
  107. KeyValuePair<float, ClusterDistance> closestPair = pq.Dequeue();
  108. // should only consider unclustered pairs. It is more effecient to discard nodes that have already been clustered as they are popped off the Q
  109. // than to try to remove them from the Q when they have been clustered.
  110. while (!closestPair.Value.a.isUnclustered || !closestPair.Value.b.isUnclustered) {
  111. if (pq.Count == 0)
  112. {
  113. numRefills++;
  114. usedMemory = GC.GetTotalMemory(false) / 1000000;
  115. if (progFunc != null) wasCanceled = progFunc("Creating clusters:" + ((float)(items.Count - unclustered.Count) * 100) / items.Count + " unclustered:" + unclustered.Count + " inQ:" + pq.Count + " usedMem:" + usedMemory,
  116. ((float)(items.Count - unclustered.Count)) / items.Count);
  117. largestDistInQ = _RefillPriorityQWithSome(pq, unclustered, clusters, progFunc);
  118. if (pq.Count == 0) break;
  119. }
  120. closestPair = pq.Dequeue();
  121. numToFindClosetPair++;
  122. }
  123. //make a new cluster with pair as children set merge height
  124. numClussters++;
  125. ClusterNode cn = new ClusterNode(closestPair.Value.a, closestPair.Value.b, numClussters - 1, height, closestPair.Key, clusters);
  126. //remove children from unclustered
  127. unclustered.Remove(closestPair.Value.a);
  128. unclustered.Remove(closestPair.Value.b);
  129. //We NEED TO REMOVE ALL DISTANCE PAIRS THAT INVOLVE A AND B FROM PRIORITY Q. However searching for all these pairs and removing is very slow.
  130. // Instead we will leave them in the Queue and flag the clusters as isUnclustered = false and discard them as they are popped from the Q which is O(1) operation.
  131. closestPair.Value.a.isUnclustered = false;
  132. closestPair.Value.b.isUnclustered = false;
  133. //add new cluster to unclustered
  134. int newIdx = numClussters - 1;
  135. if (newIdx == clusters.Length)
  136. {
  137. Debug.LogError("how did this happen");
  138. }
  139. clusters[newIdx] = cn;
  140. unclustered.Add(cn);
  141. cn.isUnclustered = true;
  142. //update new clusteres distance
  143. for (int i = 0; i < unclustered.Count - 1; i++)
  144. {
  145. float dist = euclidean_distance(cn.centroid, unclustered[i].centroid);
  146. if (dist < largestDistInQ) //avoid cluttering Qwith
  147. {
  148. pq.Add(new KeyValuePair<float, ClusterDistance>(dist, new ClusterDistance(cn, unclustered[i])));
  149. }
  150. }
  151. //if (timer.Interval > .2f)
  152. //{
  153. // yield return null;
  154. // timer.Start();
  155. //}
  156. if (wasCanceled) break;
  157. usedMemory = GC.GetTotalMemory(false) / 1000000;
  158. if (progFunc != null) wasCanceled = progFunc("Creating clusters:" + ((float)(items.Count - unclustered.Count)*100) / items.Count + " unclustered:" + unclustered.Count + " inQ:" + pq.Count + " usedMem:" + usedMemory,
  159. ((float)(items.Count - unclustered.Count)) / items.Count);
  160. }
  161. if (progFunc != null) wasCanceled = progFunc("Finished clustering:", 100);
  162. //Debug.Log("Time " + timer.Elapsed);
  163. if (wasCanceled)
  164. {
  165. return false;
  166. }
  167. else
  168. {
  169. return true;
  170. }
  171. }
  172. const int MAX_PRIORITY_Q_SIZE = 2048;
  173. float _RefillPriorityQWithSome(PriorityQueue<float, ClusterDistance> pq, List<ClusterNode> unclustered, ClusterNode[] clusters, ProgressUpdateCancelableDelegate progFunc)
  174. {
  175. //find nthSmallest point of distances between pairs
  176. List<float> allDist = new List<float>(2048);
  177. for (int i = 0; i < unclustered.Count; i++)
  178. {
  179. for (int j = i+1; j < unclustered.Count; j++)
  180. {
  181. // if (unclustered[i] == omitA || unclustered[i] == omitB ||
  182. // unclustered[j] == omitA || unclustered[j] == omitB)
  183. // {
  184. // } else
  185. // {
  186. allDist.Add(euclidean_distance(unclustered[i].centroid, unclustered[j].centroid));
  187. // }
  188. }
  189. wasCanceled = progFunc("Refilling Queue Part A:", i / (unclustered.Count * 2f));
  190. if (wasCanceled) return 10f;
  191. }
  192. if (allDist.Count == 0)
  193. {
  194. return 10e10f;
  195. }
  196. float nthSmallest = NthSmallestElement(allDist, MAX_PRIORITY_Q_SIZE);
  197. //load up Q with up to nthSmallest distance pairs
  198. for (int i = 0; i < unclustered.Count; i++)
  199. {
  200. for (int j = i + 1; j < unclustered.Count; j++)
  201. {
  202. int idxa = unclustered[i].idx;
  203. int idxb = unclustered[j].idx;
  204. float newDist = euclidean_distance(unclustered[i].centroid, unclustered[j].centroid);
  205. if (newDist <= nthSmallest)
  206. {
  207. pq.Add(new KeyValuePair<float, ClusterDistance>(newDist, new ClusterDistance(clusters[idxa], clusters[idxb])));
  208. }
  209. }
  210. wasCanceled = progFunc("Refilling Queue Part B:", (unclustered.Count + i) / (unclustered.Count * 2f));
  211. if (wasCanceled) return 10f;
  212. }
  213. return nthSmallest;
  214. }
  215. public int TestRun(List<GameObject> gos)
  216. {
  217. List<item_s> its = new List<item_s>();
  218. for (int i = 0; i < gos.Count; i++)
  219. {
  220. item_s ii = new item_s();
  221. ii.go = gos[i];
  222. ii.coord = gos[i].transform.position;
  223. its.Add(ii);
  224. }
  225. items = its;
  226. if (items.Count > 0)
  227. {
  228. agglomerate(null);
  229. }
  230. return 0;
  231. }
  232. //------
  233. // Unclustered
  234. //need to be able to find the smallest distance between unclustered pairs quickly
  235. //Do this by maintaining a fixed length PriorityQueue (len = 1000)
  236. // Q stores min distances between cluster pairs
  237. // unlclustered stores list of unclustered
  238. //GetMin
  239. // if Q is empty
  240. // build Q from unclustered O(n2)
  241. // track the largestDistanceInQ
  242. // if unclustered is empty we are done
  243. // else
  244. // q.DeQueue O(1)
  245. //
  246. // when creating new merged cluster, calc dist to all other unclustered add these distances to priority Q if less than largestDistanceInQ O(N)
  247. //
  248. public class ClusterDistance
  249. {
  250. public ClusterNode a;
  251. public ClusterNode b;
  252. public ClusterDistance(ClusterNode aa, ClusterNode bb)
  253. {
  254. a = aa;
  255. b = bb;
  256. }
  257. }
  258. public static void Main()
  259. {
  260. List<float> inputArray = new List<float>();
  261. inputArray.AddRange(new float[] { 19, 18, 17, 16, 15, 10, 11, 12, 13, 14 });
  262. // Loop 10 times
  263. Debug.Log("Loop quick select 10 times.");
  264. Debug.Log(NthSmallestElement(inputArray, 0));
  265. }
  266. // n is 0 indexed
  267. public static T NthSmallestElement<T>(List<T> array, int n) where T : IComparable<T>
  268. {
  269. if (n < 0)
  270. n = 0;
  271. if (n > array.Count - 1)
  272. n = array.Count - 1;
  273. if (array.Count == 0)
  274. throw new ArgumentException("Array is empty.", "array");
  275. if (array.Count == 1)
  276. return array[0];
  277. return QuickSelectSmallest(array, n)[n];
  278. }
  279. private static List<T> QuickSelectSmallest<T>(List<T> input, int n) where T : IComparable<T>
  280. {
  281. // Let's not mess up with our input array
  282. // For very large arrays - we should optimize this somehow - or just mess up with our input
  283. var partiallySortedArray = input;
  284. // Initially we are going to execute quick select to entire array
  285. var startIndex = 0;
  286. var endIndex = input.Count - 1;
  287. // Selecting initial pivot
  288. // Maybe we are lucky and array is sorted initially?
  289. var pivotIndex = n;
  290. // Loop until there is nothing to loop (this actually shouldn't happen - we should find our value before we run out of values)
  291. var r = new System.Random();
  292. while (endIndex > startIndex)
  293. {
  294. pivotIndex = QuickSelectPartition(partiallySortedArray, startIndex, endIndex, pivotIndex);
  295. if (pivotIndex == n)
  296. // We found our n:th smallest value - it is stored to pivot index
  297. break;
  298. if (pivotIndex > n)
  299. // Array before our pivot index have more elements that we are looking for
  300. endIndex = pivotIndex - 1;
  301. else
  302. // Array before our pivot index has less elements that we are looking for
  303. startIndex = pivotIndex + 1;
  304. // Omnipotent beings don't need to roll dices - but we do...
  305. // Randomly select a new pivot index between end and start indexes (there are other methods, this is just most brutal and simplest)
  306. pivotIndex = r.Next(startIndex, endIndex);
  307. }
  308. return partiallySortedArray;
  309. }
  310. private static int QuickSelectPartition<T>(List<T> array, int startIndex, int endIndex, int pivotIndex) where T : IComparable<T>
  311. {
  312. var pivotValue = array[pivotIndex];
  313. // Initially we just assume that value in pivot index is largest - so we move it to end (makes also for loop more straight forward)
  314. Swap(array, pivotIndex, endIndex);
  315. for (var i = startIndex; i < endIndex; i++)
  316. {
  317. if (array[i].CompareTo(pivotValue) > 0)
  318. continue;
  319. // Value stored to i was smaller than or equal with pivot value - let's move it to start
  320. Swap(array, i, startIndex);
  321. // Move start one index forward
  322. startIndex++;
  323. }
  324. // Start index is now pointing to index where we should store our pivot value from end of array
  325. Swap(array, endIndex, startIndex);
  326. return startIndex;
  327. }
  328. private static void Swap<T>(List<T> array, int index1, int index2)
  329. {
  330. if (index1 == index2)
  331. return;
  332. var temp = array[index1];
  333. array[index1] = array[index2];
  334. array[index2] = temp;
  335. }
  336. }
  337. }