MB3_GrouperCluster.cs 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using UnityEngine;
  5. public class MB3_KMeansClustering {
  6. class DataPoint
  7. {
  8. public Vector3 center;
  9. public GameObject gameObject;
  10. public int Cluster;
  11. public DataPoint(GameObject go)
  12. {
  13. gameObject = go;
  14. center = go.transform.position;
  15. if (go.GetComponent<Renderer>() == null) Debug.LogError("Object does not have a renderer " + go);
  16. }
  17. }
  18. List<DataPoint> _normalizedDataToCluster = new List<DataPoint>();
  19. Vector3[] _clusters = new Vector3[0];
  20. private int _numberOfClusters = 0;
  21. public MB3_KMeansClustering(List<GameObject> gos, int numClusters)
  22. {
  23. for (int i = 0; i < gos.Count; i++)
  24. {
  25. if (gos[i] != null)
  26. {
  27. DataPoint dp = new DataPoint(gos[i]);
  28. _normalizedDataToCluster.Add(dp);
  29. } else
  30. {
  31. Debug.LogWarning(String.Format("Object {0} in list of objects to cluster was null.", i));
  32. }
  33. }
  34. if (numClusters <= 0)
  35. {
  36. Debug.LogError("Number of clusters must be posititve.");
  37. numClusters = 1;
  38. }
  39. if (_normalizedDataToCluster.Count <= numClusters)
  40. {
  41. Debug.LogError("There must be fewer clusters than objects to cluster");
  42. numClusters = _normalizedDataToCluster.Count - 1;
  43. }
  44. _numberOfClusters = numClusters;
  45. if (_numberOfClusters <= 0) _numberOfClusters = 1;
  46. _clusters = new Vector3[_numberOfClusters];
  47. }
  48. private void InitializeCentroids()
  49. {
  50. //todo error if more clusters than objs
  51. for (int i = 0; i < _numberOfClusters; ++i)
  52. {
  53. _normalizedDataToCluster[i].Cluster = i;
  54. }
  55. for (int i = _numberOfClusters; i < _normalizedDataToCluster.Count; i++)
  56. {
  57. _normalizedDataToCluster[i].Cluster = UnityEngine.Random.Range(0, _numberOfClusters);
  58. }
  59. }
  60. private bool UpdateDataPointMeans(bool force)
  61. {
  62. if (AnyAreEmpty(_normalizedDataToCluster) && !force) return false;
  63. Vector3[] means = new Vector3[_numberOfClusters];
  64. int[] numInCluster = new int[_numberOfClusters];
  65. for (int i = 0; i < _normalizedDataToCluster.Count; i++)
  66. {
  67. int idx = _normalizedDataToCluster[i].Cluster;
  68. means[idx] += _normalizedDataToCluster[i].center;
  69. numInCluster[idx]++;
  70. }
  71. for (int i = 0; i < _numberOfClusters; i++)
  72. {
  73. _clusters[i] = means[i] / numInCluster[i];
  74. }
  75. return true;
  76. }
  77. private bool AnyAreEmpty(List<DataPoint> data)
  78. {
  79. int[] numInCluster = new int[_numberOfClusters];
  80. for (int i = 0; i < _normalizedDataToCluster.Count; i++)
  81. {
  82. numInCluster[_normalizedDataToCluster[i].Cluster]++;
  83. }
  84. for (int i = 0; i < numInCluster.Length; i++)
  85. {
  86. if (numInCluster[i] == 0)
  87. {
  88. return true;
  89. }
  90. }
  91. return false;
  92. }
  93. private bool UpdateClusterMembership()
  94. {
  95. bool changed = false;
  96. float[] distances = new float[_numberOfClusters];
  97. for (int i = 0; i < _normalizedDataToCluster.Count; ++i)
  98. {
  99. for (int k = 0; k < _numberOfClusters; ++k)
  100. {
  101. distances[k] = ElucidanDistance(_normalizedDataToCluster[i], _clusters[k]);
  102. }
  103. int newClusterId = MinIndex(distances);
  104. if (newClusterId != _normalizedDataToCluster[i].Cluster)
  105. {
  106. changed = true;
  107. _normalizedDataToCluster[i].Cluster = newClusterId;
  108. }
  109. else
  110. {
  111. }
  112. }
  113. if (changed == false) return false;
  114. //if (AnyAreEmpty(_normalizedDataToCluster)) return false;
  115. return true;
  116. }
  117. private float ElucidanDistance(DataPoint dataPoint, Vector3 mean)
  118. {
  119. return Vector3.Distance(dataPoint.center, mean);
  120. }
  121. private int MinIndex(float[] distances)
  122. {
  123. int _indexOfMin = 0;
  124. double _smallDist = distances[0];
  125. for (int k = 0; k < distances.Length; ++k)
  126. {
  127. if (distances[k] < _smallDist)
  128. {
  129. _smallDist = distances[k];
  130. _indexOfMin = k;
  131. }
  132. }
  133. return _indexOfMin;
  134. }
  135. public List<Renderer> GetCluster(int idx, out Vector3 mean, out float size)
  136. {
  137. if (idx < 0 || idx >= _numberOfClusters)
  138. {
  139. Debug.LogError("idx is out of bounds");
  140. mean = Vector3.zero;
  141. size = 1;
  142. return new List<Renderer>();
  143. }
  144. UpdateDataPointMeans(true);
  145. List<Renderer> gos = new List<Renderer>();
  146. mean = _clusters[idx];
  147. float longestDist = 0;
  148. for (int i = 0; i < _normalizedDataToCluster.Count; i++)
  149. {
  150. if (_normalizedDataToCluster[i].Cluster == idx)
  151. {
  152. float dist = Vector3.Distance(mean, _normalizedDataToCluster[i].center);
  153. if (dist > longestDist) longestDist = dist;
  154. gos.Add(_normalizedDataToCluster[i].gameObject.GetComponent<Renderer>());
  155. }
  156. }
  157. mean = _clusters[idx];
  158. size = longestDist; //todo should be greatest distance to mean
  159. return gos;
  160. }
  161. public void Cluster()
  162. {
  163. bool _changed = true;
  164. bool _success = true;
  165. InitializeCentroids();
  166. int maxIteration = _normalizedDataToCluster.Count * 1000;
  167. int _threshold = 0;
  168. while (_success == true && _changed == true && _threshold < maxIteration)
  169. {
  170. ++_threshold;
  171. _success = UpdateDataPointMeans(false);
  172. _changed = UpdateClusterMembership();
  173. }
  174. }
  175. }