add WeightListRandomSelector

This commit is contained in:
Looly 2024-09-04 20:16:46 +08:00
parent 41e3037d4b
commit 2f4af776c0
3 changed files with 233 additions and 1 deletions

View File

@ -0,0 +1,155 @@
/*
* Copyright (c) 2024 Hutool Team and hutool.cn
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.dromara.hutool.core.lang.selector;
import org.dromara.hutool.core.lang.Assert;
import org.dromara.hutool.core.util.ObjUtil;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
* 动态按权重随机的随机池底层是list实现
* 原理为加入的{@link WeightObj}依次增加权重随机时根据权重计算随机值然后二分查找小于等于随机值的权重返回对应的元素<br>
* 我们假设随机池中有4个对象其权重为4516权重越高'-'越多那么随机池如下
* <pre>{@code
* [obj1, obj2, obj3, obj4 ]
* [----, -----, - , ------]
* }</pre>
* 然后最后一个元素的权重值为总权重值即obj2的权重值为obj1权重+obj2本身权重依次类推<br>
* 我们取一个总权重范围的随机数根据随机数在'-'列表中的位置找到对应的obj即随机到的对象
*
* @param <E> 元素类型
* @author 王叶峰
*/
public class WeightListRandomSelector<E> implements Selector<E>, Serializable {
private static final long serialVersionUID = 1L;
/**
* 随机元素池
*/
private final List<WeightObj<E>> randomPool;
/**
* 构造
*/
public WeightListRandomSelector() {
randomPool = new ArrayList<>();
}
/**
* 构造
*
* @param poolSize 初始随机池大小
*/
public WeightListRandomSelector(final int poolSize) {
randomPool = new ArrayList<>(poolSize);
}
/**
* 增加随机种子
*
* @param e 随机对象
* @param weight 权重
*/
public void add(final E e, final int weight) {
Assert.isTrue(weight > 0, "权重必须大于0");
randomPool.add(new WeightObj<>(e, sumWeight() + weight));
}
/**
* 移除随机种子
*
* @param e 随机对象
* @return 是否移除成功
*/
public boolean remove(final E e) {
boolean removed = false;
int weight = 0;
int i = 0;
final Iterator<WeightObj<E>> iterator = randomPool.iterator();
WeightObj<E> ew;
while (iterator.hasNext()) {
ew = iterator.next();
if (!removed && ObjUtil.equals(ew.obj, e)) {
iterator.remove();
weight = ew.weight - (i == 0 ? 0 : randomPool.get(i - 1).weight);// 权重=当前权重-上一个权重
removed = true;
}
if (removed) {
// 重新计算后续权重
ew.weight -= weight;
}
i++;
}
return removed;
}
/**
* 判断是否为空
*
* @return 是否为空
*/
public boolean isEmpty() {
return randomPool.isEmpty();
}
@Override
public E select() {
if (isEmpty()) {
return null;
}
if (randomPool.size() == 1) {
return randomPool.get(0).obj;
}
return binarySearch((int) (sumWeight() * Math.random()));
}
/**
* 二分查找小于等于key的最大值的元素
*
* @param randomWeight 随机权重值查找这个权重对应的元素
* @return 随机池的一个元素或者null 当key大于所有元素的总权重时返回null
*/
private E binarySearch(final int randomWeight) {
int low = 0;
int high = randomPool.size() - 1;
while (low <= high) {
final int mid = (low + high) >>> 1;
final int midWeight = randomPool.get(mid).weight;
if (midWeight < randomWeight) {
low = mid + 1;
} else if (midWeight > randomWeight) {
high = mid - 1;
} else {
return randomPool.get(mid).obj;
}
}
return randomPool.get(low).obj;
}
private int sumWeight() {
if (randomPool.isEmpty()) {
return 0;
}
return randomPool.get(randomPool.size() - 1).weight;
}
}

View File

@ -29,7 +29,7 @@ public class WeightObj<T> {
/** 对象 */
protected T obj;
/** 权重 */
protected final int weight;
protected int weight;
/**
* 构造

View File

@ -0,0 +1,77 @@
package org.dromara.hutool.core.lang.selector;
import org.dromara.hutool.core.date.DateUtil;
import org.dromara.hutool.core.date.StopWatch;
import org.dromara.hutool.core.lang.Console;
import org.dromara.hutool.core.util.RandomUtil;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.HashMap;
import java.util.Map;
public class WeightListRandomSelectorTest {
@Test
@Disabled
public void weightListRandomTest() {
final Map<Integer, Times> timesMap = new HashMap<>();
final int size = 100;
int sumWeight = 0;
final WeightListRandomSelector<Integer> pool = new WeightListRandomSelector<>(size);
for (int i = 0; i < size; i++) {
final int weight = RandomUtil.randomInt(1, 100);
pool.add(i, weight);
sumWeight += weight;
timesMap.put(i, new Times(weight));
}
final int times = 100000000;// 随机次数
for (int i = 0; i < times; i++) {
timesMap.get(pool.select()).num++;
}
final int finalSumWeight = sumWeight;
timesMap.forEach((key, times1) -> {
final double expected = times1.weight / finalSumWeight;// 期望概率
final double actual = (double) timesMap.get(key).num / times;// 真实随机概率
Console.log(expected, actual);
});
}
private static class Times {
int num;
double weight;
public Times(final double weight) {
this.weight = weight;
}
}
@Test
public void weightRandomBenchTest() {
final int size = 100;
final WeightListRandomSelector<Integer> pool = new WeightListRandomSelector<>(size);
final WeightRandomSelector<Integer> pool2 = new WeightRandomSelector<>();
for (int i = 0; i < size; i++) {
final int weight = RandomUtil.randomInt(1, 100);
pool.add(i, weight);
pool2.add(i, weight);
}
final int count = 1000;
final StopWatch stopWatch = DateUtil.createStopWatch();
stopWatch.start("WeightListRandomSelector");
for (int i = 0; i < count; i++) {
pool.select();
}
stopWatch.stop();
stopWatch.start("WeightRandomSelector");
for (int i = 0; i < count; i++) {
pool2.select();
}
stopWatch.stop();
//Console.log(stopWatch.prettyPrint());
}
}