ProbabilityTable.cs 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #if UNITY_EDITOR
  2. using UnityEngine;
  3. using System;
  4. using System.Collections.Generic;
  5. namespace O3DWB
  6. {
  7. // WARNING: CAN NOT BE SERIALIZED...
  8. [Serializable]
  9. public class ProbabilityTableEntry<EntityType> where EntityType : class
  10. {
  11. [SerializeField]
  12. public EntityType Entity;
  13. [SerializeField]
  14. public float Probability;
  15. [SerializeField]
  16. public float NormProbability;
  17. [SerializeField]
  18. public float CumulativeProbability;
  19. }
  20. [Serializable]
  21. public class ProbabilityTable<EntityType> where EntityType : class
  22. {
  23. [SerializeField]
  24. private List<ProbabilityTableEntry<EntityType>> _entries = new List<ProbabilityTableEntry<EntityType>>();
  25. public void AddEntity(EntityType entity, float probability)
  26. {
  27. if (entity == null || ContainsEntity(entity)) return;
  28. var entry = new ProbabilityTableEntry<EntityType>();
  29. entry.Entity = entity;
  30. entry.Probability = probability;
  31. _entries.Add(entry);
  32. }
  33. public bool ContainsEntity(EntityType entity)
  34. {
  35. return _entries.FindAll(item => item.Entity == entity).Count != 0;
  36. }
  37. public void SetEntityProbability(EntityType entity, float probability)
  38. {
  39. var entries = _entries.FindAll(item => item.Entity == entity);
  40. if (entries.Count == 0) return;
  41. entries[0].Probability = probability;
  42. }
  43. public EntityType PickEntity(float randomNumber)
  44. {
  45. foreach(var entry in _entries)
  46. {
  47. if (entry.CumulativeProbability >= randomNumber) return entry.Entity;
  48. }
  49. return null;
  50. }
  51. public void RemoveEntity(EntityType entity)
  52. {
  53. _entries.RemoveAll(item => item.Entity == entity);
  54. }
  55. public void RemoveAllEntities()
  56. {
  57. _entries.Clear();
  58. }
  59. public void Rebuild()
  60. {
  61. CalculateNormProbabilities();
  62. SortEntriesByNormProbability();
  63. CalculateCumulProbabilities();
  64. }
  65. private void CalculateNormProbabilities()
  66. {
  67. float sum = GetProbabilitySum();
  68. foreach(var entry in _entries)
  69. {
  70. entry.NormProbability = entry.Probability / sum;
  71. }
  72. }
  73. private float GetProbabilitySum()
  74. {
  75. float sum = 0.0f;
  76. foreach(var entry in _entries)
  77. {
  78. sum += entry.Probability;
  79. }
  80. return sum;
  81. }
  82. private void SortEntriesByNormProbability()
  83. {
  84. _entries.Sort(delegate(ProbabilityTableEntry<EntityType> e0, ProbabilityTableEntry<EntityType> e1)
  85. {
  86. return e0.NormProbability.CompareTo(e1.NormProbability);
  87. });
  88. }
  89. private void CalculateCumulProbabilities()
  90. {
  91. for(int eIndex = 0; eIndex < _entries.Count; ++eIndex)
  92. {
  93. var previous = eIndex > 0 ? _entries[eIndex - 1] : null;
  94. var current = _entries[eIndex];
  95. current.CumulativeProbability = current.NormProbability;
  96. if (previous != null) current.CumulativeProbability += previous.CumulativeProbability;
  97. }
  98. }
  99. }
  100. }
  101. #endif