!893 新增LambdaUtil,在一定程度上替代反射,提供更高性能的调用

Merge pull request !893 from javanasoda/v6-dev-lambda-factory
This commit is contained in:
Looly 2023-01-17 03:20:24 +00:00 committed by Gitee
commit 4c8a06e749
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 562 additions and 3 deletions

View File

@ -0,0 +1,113 @@
package cn.hutool.core.lang.func;
import cn.hutool.core.exceptions.UtilException;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.lang.Opt;
import cn.hutool.core.lang.mutable.MutableEntry;
import cn.hutool.core.map.WeakConcurrentMap;
import cn.hutool.core.reflect.LookupFactory;
import cn.hutool.core.reflect.MethodUtil;
import java.io.Serializable;
import java.lang.invoke.*;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static java.lang.invoke.LambdaMetafactory.FLAG_SERIALIZABLE;
import static java.lang.invoke.MethodType.methodType;
/**
* 以类似反射的方式动态创建Lambda在性能上有一定优势同时避免每次调用Lambda时创建匿名内部类
*
* @author nasodaengineer
*/
public class LambdaFactory {
private LambdaFactory() throws IllegalAccessException {
throw new IllegalAccessException();
}
private static final Map<MutableEntry<Class<?>, Method>, Object> CACHE = new WeakConcurrentMap<>();
/**
* 构建Lambda
* <pre>{@code
* class Something {
* private Long id;
* private String name;
* // ... 省略GetterSetter方法
* }
* Function<Something, Long> getIdFunction = LambdaFactory.buildLambda(Function.class, Something.class, "getId");
* BiConsumer<Something, String> setNameConsumer = LambdaFactory.buildLambda(BiConsumer.class, Something.class, "setName", String.class);
* }
* </pre>
*
* @param functionInterfaceType 接受Lambda的函数式接口类型
* @param methodClass 声明方法的类的类型
* @param methodName 方法名称
* @param paramTypes 方法参数数组
* @return 接受Lambda的函数式接口对象
*/
public static <F> F build(Class<F> functionInterfaceType, Class methodClass, String methodName, Class... paramTypes) {
return build(functionInterfaceType, MethodUtil.getMethod(methodClass, methodName, paramTypes));
}
/**
* 构建Lambda
*
* @param functionInterfaceType 接受Lambda的函数式接口类型
* @param method 方法对象
* @return 接受Lambda的函数式接口对象
*/
public static <F> F build(Class<F> functionInterfaceType, Method method) {
Assert.notNull(functionInterfaceType);
Assert.notNull(method);
MutableEntry<Class<?>, Method> cacheKey = new MutableEntry<>(functionInterfaceType, method);
//noinspection unchecked
return (F) CACHE.computeIfAbsent(cacheKey, key -> {
List<Method> abstractMethods = Arrays.stream(functionInterfaceType.getMethods())
.filter(m -> Modifier.isAbstract(m.getModifiers()))
.collect(Collectors.toList());
Assert.equals(abstractMethods.size(), 1, "不支持非函数式接口");
if (!method.isAccessible()) {
method.setAccessible(true);
}
Method invokeMethod = abstractMethods.get(0);
MethodHandles.Lookup caller = LookupFactory.lookup(method.getDeclaringClass());
String invokeName = invokeMethod.getName();
MethodType invokedType = methodType(functionInterfaceType);
MethodType samMethodType = methodType(invokeMethod.getReturnType(), invokeMethod.getParameterTypes());
MethodHandle implMethod = Opt.ofTry(() -> caller.unreflect(method)).get();
MethodType insMethodType = methodType(method.getReturnType(), method.getDeclaringClass(), method.getParameterTypes());
boolean isSerializable = Serializable.class.isAssignableFrom(functionInterfaceType);
try {
CallSite callSite = isSerializable ?
LambdaMetafactory.altMetafactory(
caller,
invokeName,
invokedType,
samMethodType,
implMethod,
insMethodType,
FLAG_SERIALIZABLE
) :
LambdaMetafactory.metafactory(
caller,
invokeName,
invokedType,
samMethodType,
implMethod,
insMethodType
);
return (F) callSite.getTarget().invoke();
} catch (Throwable e) {
throw new UtilException(e);
}
});
}
}

View File

