diff --git a/CHANGELOG.md b/CHANGELOG.md index 76f23c02f..101142105 100755 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ ### 🐣新特性 * 【core 】 SyncFinisher增加setExecutorService方法(issue#IANKQ1@Gitee) * 【http 】 HttpConfig增加`setUseDefaultContentTypeIfNull`方法(issue#3719@Github) +* 【core 】 用ArrayList重新实现权重随机类:WeightListRandom(pr#3720@Github) ### 🐞Bug修复 * 【json 】 修复JSONConfig.setDateFormat设置后toBean无效问题(issue#3713@Github) diff --git a/hutool-core/src/main/java/cn/hutool/core/lang/WeightListPool.java b/hutool-core/src/main/java/cn/hutool/core/lang/WeightListPool.java deleted file mode 100644 index 0c490e6cc..000000000 --- a/hutool-core/src/main/java/cn/hutool/core/lang/WeightListPool.java +++ /dev/null @@ -1,136 +0,0 @@ -package cn.hutool.core.lang; - -import cn.hutool.core.util.RandomUtil; - -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.concurrent.ThreadLocalRandom; - -/** - * 动态按权重随机的随机池,底层是list实现。 - * - * @param 元素类型 - * @author 王叶峰 - * @date 2024-07-29 - */ -public class WeightListPool { - - /** - * 随机元素池 - */ - private final ArrayList> randomPool; - - public WeightListPool() { - randomPool = new ArrayList<>(); - } - - public WeightListPool(int poolSize) { - randomPool = new ArrayList<>(poolSize); - } - - public void add(E e, double weight) { - Assert.isTrue(weight > 0, "权重必须大于0!"); - randomPool.add(new EWeight<>(e, sumWeight() + weight)); - } - - public boolean remove(E e) { - boolean removed = false; - double weight = 0; - int i = 0; - Iterator> iterator = randomPool.iterator(); - while (iterator.hasNext()) { - EWeight ew = iterator.next(); - if (!removed && ew.e.equals(e)) { - iterator.remove(); - weight = ew.sumWeight - (i == 0 ? 0 : randomPool.get(i - 1).sumWeight);// 权重=当前权重-上一个权重 - removed = true; - } - if (removed) { - // 重新计算后续权重 - ew.sumWeight -= weight; - } - i++; - } - return removed; - } - - private double sumWeight() { - if (randomPool.isEmpty()) { - return 0; - } - return randomPool.get(randomPool.size() - 1).sumWeight; - } - - private void checkEmptyPool() { - if (isEmpty()) { - throw new IllegalArgumentException("随机池为空!"); - } - } - - public E random() { - checkEmptyPool(); - - if (randomPool.size() == 1) { - return randomPool.get(0).e; - } - ThreadLocalRandom random = RandomUtil.getRandom(); - double randVal = random.nextDouble() * sumWeight(); - return binarySearch(randVal); - } - - /** - * 二分查找小于等于key的最大值的元素 - * - * @param key 目标值 - * @return 随机池的一个元素或者null 当key大于所有元素的总权重时,返回null - */ - private E binarySearch(double key) { - int low = 0; - int high = randomPool.size() - 1; - - while (low <= high) { - int mid = (low + high) >>> 1; - double midVal = randomPool.get(mid).sumWeight; - - if (midVal < key) { - low = mid + 1; - } else if (midVal > key) { - high = mid - 1; - } else { - return randomPool.get(mid).e; - } - } - return randomPool.get(low).e; - } - - /** - * 按照给定的总权重随机出一个元素 - * - * @param weight 总权重 - * @return 随机池的一个元素或者null - */ - public E randomByWeight(double weight) { - Assert.isTrue(weight >= sumWeight(), "权重必须大于当前总权重!"); - ThreadLocalRandom random = RandomUtil.getRandom(); - double randVal = random.nextDouble() * sumWeight(); - if (randVal > sumWeight()) { - return null; - } - return binarySearch(randVal); - } - - public boolean isEmpty() { - return randomPool.isEmpty(); - } - - private static class EWeight { - final E e; - double sumWeight; - - public EWeight(E e, double sumWeight) { - this.e = e; - this.sumWeight = sumWeight; - } - } -} diff --git a/hutool-core/src/main/java/cn/hutool/core/lang/WeightListRandom.java b/hutool-core/src/main/java/cn/hutool/core/lang/WeightListRandom.java new file mode 100644 index 000000000..dff63753e --- /dev/null +++ b/hutool-core/src/main/java/cn/hutool/core/lang/WeightListRandom.java @@ -0,0 +1,167 @@ +package cn.hutool.core.lang; + +import cn.hutool.core.util.RandomUtil; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.concurrent.ThreadLocalRandom; + +/** + * 动态按权重随机的随机池,底层是list实现。 + * + * @param 元素类型 + * @author 王叶峰 + * @since 5.8.33 + */ +public class WeightListRandom implements Serializable { + private static final long serialVersionUID = 6902006276975764032L; + + /** + * 随机元素池 + */ + private final ArrayList> randomPool; + + /** + * 构造 + */ + public WeightListRandom() { + randomPool = new ArrayList<>(); + } + + /** + * 构造 + * + * @param poolSize 容量 + */ + public WeightListRandom(int poolSize) { + randomPool = new ArrayList<>(poolSize); + } + + /** + * 添加元素 + * + * @param e 元素 + * @param weight 权重 + */ + public void add(E e, double weight) { + Assert.isTrue(weight > 0, "权重必须大于0!"); + randomPool.add(new EWeight<>(e, sumWeight() + weight)); + } + + /** + * 移除元素 + * + * @param e 元素 + * @return 是否移除成功 + */ + public boolean remove(E e) { + boolean removed = false; + double weight = 0; + int i = 0; + Iterator> iterator = randomPool.iterator(); + while (iterator.hasNext()) { + EWeight ew = iterator.next(); + if (!removed && ew.e.equals(e)) { + iterator.remove(); + weight = ew.sumWeight - (i == 0 ? 0 : randomPool.get(i - 1).sumWeight);// 权重=当前权重-上一个权重 + removed = true; + } + if (removed) { + // 重新计算后续权重 + ew.sumWeight -= weight; + } + i++; + } + return removed; + } + + /** + * 随机出一个元素 + * + * @return 随机池的一个元素 + */ + public E next() { + checkEmptyPool(); + + if (randomPool.size() == 1) { + return randomPool.get(0).e; + } + ThreadLocalRandom random = RandomUtil.getRandom(); + double randVal = random.nextDouble() * sumWeight(); + return binarySearch(randVal); + } + + /** + * 按照给定的总权重随机出一个元素 + * + * @param weight 总权重 + * @return 随机池的一个元素或者null + */ + public E randomByWeight(double weight) { + Assert.isTrue(weight >= sumWeight(), "权重必须大于当前总权重!"); + ThreadLocalRandom random = RandomUtil.getRandom(); + double randVal = random.nextDouble() * sumWeight(); + if (randVal > sumWeight()) { + return null; + } + return binarySearch(randVal); + } + + /** + * 判断随机池是否为空 + * + * @return 是否为空 + */ + public boolean isEmpty() { + return randomPool.isEmpty(); + } + + private static class EWeight { + final E e; + double sumWeight; + + public EWeight(E e, double sumWeight) { + this.e = e; + this.sumWeight = sumWeight; + } + } + + /** + * 二分查找小于等于key的最大值的元素 + * + * @param key 目标值 + * @return 随机池的一个元素或者null 当key大于所有元素的总权重时,返回null + */ + private E binarySearch(double key) { + int low = 0; + int high = randomPool.size() - 1; + + while (low <= high) { + int mid = (low + high) >>> 1; + double midVal = randomPool.get(mid).sumWeight; + + if (midVal < key) { + low = mid + 1; + } else if (midVal > key) { + high = mid - 1; + } else { + return randomPool.get(mid).e; + } + } + return randomPool.get(low).e; + } + + private double sumWeight() { + if (randomPool.isEmpty()) { + return 0; + } + return randomPool.get(randomPool.size() - 1).sumWeight; + } + + private void checkEmptyPool() { + if (isEmpty()) { + throw new IllegalArgumentException("随机池为空!"); + } + } +} diff --git a/hutool-core/src/test/java/cn/hutool/core/lang/WeightListPoolTest.java b/hutool-core/src/test/java/cn/hutool/core/lang/WeightListRandomTest.java similarity index 83% rename from hutool-core/src/test/java/cn/hutool/core/lang/WeightListPoolTest.java rename to hutool-core/src/test/java/cn/hutool/core/lang/WeightListRandomTest.java index 902d612fc..c6da863aa 100644 --- a/hutool-core/src/test/java/cn/hutool/core/lang/WeightListPoolTest.java +++ b/hutool-core/src/test/java/cn/hutool/core/lang/WeightListRandomTest.java @@ -1,6 +1,7 @@ package cn.hutool.core.lang; import cn.hutool.core.util.RandomUtil; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.util.HashMap; @@ -8,14 +9,15 @@ import java.util.Map; import static org.junit.jupiter.api.Assertions.assertTrue; -public class WeightListPoolTest { +public class WeightListRandomTest { @Test - public void weightRandomTest() { + @Disabled + public void nextTest() { Map timesMap = new HashMap<>(); int size = 100; double sumWeight = 0.0; - WeightListPool pool = new WeightListPool<>(size); + WeightListRandom pool = new WeightListRandom<>(size); for (int i = 0; i < size; i++) { double weight = RandomUtil.randomDouble(100); pool.add(i, weight); @@ -26,7 +28,7 @@ public class WeightListPoolTest { double d = 0.0001;// 随机误差 int times = 100000000;// 随机次数 for (int i = 0; i < times; i++) { - timesMap.get(pool.random()).num++; + timesMap.get(pool.next()).num++; } double finalSumWeight = sumWeight; timesMap.forEach((key, times1) -> { @@ -37,9 +39,7 @@ public class WeightListPoolTest { } private static class Times { - int num; - double weight; public Times(double weight) {