diff --git a/hutool-core/src/main/java/org/dromara/hutool/core/lang/selector/WeightListRandomSelector.java b/hutool-core/src/main/java/org/dromara/hutool/core/lang/selector/WeightListRandomSelector.java new file mode 100644 index 000000000..006230c3f --- /dev/null +++ b/hutool-core/src/main/java/org/dromara/hutool/core/lang/selector/WeightListRandomSelector.java @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2024 Hutool Team and hutool.cn + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.dromara.hutool.core.lang.selector; + +import org.dromara.hutool.core.lang.Assert; +import org.dromara.hutool.core.util.ObjUtil; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * 动态按权重随机的随机池,底层是list实现。 + * 原理为加入的{@link WeightObj}依次增加权重,随机时根据权重计算随机值,然后二分查找小于等于随机值的权重,返回对应的元素。
+ * 我们假设随机池中有4个对象,其权重为4,5,1,6,权重越高,'-'越多,那么随机池如下: + *
{@code
+ *     [obj1,  obj2, obj3, obj4  ]
+ *     [----, -----,  -  , ------]
+ * }
+ * 然后最后一个元素的权重值为总权重值,即obj2的权重值为obj1权重+obj2本身权重,依次类推。
+ * 我们取一个总权重范围的随机数,根据随机数在'-'列表中的位置,找到对应的obj即随机到的对象。 + * + * @param 元素类型 + * @author 王叶峰 + */ +public class WeightListRandomSelector implements Selector, Serializable { + private static final long serialVersionUID = 1L; + + /** + * 随机元素池 + */ + private final List> randomPool; + + /** + * 构造 + */ + public WeightListRandomSelector() { + randomPool = new ArrayList<>(); + } + + /** + * 构造 + * + * @param poolSize 初始随机池大小 + */ + public WeightListRandomSelector(final int poolSize) { + randomPool = new ArrayList<>(poolSize); + } + + /** + * 增加随机种子 + * + * @param e 随机对象 + * @param weight 权重 + */ + public void add(final E e, final int weight) { + Assert.isTrue(weight > 0, "权重必须大于0!"); + randomPool.add(new WeightObj<>(e, sumWeight() + weight)); + } + + /** + * 移除随机种子 + * + * @param e 随机对象 + * @return 是否移除成功 + */ + public boolean remove(final E e) { + boolean removed = false; + int weight = 0; + int i = 0; + final Iterator> iterator = randomPool.iterator(); + WeightObj ew; + while (iterator.hasNext()) { + ew = iterator.next(); + if (!removed && ObjUtil.equals(ew.obj, e)) { + iterator.remove(); + weight = ew.weight - (i == 0 ? 0 : randomPool.get(i - 1).weight);// 权重=当前权重-上一个权重 + removed = true; + } + if (removed) { + // 重新计算后续权重 + ew.weight -= weight; + } + i++; + } + return removed; + } + + /** + * 判断是否为空 + * + * @return 是否为空 + */ + public boolean isEmpty() { + return randomPool.isEmpty(); + } + + @Override + public E select() { + if (isEmpty()) { + return null; + } + if (randomPool.size() == 1) { + return randomPool.get(0).obj; + } + return binarySearch((int) (sumWeight() * Math.random())); + } + + /** + * 二分查找小于等于key的最大值的元素 + * + * @param randomWeight 随机权重值,查找这个权重对应的元素 + * @return 随机池的一个元素或者null 当key大于所有元素的总权重时,返回null + */ + private E binarySearch(final int randomWeight) { + int low = 0; + int high = randomPool.size() - 1; + + while (low <= high) { + final int mid = (low + high) >>> 1; + final int midWeight = randomPool.get(mid).weight; + + if (midWeight < randomWeight) { + low = mid + 1; + } else if (midWeight > randomWeight) { + high = mid - 1; + } else { + return randomPool.get(mid).obj; + } + } + return randomPool.get(low).obj; + } + + private int sumWeight() { + if (randomPool.isEmpty()) { + return 0; + } + return randomPool.get(randomPool.size() - 1).weight; + } +} diff --git a/hutool-core/src/main/java/org/dromara/hutool/core/lang/selector/WeightObj.java b/hutool-core/src/main/java/org/dromara/hutool/core/lang/selector/WeightObj.java index d2ad128d4..489c6790c 100644 --- a/hutool-core/src/main/java/org/dromara/hutool/core/lang/selector/WeightObj.java +++ b/hutool-core/src/main/java/org/dromara/hutool/core/lang/selector/WeightObj.java @@ -29,7 +29,7 @@ public class WeightObj { /** 对象 */ protected T obj; /** 权重 */ - protected final int weight; + protected int weight; /** * 构造 diff --git a/hutool-core/src/test/java/org/dromara/hutool/core/lang/selector/WeightListRandomSelectorTest.java b/hutool-core/src/test/java/org/dromara/hutool/core/lang/selector/WeightListRandomSelectorTest.java new file mode 100644 index 000000000..2661e7cc1 --- /dev/null +++ b/hutool-core/src/test/java/org/dromara/hutool/core/lang/selector/WeightListRandomSelectorTest.java @@ -0,0 +1,77 @@ +package org.dromara.hutool.core.lang.selector; + +import org.dromara.hutool.core.date.DateUtil; +import org.dromara.hutool.core.date.StopWatch; +import org.dromara.hutool.core.lang.Console; +import org.dromara.hutool.core.util.RandomUtil; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +public class WeightListRandomSelectorTest { + + @Test + @Disabled + public void weightListRandomTest() { + final Map timesMap = new HashMap<>(); + final int size = 100; + int sumWeight = 0; + final WeightListRandomSelector pool = new WeightListRandomSelector<>(size); + for (int i = 0; i < size; i++) { + final int weight = RandomUtil.randomInt(1, 100); + pool.add(i, weight); + sumWeight += weight; + timesMap.put(i, new Times(weight)); + } + + final int times = 100000000;// 随机次数 + for (int i = 0; i < times; i++) { + timesMap.get(pool.select()).num++; + } + + final int finalSumWeight = sumWeight; + timesMap.forEach((key, times1) -> { + final double expected = times1.weight / finalSumWeight;// 期望概率 + final double actual = (double) timesMap.get(key).num / times;// 真实随机概率 + Console.log(expected, actual); + }); + } + + private static class Times { + int num; + double weight; + + public Times(final double weight) { + this.weight = weight; + } + } + + @Test + public void weightRandomBenchTest() { + final int size = 100; + final WeightListRandomSelector pool = new WeightListRandomSelector<>(size); + final WeightRandomSelector pool2 = new WeightRandomSelector<>(); + for (int i = 0; i < size; i++) { + final int weight = RandomUtil.randomInt(1, 100); + pool.add(i, weight); + pool2.add(i, weight); + } + + final int count = 1000; + final StopWatch stopWatch = DateUtil.createStopWatch(); + stopWatch.start("WeightListRandomSelector"); + for (int i = 0; i < count; i++) { + pool.select(); + } + stopWatch.stop(); + stopWatch.start("WeightRandomSelector"); + for (int i = 0; i < count; i++) { + pool2.select(); + } + stopWatch.stop(); + + //Console.log(stopWatch.prettyPrint()); + } +}