@ -13,6 +13,8 @@ import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.function.BiConsumer;
import java.util.function.Function;
/**
* Lambda相关工具类
@ -130,7 +132,69 @@ public class LambdaUtil {
return BeanUtil.getFieldName(getMethodName(func));
}
/**
* 等效于 Obj::getXxx
*
* @param getMethod getter方法
* @param <T> 调用getter方法对象类型
* @param <R> getter方法返回值类型
* @return Obj::getXxx
*/
public static <T, R> Function<T, R> buildGetter(Method getMethod) {
return LambdaFactory.build(Function.class, getMethod);
}
/**
* 等效于 Obj::getXxx
*
* @param clazz 调用getter方法对象类
* @param fieldName 字段名称
* @param <T> 调用getter方法对象类型
* @param <R> getter方法返回值类型
* @return Obj::getXxx
*/
public static <T, R> Function<T, R> buildGetter(Class<T> clazz, String fieldName) {
return LambdaFactory.build(Function.class, BeanUtil.getBeanDesc(clazz).getGetter(fieldName));
}
/**
* 等效于 Obj::setXxx
*
* @param setMethod setter方法
* @param <T> 调用setter方法对象类型
* @param <P> setter方法返回的值类型
* @return Obj::setXxx
*/
public static <T, P> BiConsumer<T, P> buildSetter(Method setMethod) {
return LambdaFactory.build(BiConsumer.class, setMethod);
}
/**
* Obj::setXxx
*
* @param clazz 调用setter方法对象类
* @param fieldName 字段名称
* @param <T> 调用setter方法对象类型
* @param <P> setter方法返回的值类型
* @return Obj::setXxx
*/
public static <T, P> BiConsumer<T, P> buildSetter(Class<T> clazz, String fieldName) {
return LambdaFactory.build(BiConsumer.class, BeanUtil.getBeanDesc(clazz).getSetter(fieldName));
}
/**
* 等效于 Obj::method
*
* @param lambdaType 接受lambda的函数式接口类型
* @param clazz 调用类
* @param methodName 方法名
* @param paramsTypes 方法参数类型数组
* @param <F> 函数式接口类型
* @return Obj::method
*/
public static <F> F build(Class<F> lambdaType, Class<?> clazz, String methodName, Class... paramsTypes) {
return LambdaFactory.build(lambdaType, clazz, methodName, paramsTypes);
}
//region Private methods

View File

