diff --git a/hutool-core/src/main/java/cn/hutool/core/lang/func/LambdaFactory.java b/hutool-core/src/main/java/cn/hutool/core/lang/func/LambdaFactory.java new file mode 100644 index 000000000..f8c58d79f --- /dev/null +++ b/hutool-core/src/main/java/cn/hutool/core/lang/func/LambdaFactory.java @@ -0,0 +1,117 @@ +package cn.hutool.core.lang.func; + +import cn.hutool.core.lang.Assert; +import cn.hutool.core.lang.Opt; +import cn.hutool.core.lang.Tuple; +import cn.hutool.core.map.WeakConcurrentMap; +import cn.hutool.core.reflect.LookupFactory; +import cn.hutool.core.reflect.MethodUtil; + +import java.io.Serializable; +import java.lang.invoke.*; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static java.lang.invoke.LambdaMetafactory.FLAG_SERIALIZABLE; +import static java.lang.invoke.MethodType.methodType; + +/** + * 以类似反射的方式动态创建Lambda,在性能上有一定优势,同时避免每次调用Lambda时创建匿名内部类 + * + * @author nasodaengineer + */ +public class LambdaFactory { + + private LambdaFactory() throws IllegalAccessException { + throw new IllegalAccessException(); + } + + private static final Map CACHE = new WeakConcurrentMap<>(); + + /** + * 构建Lambda + *
{@code
+	 * class Something {
+	 *     private Long id;
+	 *     private String name;
+	 *     // ... 省略GetterSetter方法
+	 * }
+	 * Function getIdFunction = LambdaFactory.buildLambda(Function.class, Something.class, "getId");
+	 * BiConsumer setNameConsumer = LambdaFactory.buildLambda(BiConsumer.class, Something.class, "setName", String.class);
+	 * }
+	 * 
+ * + * @param functionInterfaceType 接受Lambda的函数式接口类型 + * @param methodClass 声明方法的类的类型 + * @param methodName 方法名称 + * @param paramTypes 方法参数数组 + * @return 接受Lambda的函数式接口对象 + */ + public static F buildLambda(Class functionInterfaceType, Class methodClass, String methodName, Class... paramTypes) { + return buildLambda(functionInterfaceType, MethodUtil.getMethod(methodClass, methodName, paramTypes)); + } + + /** + * 构建Lambda + * + * @param functionInterfaceType 接受Lambda的函数式接口类型 + * @param method 方法对象 + * @return 接受Lambda的函数式接口对象 + */ + public static F buildLambda(Class functionInterfaceType, Method method) { + Assert.notNull(functionInterfaceType); + Assert.notNull(method); + Tuple cacheKey = new Tuple(functionInterfaceType, method); + Object cacheValue = CACHE.get(cacheKey); + if (null != cacheValue) { + //noinspection unchecked + return (F) cacheValue; + } + List abstractMethods = Arrays.stream(functionInterfaceType.getMethods()) + .filter(m -> Modifier.isAbstract(m.getModifiers())) + .collect(Collectors.toList()); + Assert.equals(abstractMethods.size(), 1, "不支持非函数式接口"); + if (!method.isAccessible()) { + method.setAccessible(true); + } + Method invokeMethod = abstractMethods.get(0); + MethodHandles.Lookup caller = LookupFactory.lookup(method.getDeclaringClass()); + String invokeName = invokeMethod.getName(); + MethodType invokedType = methodType(functionInterfaceType); + MethodType samMethodType = methodType(invokeMethod.getReturnType(), invokeMethod.getParameterTypes()); + MethodHandle implMethod = Opt.ofTry(() -> caller.unreflect(method)).get(); + MethodType insMethodType = methodType(method.getReturnType(), method.getDeclaringClass(), method.getParameterTypes()); + boolean isSerializable = Arrays.stream(functionInterfaceType.getInterfaces()).anyMatch(i -> i.isAssignableFrom(Serializable.class)); + CallSite callSite = Opt.ofTry(() -> isSerializable ? + LambdaMetafactory.altMetafactory( + caller, + invokeName, + invokedType, + samMethodType, + implMethod, + insMethodType, + FLAG_SERIALIZABLE + ) : + LambdaMetafactory.metafactory( + caller, + invokeName, + invokedType, + samMethodType, + implMethod, + insMethodType + )).get(); + + try { + //noinspection unchecked + F lambda = (F) callSite.getTarget().invoke(); + CACHE.put(cacheKey, lambda); + return lambda; + } catch (Throwable e) { + throw new RuntimeException(e); + } + } +} diff --git a/hutool-core/src/main/java/cn/hutool/core/lang/func/LambdaUtil.java b/hutool-core/src/main/java/cn/hutool/core/lang/func/LambdaUtil.java index 239f845f0..b57824ac4 100755 --- a/hutool-core/src/main/java/cn/hutool/core/lang/func/LambdaUtil.java +++ b/hutool-core/src/main/java/cn/hutool/core/lang/func/LambdaUtil.java @@ -13,6 +13,8 @@ import java.lang.invoke.SerializedLambda; import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.lang.reflect.Proxy; +import java.util.function.BiConsumer; +import java.util.function.Function; /** * Lambda相关工具类 @@ -130,7 +132,69 @@ public class LambdaUtil { return BeanUtil.getFieldName(getMethodName(func)); } + /** + * 等效于 Obj::getXxx + * + * @param getMethod getter方法 + * @param 调用getter方法对象类型 + * @param getter方法返回值类型 + * @return Obj::getXxx + */ + public static Function getter(Method getMethod) { + return LambdaFactory.buildLambda(Function.class, getMethod); + } + /** + * 等效于 Obj::getXxx + * + * @param clazz 调用getter方法对象类 + * @param fieldName 字段名称 + * @param 调用getter方法对象类型 + * @param getter方法返回值类型 + * @return Obj::getXxx + */ + public static Function getter(Class clazz, String fieldName) { + return LambdaFactory.buildLambda(Function.class, BeanUtil.getBeanDesc(clazz).getGetter(fieldName)); + } + + /** + * 等效于 Obj::setXxx + * + * @param setMethod setter方法 + * @param 调用setter方法对象类型 + * @param

setter方法返回的值类型 + * @return Obj::setXxx + */ + public static BiConsumer setter(Method setMethod) { + return LambdaFactory.buildLambda(BiConsumer.class, setMethod); + } + + /** + * Obj::setXxx + * + * @param clazz 调用setter方法对象类 + * @param fieldName 字段名称 + * @param 调用setter方法对象类型 + * @param

setter方法返回的值类型 + * @return Obj::setXxx + */ + public static BiConsumer setter(Class clazz, String fieldName) { + return LambdaFactory.buildLambda(BiConsumer.class, BeanUtil.getBeanDesc(clazz).getSetter(fieldName)); + } + + /** + * 等效于 Obj::method + * + * @param lambdaType 接受lambda的函数式接口类型 + * @param clazz 调用类 + * @param methodName 方法名 + * @param paramsTypes 方法参数类型数组 + * @param 函数式接口类型 + * @return Obj::method + */ + public static F lambda(Class lambdaType, Class clazz, String methodName, Class... paramsTypes) { + return LambdaFactory.buildLambda(lambdaType, clazz, methodName, paramsTypes); + } //region Private methods diff --git a/hutool-core/src/test/java/cn/hutool/core/lang/func/LambdaFactoryTest.java b/hutool-core/src/test/java/cn/hutool/core/lang/func/LambdaFactoryTest.java new file mode 100644 index 000000000..4e1d3e6c8 --- /dev/null +++ b/hutool-core/src/test/java/cn/hutool/core/lang/func/LambdaFactoryTest.java @@ -0,0 +1,253 @@ +package cn.hutool.core.lang.func; + +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.collection.ListUtil; +import lombok.Data; +import lombok.Getter; +import lombok.Setter; +import lombok.SneakyThrows; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.lang.invoke.MethodHandleProxies; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.function.BiConsumer; +import java.util.function.Function; + +/** + * @author nasodaengineer + */ +public class LambdaFactoryTest { + + @Test(expected = RuntimeException.class) + public void testMethodNotMatch() { + LambdaFactory.buildLambda(Function.class, Something.class, "setId", Long.class); + } + + @Test + public void buildLambdaTest() { + Something something = new Something(); + something.setId(1L); + something.setName("name"); + + Function get11 = LambdaFactory.buildLambda(Function.class, Something.class, "getId"); + Function get12 = LambdaFactory.buildLambda(Function.class, Something.class, "getId"); + + Assert.assertEquals(get11, get12); + Assert.assertEquals(something.getId(), get11.apply(something)); + + String name = "sname"; + BiConsumer set = LambdaFactory.buildLambda(BiConsumer.class, Something.class, "setName", String.class); + set.accept(something, name); + + Assert.assertEquals(something.getName(), name); + } + + @Data + private static class Something { + private Long id; + private String name; + } + + /** + * 简单的性能测试,大多数情况下直接调用 快于 lambda 快于 反射 快于 代理 + * + * @author nasodaengineer + */ + @RunWith(Parameterized.class) + public static class PerformanceTest { + + @Parameterized.Parameter + public int count; + + @Parameterized.Parameters + public static Collection parameters() { + return ListUtil.of(1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000); + } + + /** + *

hardCode 运行1次耗时 4600 ns + *

lambda 运行1次耗时 5400 ns + *

reflect 运行1次耗时 7100 ns + *

proxy 运行1次耗时 145400 ns + *

-------------------------------------------- + *

hardCode 运行10次耗时 1200 ns + *

lambda 运行10次耗时 1200 ns + *

proxy 运行10次耗时 10800 ns + *

reflect 运行10次耗时 20100 ns + *

-------------------------------------------- + *

lambda 运行100次耗时 6300 ns + *

hardCode 运行100次耗时 6400 ns + *

proxy 运行100次耗时 65100 ns + *

reflect 运行100次耗时 196800 ns + *

-------------------------------------------- + *

hardCode 运行1000次耗时 54100 ns + *

lambda 运行1000次耗时 82000 ns + *

reflect 运行1000次耗时 257300 ns + *

proxy 运行1000次耗时 822700 ns + *

-------------------------------------------- + *

hardCode 运行10000次耗时 84400 ns + *

lambda 运行10000次耗时 209200 ns + *

reflect 运行10000次耗时 1024300 ns + *

proxy 运行10000次耗时 1467300 ns + *

-------------------------------------------- + *

lambda 运行100000次耗时 618700 ns + *

hardCode 运行100000次耗时 675200 ns + *

reflect 运行100000次耗时 914100 ns + *

proxy 运行100000次耗时 2745800 ns + *

-------------------------------------------- + *

lambda 运行1000000次耗时 5342500 ns + *

hardCode 运行1000000次耗时 5616400 ns + *

reflect 运行1000000次耗时 9176700 ns + *

proxy 运行1000000次耗时 15801800 ns + *

-------------------------------------------- + *

lambda 运行10000000次耗时 53415200 ns + *

hardCode 运行10000000次耗时 63714500 ns + *

proxy 运行10000000次耗时 116420900 ns + *

reflect 运行10000000次耗时 120817900 ns + *

-------------------------------------------- + *

lambda 运行100000000次耗时 546706600 ns + *

hardCode 运行100000000次耗时 557174500 ns + *

reflect 运行100000000次耗时 924166200 ns + *

proxy 运行100000000次耗时 1862735900 ns + *

-------------------------------------------- + */ + @Test + @SneakyThrows + public void lambdaGetPerformanceTest() { + Something something = new Something(); + something.setId(1L); + something.setName("name"); + Method getByReflect = Something.class.getMethod("getId"); + Function getByProxy = MethodHandleProxies.asInterfaceInstance(Function.class, MethodHandles.lookup().unreflect(getByReflect)); + Function getByLambda = LambdaFactory.buildLambda(Function.class, getByReflect); + Task lambdaTask = new Task("lambda", () -> getByLambda.apply(something)); + Task proxyTask = new Task("proxy", () -> getByProxy.apply(something)); + Task reflectTask = new Task("reflect", () -> getByReflect.invoke(something)); + Task hardCodeTask = new Task("hardCode", () -> something.getId()); + Task[] tasks = {hardCodeTask, lambdaTask, proxyTask, reflectTask}; + loop(count, tasks); + } + + /** + *

hardCode 运行1次耗时 4800 ns + *

lambda 运行1次耗时 9100 ns + *

reflect 运行1次耗时 20600 ns + *

-------------------------------------------- + *

hardCode 运行10次耗时 1800 ns + *

lambda 运行10次耗时 2100 ns + *

reflect 运行10次耗时 24500 ns + *

-------------------------------------------- + *

hardCode 运行100次耗时 15700 ns + *

lambda 运行100次耗时 17500 ns + *

reflect 运行100次耗时 418200 ns + *

-------------------------------------------- + *

hardCode 运行1000次耗时 101700 ns + *

lambda 运行1000次耗时 157200 ns + *

reflect 运行1000次耗时 504900 ns + *

-------------------------------------------- + *

hardCode 运行10000次耗时 360800 ns + *

lambda 运行10000次耗时 371700 ns + *

reflect 运行10000次耗时 1887600 ns + *

-------------------------------------------- + *

lambda 运行100000次耗时 581500 ns + *

hardCode 运行100000次耗时 1629900 ns + *

reflect 运行100000次耗时 1781700 ns + *

-------------------------------------------- + *

lambda 运行1000000次耗时 175400 ns + *

hardCode 运行1000000次耗时 2045400 ns + *

reflect 运行1000000次耗时 14363200 ns + *

-------------------------------------------- + *

hardCode 运行10000000次耗时 60149000 ns + *

lambda 运行10000000次耗时 60502600 ns + *

reflect 运行10000000次耗时 187412800 ns + *

-------------------------------------------- + *

hardCode 运行100000000次耗时 562997300 ns + *

lambda 运行100000000次耗时 564359700 ns + *

reflect 运行100000000次耗时 1163617600 ns + * -------------------------------------------- + */ + @Test + @SneakyThrows + public void lambdaSetPerformanceTest() { + Something something = new Something(); + something.setId(1L); + something.setName("name"); + Method setByReflect = Something.class.getMethod("setName", String.class); + BiConsumer setByLambda = LambdaFactory.buildLambda(BiConsumer.class, setByReflect); + String name = "name1"; + Task lambdaTask = new Task("lambda", () -> { + setByLambda.accept(something, name); + return null; + }); + Task reflectTask = new Task("reflect", () -> { + setByReflect.invoke(something, name); + return null; + }); + Task hardCodeTask = new Task("hardCode", () -> { + something.setName(name); + return null; + }); + Task[] tasks = {hardCodeTask, lambdaTask, reflectTask}; + loop(count, tasks); + } + + @SneakyThrows + private void loop(int count, Task... tasks) { + Arrays.stream(tasks) + .peek(task -> { + LambdaFactoryTest.SupplierThrowable runnable = task.getRunnable(); + long cost = System.nanoTime(); + for (int i = 0; i < count; i++) { + runnable.get(); + } + cost = System.nanoTime() - cost; + task.setCost(cost); + task.setCount(count); + }) + .sorted(Comparator.comparing(Task::getCost)) + .map(Task::format) + .forEach(System.out::println); + System.out.println("--------------------------------------------"); + } + + @Getter + private class Task { + private String name; + private LambdaFactoryTest.SupplierThrowable runnable; + @Setter + private long cost; + @Setter + private Integer count; + + public Task(String name, LambdaFactoryTest.SupplierThrowable runnable) { + this.name = name; + this.runnable = runnable; + } + + public String format() { + return String.format("%-10s 运行%d次耗时 %d ns", name, count, cost); + } + } + + } + + @FunctionalInterface + interface SupplierThrowable { + T get0() throws Throwable; + + default T get() { + try { + return get0(); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/hutool-core/src/test/java/cn/hutool/core/lang/func/LambdaUtilTest.java b/hutool-core/src/test/java/cn/hutool/core/lang/func/LambdaUtilTest.java index da6e3266e..08d0c144d 100644 --- a/hutool-core/src/test/java/cn/hutool/core/lang/func/LambdaUtilTest.java +++ b/hutool-core/src/test/java/cn/hutool/core/lang/func/LambdaUtilTest.java @@ -1,13 +1,19 @@ package cn.hutool.core.lang.func; +import cn.hutool.core.lang.Tuple; +import cn.hutool.core.reflect.MethodUtil; import lombok.AllArgsConstructor; import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.experimental.FieldNameConstants; import org.junit.Assert; import org.junit.Test; import java.io.Serializable; import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; import java.util.stream.Stream; public class LambdaUtilTest { @@ -82,8 +88,10 @@ public class LambdaUtilTest { // 一些特殊的lambda Assert.assertEquals("T", LambdaUtil.>>resolve(Stream::of).getParameterTypes()[0].getTypeName()); Assert.assertEquals(MyTeacher[][].class, LambdaUtil.>resolve(MyTeacher[][]::new).getReturnType()); - Assert.assertEquals(Integer[][][].class, LambdaUtil.>resolve(a -> {}).getParameterTypes()[0]); - Assert.assertEquals(Integer[][][].class, LambdaUtil.resolve((Serializable & SerConsumer3) (a, b, c) -> {}).getParameterTypes()[0]); + Assert.assertEquals(Integer[][][].class, LambdaUtil.>resolve(a -> { + }).getParameterTypes()[0]); + Assert.assertEquals(Integer[][][].class, LambdaUtil.resolve((Serializable & SerConsumer3) (a, b, c) -> { + }).getParameterTypes()[0]); }).forEach(Runnable::run); } @@ -136,11 +144,86 @@ public class LambdaUtilTest { Assert.assertEquals(MyTeacher.class, LambdaUtil.getRealClass(lambda)); }, () -> { // 数组测试 - final SerConsumer lambda = (String[] stringList) -> {}; + final SerConsumer lambda = (String[] stringList) -> { + }; Assert.assertEquals(String[].class, LambdaUtil.getRealClass(lambda)); }).forEach(Runnable::run); } + @Test + public void getterTest() { + Bean bean = new Bean(); + bean.setId(2L); + + Function getId = cn.hutool.core.lang.func.LambdaUtil.getter(MethodUtil.getMethod(Bean.class, "getId")); + Function getId2 = cn.hutool.core.lang.func.LambdaUtil.getter(Bean.class, Bean.Fields.id); + + Assert.assertEquals(getId, getId2); + Assert.assertEquals(bean.getId(), getId.apply(bean)); + } + + @Test + public void setterTest() { + Bean bean = new Bean(); + bean.setId(2L); + bean.setFlag(false); + + BiConsumer setId = cn.hutool.core.lang.func.LambdaUtil.setter(MethodUtil.getMethod(Bean.class, "setId", Long.class)); + BiConsumer setId2 = cn.hutool.core.lang.func.LambdaUtil.setter(Bean.class, Bean.Fields.id); + BiConsumer setFlag = cn.hutool.core.lang.func.LambdaUtil.setter(Bean.class, Bean.Fields.flag); + Assert.assertEquals(setId, setId2); + + setId.accept(bean, 3L); + setFlag.accept(bean, true); + Assert.assertEquals(3L, (long) bean.getId()); + Assert.assertTrue(bean.isFlag()); + } + + @Test + public void lambdaTest() { + Bean bean = new Bean(); + bean.setId(1L); + bean.setPid(0L); + bean.setFlag(true); + BiFunction uniqueKeyFunction = LambdaUtil.lambda(BiFunction.class, Bean.class, "uniqueKey", String.class); + Function4 paramsFunction = LambdaUtil.lambda(Function4.class, Bean.class, "params", String.class, Integer.class, Double.class); + Assert.assertEquals(bean.uniqueKey("test"), uniqueKeyFunction.apply(bean, "test")); + Assert.assertEquals(bean.params("test", 1, 0.5), paramsFunction.apply(bean, "test", 1, 0.5)); + } + + @FunctionalInterface + interface Function4 { + R apply(P1 p1, P2 p2, P3 p3, P4 p4); + } + + @Data + @FieldNameConstants + private static class Bean { + Long id; + Long pid; + boolean flag; + + private Tuple uniqueKey(String name) { + return new Tuple(id, pid, flag, name); + } + + public Tuple params(String name, Integer length, Double score) { + return new Tuple(name, length, score); + } + + public static Function idGetter() { + return Bean::getId; + } + + public Function idGet() { + return bean -> bean.id; + } + + public Function idGetting() { + return Bean::getId; + } + } + @Data @AllArgsConstructor static class MyStudent {