kmeans.ts 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. export interface Label<Type> {
  2. centroid: Type;
  3. data: Type[];
  4. }
  5. export type DistanceFunc<Type> = (a: Type, b: Type) => number;
  6. export interface shuffleItem<Type> {
  7. item: Type,
  8. random: number,
  9. }
  10. /**
  11. * 从数据集中随机挑选指定数量的数据
  12. * @param dataset
  13. * @param size
  14. * @returns
  15. */
  16. export function randomPickup<Type>(dataset: Type[], size: number): Type[] {
  17. if (dataset.length <= size) return dataset;
  18. let shuffle: Type[] = dataset.map(item => {
  19. return { item: item, random: Math.random() } as shuffleItem<Type>;
  20. })
  21. .sort((a: shuffleItem<Type>, b: shuffleItem<Type>) => { return a.random - b.random; })
  22. .map((si: shuffleItem<Type>) => si.item).slice(0, size);
  23. return shuffle;
  24. }
  25. /**
  26. * 挑选数据集中心点
  27. * @param dataset
  28. * @param distanceFunc
  29. * @returns
  30. */
  31. export function pickupCentroid<Type>(dataset: Type[], distanceFunc: DistanceFunc<Type>): Type {
  32. let minDist = Number.MAX_SAFE_INTEGER;
  33. let minDistIndex = 0;
  34. for (var i = 0; i < dataset.length; i++) {
  35. let dist = dataset.reduce((r: number, c: Type) => {
  36. return r + distanceFunc(c, dataset[i]);
  37. }, 0);
  38. if (dist < minDist) {
  39. minDist = dist;
  40. minDistIndex = i;
  41. }
  42. }
  43. return dataset[minDistIndex];
  44. }
  45. /**
  46. * 验证 k-means结果
  47. * @param labels
  48. * @param distanceFunc
  49. * @param minDist
  50. * @returns
  51. */
  52. export function validateKmeans<Type>(labels: Label<Type>[], distanceFunc: DistanceFunc<Type>, minDist): boolean {
  53. for (var i = 0; i < labels.length; i++) {
  54. let label = labels[i];
  55. for (var j = 0; j < label.data.length; j++) {
  56. if (distanceFunc(label.centroid, label.data[j]) >= minDist) return false;
  57. }
  58. }
  59. return true;
  60. }
  61. /**
  62. *
  63. * @param dataset
  64. * @param k
  65. * @param distanceFunc
  66. * @returns
  67. */
  68. export function kmeans<Type>(dataset: Type[], k: number, distanceFunc: DistanceFunc<Type>): Label<Type>[] {
  69. let centroids: Type[] = randomPickup<Type>(dataset, k);
  70. let loop = 0;
  71. let labels: Label<Type>[];
  72. do {
  73. labels = centroids.map((item: Type) => {
  74. return { centroid: item, data: [] } as Label<Type>;
  75. })
  76. //按中心点进行归类
  77. for (var i = 0; i < dataset.length; i++) {
  78. let item: Type = dataset[i];
  79. let nearestLabel: Label<Type> = null;
  80. let neareastDist: number = 0;
  81. let dist: number;
  82. for (var j = 0; j < labels.length; j++) {
  83. dist = distanceFunc(item, labels[j].centroid);
  84. if (nearestLabel == null || dist < neareastDist) {
  85. nearestLabel = labels[j];
  86. neareastDist = dist;
  87. }
  88. }
  89. nearestLabel.data.push(item);
  90. }
  91. //validate Result
  92. if (validateKmeans(labels, distanceFunc, 4) || loop > 100) break;
  93. //重新计算中心点
  94. centroids = labels.map(label => pickupCentroid(label.data, distanceFunc))
  95. loop++;
  96. } while (true);
  97. console.log('loop:', loop);
  98. return labels;
  99. }
  100. export interface Split<Type> {
  101. mergable: Type[];
  102. standalone: Type[];
  103. }
  104. export function splitByMinDistance<Type>(dataset: Type[], minDist: number = 2.3, distanceFunc: DistanceFunc<Type>): Split<Type> {
  105. let standalone: Type[] = [];
  106. let mergable: Type[] = [];
  107. for (var i = 0; i < dataset.length; i++) {
  108. let nereasets = dataset.filter(item => distanceFunc(item, dataset[i]) < minDist);
  109. if (nereasets.length <= 1) standalone.push(dataset[i]);
  110. else mergable.push(dataset[i]);
  111. }
  112. return { mergable, standalone };
  113. }
  114. export interface Neigbours<Type> {
  115. self: Type;
  116. neigbours: Type[];
  117. }
  118. /**
  119. * Generic merge by distance.
  120. * @param dataset
  121. * @param minDist
  122. * @param distanceFunc
  123. * @returns
  124. */
  125. export function simpleMerge<Type>(dataset: Type[], minDist: number = 2.3, distanceFunc: DistanceFunc<Type>): Neigbours<Type>[] {
  126. let data = [...dataset];
  127. let result = [];
  128. let list: Neigbours<Type>[] = data.map((item: Type) => ({ self: item, neigbours: [] }));
  129. //找到每个颜色最近的颜色
  130. list.forEach(nei => {
  131. nei.neigbours = data.filter(item => distanceFunc(item, nei.self) < minDist);
  132. })
  133. do {
  134. list = list.sort((a, b) => a.neigbours.length - b.neigbours.length);
  135. //console.log(list.length);
  136. let best = list.pop();
  137. result.push(best);
  138. list = list.filter(nei => best.neigbours.indexOf(nei.self) < 0)
  139. if (best.neigbours.length > 1) {
  140. list.forEach(nei => {
  141. nei.neigbours = nei.neigbours.filter(item => best.neigbours.indexOf(item) < 0);
  142. })
  143. }
  144. } while (list.length > 0);
  145. return result;
  146. }