@ -0,0 +1,299 @@
package cn.hutool.core.lang.func;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.exceptions.UtilException;
import cn.hutool.core.reflect.MethodHandleUtil;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.SneakyThrows;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.lang.invoke.*;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Function;
/**
* @author nasodaengineer
*/
public class LambdaFactoryTest {
// @Test
@Test
public void testMethodNotMatch() {
try {
LambdaFactory.build(Function.class, Something.class, "setId", Long.class);
} catch (Exception e) {
Assert.assertTrue(e.getCause() instanceof LambdaConversionException);
}
}
@Test
public void buildLambdaTest() {
Something something = new Something();
something.setId(1L);
something.setName("name");
Function<Something, Long> get11 = LambdaFactory.build(Function.class, Something.class, "getId");
Function<Something, Long> get12 = LambdaFactory.build(Function.class, Something.class, "getId");
Assert.assertEquals(get11, get12);
Assert.assertEquals(something.getId(), get11.apply(something));
String name = "sname";
BiConsumer<Something, String> set = LambdaFactory.build(BiConsumer.class, Something.class, "setName", String.class);
set.accept(something, name);
Assert.assertEquals(something.getName(), name);
}
@Data
private static class Something {
private Long id;
private String name;
}
/**
* 简单的性能测试大多数情况下直接调用 快于 lambda 快于 反射 快于 代理
*
* @author nasodaengineer
*/
@RunWith(Parameterized.class)
public static class PerformanceTest {
@Parameterized.Parameter
public int count;
@Parameterized.Parameters
public static Collection<Integer> parameters() {
return ListUtil.of(1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000);
}
/**
* <p>lambda 运行1次耗时 7000 NANOSECONDS
* <p>reflect 运行1次耗时 11300 NANOSECONDS
* <p>hardCode 运行1次耗时 12800 NANOSECONDS
* <p>proxy 运行1次耗时 160200 NANOSECONDS
* <p>mh 运行1次耗时 197900 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行10次耗时 1500 NANOSECONDS
* <p>lambda 运行10次耗时 2200 NANOSECONDS
* <p>mh 运行10次耗时 11700 NANOSECONDS
* <p>proxy 运行10次耗时 14400 NANOSECONDS
* <p>reflect 运行10次耗时 28600 NANOSECONDS
* <p>--------------------------------------------
* <p>lambda 运行100次耗时 9300 NANOSECONDS
* <p>hardCode 运行100次耗时 14400 NANOSECONDS
* <p>mh 运行100次耗时 42900 NANOSECONDS
* <p>proxy 运行100次耗时 107900 NANOSECONDS
* <p>reflect 运行100次耗时 430800 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行1000次耗时 86300 NANOSECONDS
* <p>lambda 运行1000次耗时 101700 NANOSECONDS
* <p>reflect 运行1000次耗时 754700 NANOSECONDS
* <p>mh 运行1000次耗时 962200 NANOSECONDS
* <p>proxy 运行1000次耗时 1200500 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行10000次耗时 333000 NANOSECONDS
* <p>lambda 运行10000次耗时 367800 NANOSECONDS
* <p>mh 运行10000次耗时 999100 NANOSECONDS
* <p>proxy 运行10000次耗时 2766100 NANOSECONDS
* <p>reflect 运行10000次耗时 3157200 NANOSECONDS
* <p>--------------------------------------------
* <p>lambda 运行100000次耗时 571600 NANOSECONDS
* <p>hardCode 运行100000次耗时 1061700 NANOSECONDS
* <p>reflect 运行100000次耗时 1326800 NANOSECONDS
* <p>proxy 运行100000次耗时 3160900 NANOSECONDS
* <p>mh 运行100000次耗时 4137500 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行1000000次耗时 5066200 NANOSECONDS
* <p>lambda 运行1000000次耗时 5868700 NANOSECONDS
* <p>mh 运行1000000次耗时 8342700 NANOSECONDS
* <p>reflect 运行1000000次耗时 13009400 NANOSECONDS
* <p>proxy 运行1000000次耗时 21787800 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行10000000次耗时 51102700 NANOSECONDS
* <p>lambda 运行10000000次耗时 55007900 NANOSECONDS
* <p>mh 运行10000000次耗时 72751700 NANOSECONDS
* <p>reflect 运行10000000次耗时 92348800 NANOSECONDS
* <p>proxy 运行10000000次耗时 199705500 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行100000000次耗时 456094400 NANOSECONDS
* <p>lambda 运行100000000次耗时 562348600 NANOSECONDS
* <p>reflect 运行100000000次耗时 630433200 NANOSECONDS
* <p>mh 运行100000000次耗时 671914300 NANOSECONDS
* <p>proxy 运行100000000次耗时 1117192600 NANOSECONDS
* <p>--------------------------------------------
*/
@Test
@SneakyThrows
public void lambdaGetPerformanceTest() {
Something something = new Something();
something.setId(1L);
something.setName("name");
Method getByReflect = Something.class.getMethod("getId");
MethodHandle getByMh = MethodHandleUtil.findMethod(Something.class, "getId", MethodType.methodType(Long.class));
Function getByProxy = MethodHandleProxies.asInterfaceInstance(Function.class, MethodHandles.lookup().unreflect(getByReflect));
Function getByLambda = LambdaFactory.build(Function.class, getByReflect);
Task lambdaTask = new Task("lambda", () -> getByLambda.apply(something));
Task mhTask = new Task("mh", () -> getByMh.invoke(something));
Task proxyTask = new Task("proxy", () -> getByProxy.apply(something));
Task reflectTask = new Task("reflect", () -> getByReflect.invoke(something));
Task hardCodeTask = new Task("hardCode", () -> something.getId());
Task[] tasks = {hardCodeTask, lambdaTask, mhTask, proxyTask, reflectTask};
loop(count, tasks);
}
/**
* <p>hardCode 运行1次耗时 7600 NANOSECONDS
* <p>lambda 运行1次耗时 12400 NANOSECONDS
* <p>reflect 运行1次耗时 19900 NANOSECONDS
* <p>mh 运行1次耗时 139900 NANOSECONDS
* <p>proxy 运行1次耗时 261300 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行10次耗时 1700 NANOSECONDS
* <p>lambda 运行10次耗时 2600 NANOSECONDS
* <p>mh 运行10次耗时 3900 NANOSECONDS
* <p>proxy 运行10次耗时 20400 NANOSECONDS
* <p>reflect 运行10次耗时 26500 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行100次耗时 9000 NANOSECONDS
* <p>lambda 运行100次耗时 16900 NANOSECONDS
* <p>mh 运行100次耗时 32200 NANOSECONDS
* <p>proxy 运行100次耗时 315700 NANOSECONDS
* <p>reflect 运行100次耗时 604300 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行1000次耗时 123500 NANOSECONDS
* <p>lambda 运行1000次耗时 253100 NANOSECONDS
* <p>mh 运行1000次耗时 644600 NANOSECONDS
* <p>reflect 运行1000次耗时 793100 NANOSECONDS
* <p>proxy 运行1000次耗时 1111100 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行10000次耗时 346800 NANOSECONDS
* <p>lambda 运行10000次耗时 524900 NANOSECONDS
* <p>mh 运行10000次耗时 931000 NANOSECONDS
* <p>reflect 运行10000次耗时 2046500 NANOSECONDS
* <p>proxy 运行10000次耗时 3108400 NANOSECONDS
* <p>--------------------------------------------
* <p>lambda 运行100000次耗时 608300 NANOSECONDS
* <p>hardCode 运行100000次耗时 1095600 NANOSECONDS
* <p>mh 运行100000次耗时 1430100 NANOSECONDS
* <p>reflect 运行100000次耗时 1558400 NANOSECONDS
* <p>proxy 运行100000次耗时 5566000 NANOSECONDS
* <p>--------------------------------------------
* <p>lambda 运行1000000次耗时 6261000 NANOSECONDS
* <p>hardCode 运行1000000次耗时 6570200 NANOSECONDS
* <p>mh 运行1000000次耗时 8703300 NANOSECONDS
* <p>reflect 运行1000000次耗时 16437800 NANOSECONDS
* <p>proxy 运行1000000次耗时 22161100 NANOSECONDS
* <p>--------------------------------------------
* <p>lambda 运行10000000次耗时 60895800 NANOSECONDS
* <p>hardCode 运行10000000次耗时 61055300 NANOSECONDS
* <p>mh 运行10000000次耗时 69782400 NANOSECONDS
* <p>reflect 运行10000000次耗时 78078800 NANOSECONDS
* <p>proxy 运行10000000次耗时 193799800 NANOSECONDS
* <p>--------------------------------------------
* <p>hardCode 运行100000000次耗时 499826200 NANOSECONDS
* <p>lambda 运行100000000次耗时 537454100 NANOSECONDS
* <p>reflect 运行100000000次耗时 673561400 NANOSECONDS
* <p>mh 运行100000000次耗时 700774100 NANOSECONDS
* <p>proxy 运行100000000次耗时 1169452400 NANOSECONDS
* <p>--------------------------------------------
*/
@Test
@SneakyThrows
public void lambdaSetPerformanceTest() {
Something something = new Something();
something.setId(1L);
something.setName("name");
Method setByReflect = Something.class.getMethod("setName", String.class);
MethodHandle setByMh = MethodHandleUtil.findMethod(Something.class, "setName", MethodType.methodType(Void.TYPE, String.class));
BiConsumer setByProxy = MethodHandleProxies.asInterfaceInstance(BiConsumer.class, setByMh);
BiConsumer setByLambda = LambdaFactory.build(BiConsumer.class, setByReflect);
String name = "name1";
Task lambdaTask = new Task("lambda", () -> {
setByLambda.accept(something, name);
return null;
});
Task proxyTask = new Task("proxy", () -> {
setByProxy.accept(something, name);
return null;
});
Task mhTask = new Task("mh", () -> {
setByMh.invoke(something, name);
return null;
});
Task reflectTask = new Task("reflect", () -> {
setByReflect.invoke(something, name);
return null;
});
Task hardCodeTask = new Task("hardCode", () -> {
something.setName(name);
return null;
});
Task[] tasks = {hardCodeTask, lambdaTask, proxyTask, mhTask, reflectTask};
loop(count, tasks);
}
@SneakyThrows
private void loop(int count, Task... tasks) {
Arrays.stream(tasks)
.peek(task -> {
LambdaFactoryTest.SupplierThrowable runnable = task.getRunnable();
long cost = System.nanoTime();
for (int i = 0; i < count; i++) {
runnable.get();
}
cost = System.nanoTime() - cost;
task.setCost(cost);
task.setCount(count);
})
.sorted(Comparator.comparing(Task::getCost))
.map(Task::format)
.forEach(System.out::println);
System.out.println("--------------------------------------------");
}
@Getter
private class Task {
private String name;
private LambdaFactoryTest.SupplierThrowable<?> runnable;
@Setter
private long cost;
@Setter
private Integer count;
public Task(String name, LambdaFactoryTest.SupplierThrowable<?> runnable) {
this.name = name;
this.runnable = runnable;
}
public String format() {
TimeUnit timeUnit = TimeUnit.NANOSECONDS;
return String.format("%-10s 运行%d次耗时 %d %s", name, count, timeUnit.convert(cost, TimeUnit.NANOSECONDS), timeUnit.name());
}
}
}
@FunctionalInterface
interface SupplierThrowable<T> {
T get0() throws Throwable;
default T get() {
try {
return get0();
} catch (Throwable e) {
throw new RuntimeException(e);
}
}
}
}

