diff --git a/hutool-extra/src/main/java/cn/hutool/extra/spring/SpringUtil.java b/hutool-extra/src/main/java/cn/hutool/extra/spring/SpringUtil.java index 5a3b46c13..dad9236a6 100644 --- a/hutool-extra/src/main/java/cn/hutool/extra/spring/SpringUtil.java +++ b/hutool-extra/src/main/java/cn/hutool/extra/spring/SpringUtil.java @@ -1,10 +1,14 @@ package cn.hutool.extra.spring; +import cn.hutool.core.lang.TypeReference; import cn.hutool.core.util.ArrayUtil; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; +import org.springframework.core.ResolvableType; import org.springframework.stereotype.Component; +import java.lang.reflect.ParameterizedType; +import java.util.Arrays; import java.util.Map; /** @@ -73,6 +77,22 @@ public class SpringUtil implements ApplicationContextAware { return applicationContext.getBean(name, clazz); } + /** + * 通过类型参考返回带泛型参数的Bean + * + * @param reference 类型参考,用于持有转换后的泛型类型 + * @param Bean类型 + * @return 带泛型参数的Bean + */ + @SuppressWarnings("unchecked") + public static T getBean(TypeReference reference) { + ParameterizedType parameterizedType = (ParameterizedType) reference.getType(); + Class rawType = (Class) parameterizedType.getRawType(); + Class[] genericTypes = Arrays.stream(parameterizedType.getActualTypeArguments()).map(type -> (Class) type).toArray(Class[]::new); + String[] beanNames = applicationContext.getBeanNamesForType(ResolvableType.forClassWithGenerics(rawType, genericTypes)); + return getBean(beanNames[0], rawType); + } + /** * 获取指定类型对应的所有Bean,包括子类 * diff --git a/hutool-extra/src/test/java/cn/hutool/extra/spring/SpringUtilTest.java b/hutool-extra/src/test/java/cn/hutool/extra/spring/SpringUtilTest.java index 0bff7ee5a..e7fbaed87 100644 --- a/hutool-extra/src/test/java/cn/hutool/extra/spring/SpringUtilTest.java +++ b/hutool-extra/src/test/java/cn/hutool/extra/spring/SpringUtilTest.java @@ -1,5 +1,7 @@ package cn.hutool.extra.spring; +import cn.hutool.core.lang.TypeReference; +import cn.hutool.core.map.MapUtil; import lombok.Data; import org.junit.Assert; import org.junit.Test; @@ -8,6 +10,9 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import java.util.HashMap; +import java.util.Map; + @RunWith(SpringJUnit4ClassRunner.class) @SpringBootTest(classes = {SpringUtil.class, SpringUtilTest.Demo2.class}) //@Import(cn.hutool.extra.spring.SpringUtil.class) @@ -20,6 +25,14 @@ public class SpringUtilTest { Assert.assertEquals("test", testDemo.getName()); } + @Test + public void getBeanWithTypeReferenceTest() { + Map mapBean = SpringUtil.getBean(new TypeReference>() {}); + Assert.assertNotNull(mapBean); + Assert.assertEquals("value1", mapBean.get("key1")); + Assert.assertEquals("value2", mapBean.get("key2")); + } + @Data public static class Demo2{ private long id; @@ -32,5 +45,13 @@ public class SpringUtilTest { demo.setName("test"); return demo; } + + @Bean(name="mapDemo") + public Map generateMap() { + HashMap map = MapUtil.newHashMap(); + map.put("key1", "value1"); + map.put("key2", "value2"); + return map; + } } }