This commit is contained in:
Looly 2023-01-17 11:26:16 +08:00
parent 4c8a06e749
commit 2f38e2e138
3 changed files with 69 additions and 63 deletions

View File

@ -51,8 +51,9 @@ public class LambdaFactory {
* @param methodName 方法名称 * @param methodName 方法名称
* @param paramTypes 方法参数数组 * @param paramTypes 方法参数数组
* @return 接受Lambda的函数式接口对象 * @return 接受Lambda的函数式接口对象
* @param <F> Function类型
*/ */
public static <F> F build(Class<F> functionInterfaceType, Class methodClass, String methodName, Class... paramTypes) { public static <F> F build(final Class<F> functionInterfaceType, final Class<?> methodClass, final String methodName, final Class<?>... paramTypes) {
return build(functionInterfaceType, MethodUtil.getMethod(methodClass, methodName, paramTypes)); return build(functionInterfaceType, MethodUtil.getMethod(methodClass, methodName, paramTypes));
} }
@ -62,30 +63,31 @@ public class LambdaFactory {
* @param functionInterfaceType 接受Lambda的函数式接口类型 * @param functionInterfaceType 接受Lambda的函数式接口类型
* @param method 方法对象 * @param method 方法对象
* @return 接受Lambda的函数式接口对象 * @return 接受Lambda的函数式接口对象
* @param <F> Function类型
*/ */
public static <F> F build(Class<F> functionInterfaceType, Method method) { public static <F> F build(final Class<F> functionInterfaceType, final Method method) {
Assert.notNull(functionInterfaceType); Assert.notNull(functionInterfaceType);
Assert.notNull(method); Assert.notNull(method);
MutableEntry<Class<?>, Method> cacheKey = new MutableEntry<>(functionInterfaceType, method); final MutableEntry<Class<?>, Method> cacheKey = new MutableEntry<>(functionInterfaceType, method);
//noinspection unchecked //noinspection unchecked
return (F) CACHE.computeIfAbsent(cacheKey, key -> { return (F) CACHE.computeIfAbsent(cacheKey, key -> {
List<Method> abstractMethods = Arrays.stream(functionInterfaceType.getMethods()) final List<Method> abstractMethods = Arrays.stream(functionInterfaceType.getMethods())
.filter(m -> Modifier.isAbstract(m.getModifiers())) .filter(m -> Modifier.isAbstract(m.getModifiers()))
.collect(Collectors.toList()); .collect(Collectors.toList());
Assert.equals(abstractMethods.size(), 1, "不支持非函数式接口"); Assert.equals(abstractMethods.size(), 1, "不支持非函数式接口");
if (!method.isAccessible()) { if (!method.isAccessible()) {
method.setAccessible(true); method.setAccessible(true);
} }
Method invokeMethod = abstractMethods.get(0); final Method invokeMethod = abstractMethods.get(0);
MethodHandles.Lookup caller = LookupFactory.lookup(method.getDeclaringClass()); final MethodHandles.Lookup caller = LookupFactory.lookup(method.getDeclaringClass());
String invokeName = invokeMethod.getName(); final String invokeName = invokeMethod.getName();
MethodType invokedType = methodType(functionInterfaceType); final MethodType invokedType = methodType(functionInterfaceType);
MethodType samMethodType = methodType(invokeMethod.getReturnType(), invokeMethod.getParameterTypes()); final MethodType samMethodType = methodType(invokeMethod.getReturnType(), invokeMethod.getParameterTypes());
MethodHandle implMethod = Opt.ofTry(() -> caller.unreflect(method)).get(); final MethodHandle implMethod = Opt.ofTry(() -> caller.unreflect(method)).get();
MethodType insMethodType = methodType(method.getReturnType(), method.getDeclaringClass(), method.getParameterTypes()); final MethodType insMethodType = methodType(method.getReturnType(), method.getDeclaringClass(), method.getParameterTypes());
boolean isSerializable = Serializable.class.isAssignableFrom(functionInterfaceType); final boolean isSerializable = Serializable.class.isAssignableFrom(functionInterfaceType);
try { try {
CallSite callSite = isSerializable ? final CallSite callSite = isSerializable ?
LambdaMetafactory.altMetafactory( LambdaMetafactory.altMetafactory(
caller, caller,
invokeName, invokeName,
@ -103,8 +105,9 @@ public class LambdaFactory {
implMethod, implMethod,
insMethodType insMethodType
); );
//noinspection unchecked
return (F) callSite.getTarget().invoke(); return (F) callSite.getTarget().invoke();
} catch (Throwable e) { } catch (final Throwable e) {
throw new UtilException(e); throw new UtilException(e);
} }
}); });

View File

@ -1,7 +1,6 @@
package cn.hutool.core.lang.func; package cn.hutool.core.lang.func;
import cn.hutool.core.collection.ListUtil; import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.exceptions.UtilException;
import cn.hutool.core.reflect.MethodHandleUtil; import cn.hutool.core.reflect.MethodHandleUtil;
import lombok.Data; import lombok.Data;
import lombok.Getter; import lombok.Getter;
@ -12,7 +11,11 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
import java.lang.invoke.*; import java.lang.invoke.LambdaConversionException;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandleProxies;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
@ -31,25 +34,25 @@ public class LambdaFactoryTest {
public void testMethodNotMatch() { public void testMethodNotMatch() {
try { try {
LambdaFactory.build(Function.class, Something.class, "setId", Long.class); LambdaFactory.build(Function.class, Something.class, "setId", Long.class);
} catch (Exception e) { } catch (final Exception e) {
Assert.assertTrue(e.getCause() instanceof LambdaConversionException); Assert.assertTrue(e.getCause() instanceof LambdaConversionException);
} }
} }
@Test @Test
public void buildLambdaTest() { public void buildLambdaTest() {
Something something = new Something(); final Something something = new Something();
something.setId(1L); something.setId(1L);
something.setName("name"); something.setName("name");
Function<Something, Long> get11 = LambdaFactory.build(Function.class, Something.class, "getId"); final Function<Something, Long> get11 = LambdaFactory.build(Function.class, Something.class, "getId");
Function<Something, Long> get12 = LambdaFactory.build(Function.class, Something.class, "getId"); final Function<Something, Long> get12 = LambdaFactory.build(Function.class, Something.class, "getId");
Assert.assertEquals(get11, get12); Assert.assertEquals(get11, get12);
Assert.assertEquals(something.getId(), get11.apply(something)); Assert.assertEquals(something.getId(), get11.apply(something));
String name = "sname"; final String name = "sname";
BiConsumer<Something, String> set = LambdaFactory.build(BiConsumer.class, Something.class, "setName", String.class); final BiConsumer<Something, String> set = LambdaFactory.build(BiConsumer.class, Something.class, "setName", String.class);
set.accept(something, name); set.accept(something, name);
Assert.assertEquals(something.getName(), name); Assert.assertEquals(something.getName(), name);
@ -136,19 +139,19 @@ public class LambdaFactoryTest {
@Test @Test
@SneakyThrows @SneakyThrows
public void lambdaGetPerformanceTest() { public void lambdaGetPerformanceTest() {
Something something = new Something(); final Something something = new Something();
something.setId(1L); something.setId(1L);
something.setName("name"); something.setName("name");
Method getByReflect = Something.class.getMethod("getId"); final Method getByReflect = Something.class.getMethod("getId");
MethodHandle getByMh = MethodHandleUtil.findMethod(Something.class, "getId", MethodType.methodType(Long.class)); final MethodHandle getByMh = MethodHandleUtil.findMethod(Something.class, "getId", MethodType.methodType(Long.class));
Function getByProxy = MethodHandleProxies.asInterfaceInstance(Function.class, MethodHandles.lookup().unreflect(getByReflect)); final Function getByProxy = MethodHandleProxies.asInterfaceInstance(Function.class, MethodHandles.lookup().unreflect(getByReflect));
Function getByLambda = LambdaFactory.build(Function.class, getByReflect); final Function getByLambda = LambdaFactory.build(Function.class, getByReflect);
Task lambdaTask = new Task("lambda", () -> getByLambda.apply(something)); final Task lambdaTask = new Task("lambda", () -> getByLambda.apply(something));
Task mhTask = new Task("mh", () -> getByMh.invoke(something)); final Task mhTask = new Task("mh", () -> getByMh.invoke(something));
Task proxyTask = new Task("proxy", () -> getByProxy.apply(something)); final Task proxyTask = new Task("proxy", () -> getByProxy.apply(something));
Task reflectTask = new Task("reflect", () -> getByReflect.invoke(something)); final Task reflectTask = new Task("reflect", () -> getByReflect.invoke(something));
Task hardCodeTask = new Task("hardCode", () -> something.getId()); final Task hardCodeTask = new Task("hardCode", () -> something.getId());
Task[] tasks = {hardCodeTask, lambdaTask, mhTask, proxyTask, reflectTask}; final Task[] tasks = {hardCodeTask, lambdaTask, mhTask, proxyTask, reflectTask};
loop(count, tasks); loop(count, tasks);
} }
@ -211,43 +214,43 @@ public class LambdaFactoryTest {
@Test @Test
@SneakyThrows @SneakyThrows
public void lambdaSetPerformanceTest() { public void lambdaSetPerformanceTest() {
Something something = new Something(); final Something something = new Something();
something.setId(1L); something.setId(1L);
something.setName("name"); something.setName("name");
Method setByReflect = Something.class.getMethod("setName", String.class); final Method setByReflect = Something.class.getMethod("setName", String.class);
MethodHandle setByMh = MethodHandleUtil.findMethod(Something.class, "setName", MethodType.methodType(Void.TYPE, String.class)); final MethodHandle setByMh = MethodHandleUtil.findMethod(Something.class, "setName", MethodType.methodType(Void.TYPE, String.class));
BiConsumer setByProxy = MethodHandleProxies.asInterfaceInstance(BiConsumer.class, setByMh); final BiConsumer setByProxy = MethodHandleProxies.asInterfaceInstance(BiConsumer.class, setByMh);
BiConsumer setByLambda = LambdaFactory.build(BiConsumer.class, setByReflect); final BiConsumer setByLambda = LambdaFactory.build(BiConsumer.class, setByReflect);
String name = "name1"; final String name = "name1";
Task lambdaTask = new Task("lambda", () -> { final Task lambdaTask = new Task("lambda", () -> {
setByLambda.accept(something, name); setByLambda.accept(something, name);
return null; return null;
}); });
Task proxyTask = new Task("proxy", () -> { final Task proxyTask = new Task("proxy", () -> {
setByProxy.accept(something, name); setByProxy.accept(something, name);
return null; return null;
}); });
Task mhTask = new Task("mh", () -> { final Task mhTask = new Task("mh", () -> {
setByMh.invoke(something, name); setByMh.invoke(something, name);
return null; return null;
}); });
Task reflectTask = new Task("reflect", () -> { final Task reflectTask = new Task("reflect", () -> {
setByReflect.invoke(something, name); setByReflect.invoke(something, name);
return null; return null;
}); });
Task hardCodeTask = new Task("hardCode", () -> { final Task hardCodeTask = new Task("hardCode", () -> {
something.setName(name); something.setName(name);
return null; return null;
}); });
Task[] tasks = {hardCodeTask, lambdaTask, proxyTask, mhTask, reflectTask}; final Task[] tasks = {hardCodeTask, lambdaTask, proxyTask, mhTask, reflectTask};
loop(count, tasks); loop(count, tasks);
} }
@SneakyThrows @SneakyThrows
private void loop(int count, Task... tasks) { private void loop(final int count, final Task... tasks) {
Arrays.stream(tasks) Arrays.stream(tasks)
.peek(task -> { .peek(task -> {
LambdaFactoryTest.SupplierThrowable runnable = task.getRunnable(); final LambdaFactoryTest.SupplierThrowable runnable = task.getRunnable();
long cost = System.nanoTime(); long cost = System.nanoTime();
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
runnable.get(); runnable.get();
@ -271,13 +274,13 @@ public class LambdaFactoryTest {
@Setter @Setter
private Integer count; private Integer count;
public Task(String name, LambdaFactoryTest.SupplierThrowable<?> runnable) { public Task(final String name, final LambdaFactoryTest.SupplierThrowable<?> runnable) {
this.name = name; this.name = name;
this.runnable = runnable; this.runnable = runnable;
} }
public String format() { public String format() {
TimeUnit timeUnit = TimeUnit.NANOSECONDS; final TimeUnit timeUnit = TimeUnit.NANOSECONDS;
return String.format("%-10s 运行%d次耗时 %d %s", name, count, timeUnit.convert(cost, TimeUnit.NANOSECONDS), timeUnit.name()); return String.format("%-10s 运行%d次耗时 %d %s", name, count, timeUnit.convert(cost, TimeUnit.NANOSECONDS), timeUnit.name());
} }
} }
@ -291,7 +294,7 @@ public class LambdaFactoryTest {
default T get() { default T get() {
try { try {
return get0(); return get0();
} catch (Throwable e) { } catch (final Throwable e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }

View File

@ -144,7 +144,7 @@ public class LambdaUtilTest {
Assert.assertEquals(MyTeacher.class, LambdaUtil.getRealClass(lambda)); Assert.assertEquals(MyTeacher.class, LambdaUtil.getRealClass(lambda));
}, () -> { }, () -> {
// 数组测试 // 数组测试
final SerConsumer<String[]> lambda = (String[] stringList) -> { final SerConsumer<String[]> lambda = (final String[] stringList) -> {
}; };
Assert.assertEquals(String[].class, LambdaUtil.getRealClass(lambda)); Assert.assertEquals(String[].class, LambdaUtil.getRealClass(lambda));
}).forEach(Runnable::run); }).forEach(Runnable::run);
@ -152,11 +152,11 @@ public class LambdaUtilTest {
@Test @Test
public void getterTest() { public void getterTest() {
Bean bean = new Bean(); final Bean bean = new Bean();
bean.setId(2L); bean.setId(2L);
Function<Bean, Long> getId = LambdaUtil.buildGetter(MethodUtil.getMethod(Bean.class, "getId")); final Function<Bean, Long> getId = LambdaUtil.buildGetter(MethodUtil.getMethod(Bean.class, "getId"));
Function<Bean, Long> getId2 = LambdaUtil.buildGetter(Bean.class, Bean.Fields.id); final Function<Bean, Long> getId2 = LambdaUtil.buildGetter(Bean.class, Bean.Fields.id);
Assert.assertEquals(getId, getId2); Assert.assertEquals(getId, getId2);
Assert.assertEquals(bean.getId(), getId.apply(bean)); Assert.assertEquals(bean.getId(), getId.apply(bean));
@ -164,13 +164,13 @@ public class LambdaUtilTest {
@Test @Test
public void setterTest() { public void setterTest() {
Bean bean = new Bean(); final Bean bean = new Bean();
bean.setId(2L); bean.setId(2L);
bean.setFlag(false); bean.setFlag(false);
BiConsumer<Bean, Long> setId = LambdaUtil.buildSetter(MethodUtil.getMethod(Bean.class, "setId", Long.class)); final BiConsumer<Bean, Long> setId = LambdaUtil.buildSetter(MethodUtil.getMethod(Bean.class, "setId", Long.class));
BiConsumer<Bean, Long> setId2 = LambdaUtil.buildSetter(Bean.class, Bean.Fields.id); final BiConsumer<Bean, Long> setId2 = LambdaUtil.buildSetter(Bean.class, Bean.Fields.id);
BiConsumer<Bean, Boolean> setFlag = LambdaUtil.buildSetter(Bean.class, Bean.Fields.flag); final BiConsumer<Bean, Boolean> setFlag = LambdaUtil.buildSetter(Bean.class, Bean.Fields.flag);
Assert.assertEquals(setId, setId2); Assert.assertEquals(setId, setId2);
setId.accept(bean, 3L); setId.accept(bean, 3L);
@ -181,12 +181,12 @@ public class LambdaUtilTest {
@Test @Test
public void lambdaTest() { public void lambdaTest() {
Bean bean = new Bean(); final Bean bean = new Bean();
bean.setId(1L); bean.setId(1L);
bean.setPid(0L); bean.setPid(0L);
bean.setFlag(true); bean.setFlag(true);
BiFunction<Bean, String, Tuple> uniqueKeyFunction = LambdaUtil.build(BiFunction.class, Bean.class, "uniqueKey", String.class); final BiFunction<Bean, String, Tuple> uniqueKeyFunction = LambdaUtil.build(BiFunction.class, Bean.class, "uniqueKey", String.class);
Function4<Tuple, Bean, String, Integer, Double> paramsFunction = LambdaUtil.build(Function4.class, Bean.class, "params", String.class, Integer.class, Double.class); final Function4<Tuple, Bean, String, Integer, Double> paramsFunction = LambdaUtil.build(Function4.class, Bean.class, "params", String.class, Integer.class, Double.class);
Assert.assertEquals(bean.uniqueKey("test"), uniqueKeyFunction.apply(bean, "test")); Assert.assertEquals(bean.uniqueKey("test"), uniqueKeyFunction.apply(bean, "test"));
Assert.assertEquals(bean.params("test", 1, 0.5), paramsFunction.apply(bean, "test", 1, 0.5)); Assert.assertEquals(bean.params("test", 1, 0.5), paramsFunction.apply(bean, "test", 1, 0.5));
} }
@ -203,11 +203,11 @@ public class LambdaUtilTest {
Long pid; Long pid;
boolean flag; boolean flag;
private Tuple uniqueKey(String name) { private Tuple uniqueKey(final String name) {
return new Tuple(id, pid, flag, name); return new Tuple(id, pid, flag, name);
} }
public Tuple params(String name, Integer length, Double score) { public Tuple params(final String name, final Integer length, final Double score) {
return new Tuple(name, length, score); return new Tuple(name, length, score);
} }