View File

@ -1,13 +1,19 @@
package cn.hutool.core.lang.func;
import cn.hutool.core.lang.Tuple;
import cn.hutool.core.reflect.MethodUtil;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.FieldNameConstants;
import org.junit.Assert;
import org.junit.Test;
import java.io.Serializable;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Stream;
public class LambdaUtilTest {
@ -82,8 +88,10 @@ public class LambdaUtilTest {
// 一些特殊的lambda
Assert.assertEquals("T", LambdaUtil.<SerFunction<Object, Stream<?>>>resolve(Stream::of).getParameterTypes()[0].getTypeName());
Assert.assertEquals(MyTeacher[][].class, LambdaUtil.<SerFunction<Integer, MyTeacher[][]>>resolve(MyTeacher[][]::new).getReturnType());
Assert.assertEquals(Integer[][][].class, LambdaUtil.<SerConsumer<Integer[][][]>>resolve(a -> {}).getParameterTypes()[0]);
Assert.assertEquals(Integer[][][].class, LambdaUtil.resolve((Serializable & SerConsumer3<Integer[][][], Integer[][], Integer>) (a, b, c) -> {}).getParameterTypes()[0]);
Assert.assertEquals(Integer[][][].class, LambdaUtil.<SerConsumer<Integer[][][]>>resolve(a -> {
}).getParameterTypes()[0]);
Assert.assertEquals(Integer[][][].class, LambdaUtil.resolve((Serializable & SerConsumer3<Integer[][][], Integer[][], Integer>) (a, b, c) -> {
}).getParameterTypes()[0]);
}).forEach(Runnable::run);
}
@ -136,11 +144,86 @@ public class LambdaUtilTest {
Assert.assertEquals(MyTeacher.class, LambdaUtil.getRealClass(lambda));
}, () -> {
// 数组测试
final SerConsumer<String[]> lambda = (String[] stringList) -> {};
final SerConsumer<String[]> lambda = (String[] stringList) -> {
};
Assert.assertEquals(String[].class, LambdaUtil.getRealClass(lambda));
}).forEach(Runnable::run);
}
@Test
public void getterTest() {
Bean bean = new Bean();
bean.setId(2L);
Function<Bean, Long> getId = LambdaUtil.buildGetter(MethodUtil.getMethod(Bean.class, "getId"));
Function<Bean, Long> getId2 = LambdaUtil.buildGetter(Bean.class, Bean.Fields.id);
Assert.assertEquals(getId, getId2);
Assert.assertEquals(bean.getId(), getId.apply(bean));
}
@Test
public void setterTest() {
Bean bean = new Bean();
bean.setId(2L);
bean.setFlag(false);
BiConsumer<Bean, Long> setId = LambdaUtil.buildSetter(MethodUtil.getMethod(Bean.class, "setId", Long.class));
BiConsumer<Bean, Long> setId2 = LambdaUtil.buildSetter(Bean.class, Bean.Fields.id);
BiConsumer<Bean, Boolean> setFlag = LambdaUtil.buildSetter(Bean.class, Bean.Fields.flag);
Assert.assertEquals(setId, setId2);
setId.accept(bean, 3L);
setFlag.accept(bean, true);
Assert.assertEquals(3L, (long) bean.getId());
Assert.assertTrue(bean.isFlag());
}
@Test
public void lambdaTest() {
Bean bean = new Bean();
bean.setId(1L);
bean.setPid(0L);
bean.setFlag(true);
BiFunction<Bean, String, Tuple> uniqueKeyFunction = LambdaUtil.build(BiFunction.class, Bean.class, "uniqueKey", String.class);
Function4<Tuple, Bean, String, Integer, Double> paramsFunction = LambdaUtil.build(Function4.class, Bean.class, "params", String.class, Integer.class, Double.class);
Assert.assertEquals(bean.uniqueKey("test"), uniqueKeyFunction.apply(bean, "test"));
Assert.assertEquals(bean.params("test", 1, 0.5), paramsFunction.apply(bean, "test", 1, 0.5));
}
@FunctionalInterface
interface Function4<R, P1, P2, P3, P4> {
R apply(P1 p1, P2 p2, P3 p3, P4 p4);
}
@Data
@FieldNameConstants
private static class Bean {
Long id;
Long pid;
boolean flag;
private Tuple uniqueKey(String name) {
return new Tuple(id, pid, flag, name);
}
public Tuple params(String name, Integer length, Double score) {
return new Tuple(name, length, score);
}
public static Function<Bean, Long> idGetter() {
return Bean::getId;
}
public Function<Bean, Long> idGet() {
return bean -> bean.id;
}
public Function<Bean, Long> idGetting() {
return Bean::getId;
}
}
@Data
@AllArgsConstructor
static class MyStudent {