From 8e67251fd393b9ad955c2ca97d24b1db230ff7be Mon Sep 17 00:00:00 2001 From: Looly Date: Sat, 15 Jan 2022 13:08:46 +0800 Subject: [PATCH] fix upsert --- .../cn/hutool/core/collection/CollUtil.java | 12 +- .../main/java/cn/hutool/core/map/MapUtil.java | 19 ++ .../hutool/core/collection/CollUtilTest.java | 9 + .../main/java/cn/hutool/db/DialectRunner.java | 7 +- .../main/java/cn/hutool/db/StatementUtil.java | 2 +- .../java/cn/hutool/db/dialect/Dialect.java | 42 ++-- .../db/dialect/impl/AnsiSqlDialect.java | 11 +- .../cn/hutool/db/dialect/impl/H2Dialect.java | 51 +++-- .../hutool/db/dialect/impl/MysqlDialect.java | 50 ++++- .../hutool/db/dialect/impl/OracleDialect.java | 28 ++- .../db/dialect/impl/PhoenixDialect.java | 7 + .../db/dialect/impl/PostgresqlDialect.java | 56 ++++-- .../java/cn/hutool/db/sql/SqlBuilder.java | 187 ++++-------------- .../main/java/cn/hutool/db/sql/Wrapper.java | 56 ++++-- .../src/test/java/cn/hutool/db/H2Test.java | 2 +- 15 files changed, 298 insertions(+), 241 deletions(-) diff --git a/hutool-core/src/main/java/cn/hutool/core/collection/CollUtil.java b/hutool-core/src/main/java/cn/hutool/core/collection/CollUtil.java index 201308b2e..dd22d5275 100644 --- a/hutool-core/src/main/java/cn/hutool/core/collection/CollUtil.java +++ b/hutool-core/src/main/java/cn/hutool/core/collection/CollUtil.java @@ -2343,11 +2343,7 @@ public class CollUtil { */ @SuppressWarnings("unchecked") public static ArrayList valuesOfKeys(Map map, K... keys) { - final ArrayList values = new ArrayList<>(); - for (K k : keys) { - values.add(map.get(k)); - } - return values; + return MapUtil.valuesOfKeys(map, new ArrayIter<>(keys)); } /** @@ -2377,11 +2373,7 @@ public class CollUtil { * @since 3.0.9 */ public static ArrayList valuesOfKeys(Map map, Iterator keys) { - final ArrayList values = new ArrayList<>(); - while (keys.hasNext()) { - values.add(map.get(keys.next())); - } - return values; + return MapUtil.valuesOfKeys(map, keys); } // ------------------------------------------------------------------------------------------------- sort diff --git a/hutool-core/src/main/java/cn/hutool/core/map/MapUtil.java b/hutool-core/src/main/java/cn/hutool/core/map/MapUtil.java index 7bbce4c4b..b0fb79250 100644 --- a/hutool-core/src/main/java/cn/hutool/core/map/MapUtil.java +++ b/hutool-core/src/main/java/cn/hutool/core/map/MapUtil.java @@ -1354,4 +1354,23 @@ public class MapUtil { } } } + + /** + * 从Map中获取指定键列表对应的值列表
+ * 如果key在map中不存在或key对应值为null,则返回值列表对应位置的值也为null + * + * @param 键类型 + * @param 值类型 + * @param map {@link Map} + * @param keys 键列表 + * @return 值列表 + * @since 5.7.20 + */ + public static ArrayList valuesOfKeys(Map map, Iterator keys) { + final ArrayList values = new ArrayList<>(); + while (keys.hasNext()) { + values.add(map.get(keys.next())); + } + return values; + } } diff --git a/hutool-core/src/test/java/cn/hutool/core/collection/CollUtilTest.java b/hutool-core/src/test/java/cn/hutool/core/collection/CollUtilTest.java index 4e866979d..ad0582e18 100644 --- a/hutool-core/src/test/java/cn/hutool/core/collection/CollUtilTest.java +++ b/hutool-core/src/test/java/cn/hutool/core/collection/CollUtilTest.java @@ -4,6 +4,7 @@ import cn.hutool.core.comparator.ComparableComparator; import cn.hutool.core.date.DateUtil; import cn.hutool.core.lang.Dict; import cn.hutool.core.map.MapUtil; +import cn.hutool.core.util.StrUtil; import lombok.AllArgsConstructor; import lombok.Data; import org.junit.Assert; @@ -302,6 +303,14 @@ public class CollUtilTest { Assert.assertEquals(CollUtil.newArrayList("b", "c"), filtered); } + @Test + public void filterSetTest() { + Set set = CollUtil.newLinkedHashSet("a", "b", "", " ", "c"); + Set filtered = CollUtil.filter(set, StrUtil::isNotBlank); + + Assert.assertEquals(CollUtil.newLinkedHashSet("a", "b", "c"), filtered); + } + @Test public void filterRemoveTest() { ArrayList list = CollUtil.newArrayList("a", "b", "c"); diff --git a/hutool-db/src/main/java/cn/hutool/db/DialectRunner.java b/hutool-db/src/main/java/cn/hutool/db/DialectRunner.java index 26e18bc09..24c535c4b 100644 --- a/hutool-db/src/main/java/cn/hutool/db/DialectRunner.java +++ b/hutool-db/src/main/java/cn/hutool/db/DialectRunner.java @@ -99,7 +99,12 @@ public class DialectRunner implements Serializable { * @since 5.7.20 */ public int upsert(Connection conn, Entity record, String... keys) throws SQLException { - PreparedStatement ps = getDialect().psForUpsert(conn, record, keys); + PreparedStatement ps = null; + try{ + ps = getDialect().psForUpsert(conn, record, keys); + }catch (SQLException ignore){ + // 方言不支持,使用默认 + } if (null != ps) { try { return ps.executeUpdate(); diff --git a/hutool-db/src/main/java/cn/hutool/db/StatementUtil.java b/hutool-db/src/main/java/cn/hutool/db/StatementUtil.java index b3a6eed12..b821c819e 100644 --- a/hutool-db/src/main/java/cn/hutool/db/StatementUtil.java +++ b/hutool-db/src/main/java/cn/hutool/db/StatementUtil.java @@ -194,7 +194,7 @@ public class StatementUtil { * @throws SQLException SQL异常 * @since 4.6.7 */ - public static PreparedStatement prepareStatementForBatch(Connection conn, String sql, List fields, Entity... entities) throws SQLException { + public static PreparedStatement prepareStatementForBatch(Connection conn, String sql, Iterable fields, Entity... entities) throws SQLException { Assert.notBlank(sql, "Sql String must be not blank!"); sql = sql.trim(); diff --git a/hutool-db/src/main/java/cn/hutool/db/dialect/Dialect.java b/hutool-db/src/main/java/cn/hutool/db/dialect/Dialect.java index d37a7e867..6ad5922bd 100644 --- a/hutool-db/src/main/java/cn/hutool/db/dialect/Dialect.java +++ b/hutool-db/src/main/java/cn/hutool/db/dialect/Dialect.java @@ -37,7 +37,8 @@ public interface Dialect extends Serializable { // -------------------------------------------- Execute /** - * 构建用于插入的PreparedStatement + * 构建用于插入的{@link PreparedStatement}
+ * 用户实现需按照数据库方言格式,将{@link Entity}转换为带有占位符的SQL语句及参数列表 * * @param conn 数据库连接对象 * @param entity 数据实体类(包含表名) @@ -47,7 +48,8 @@ public interface Dialect extends Serializable { PreparedStatement psForInsert(Connection conn, Entity entity) throws SQLException; /** - * 构建用于批量插入的PreparedStatement + * 构建用于批量插入的PreparedStatement
+ * 用户实现需按照数据库方言格式,将{@link Entity}转换为带有占位符的SQL语句及参数列表 * * @param conn 数据库连接对象 * @param entities 数据实体,实体的结构必须全部一致,否则插入结果将不可预知 @@ -57,7 +59,9 @@ public interface Dialect extends Serializable { PreparedStatement psForInsertBatch(Connection conn, Entity... entities) throws SQLException; /** - * 构建用于删除的PreparedStatement + * 构建用于删除的{@link PreparedStatement}
+ * 用户实现需按照数据库方言格式,将{@link Query}转换为带有占位符的SQL语句及参数列表
+ * {@link Query}中包含了删除所需的表名、查询条件等信息,可借助SqlBuilder完成SQL语句生成。 * * @param conn 数据库连接对象 * @param query 查找条件(包含表名) @@ -67,7 +71,9 @@ public interface Dialect extends Serializable { PreparedStatement psForDelete(Connection conn, Query query) throws SQLException; /** - * 构建用于更新的PreparedStatement + * 构建用于更新的{@link PreparedStatement}
+ * 用户实现需按照数据库方言格式,将{@link Entity}配合{@link Query}转换为带有占位符的SQL语句及参数列表
+ * 其中{@link Entity}中包含需要更新的数据信息,{@link Query}包含更新的查找条件信息。 * * @param conn 数据库连接对象 * @param entity 数据实体类(包含表名) @@ -80,7 +86,9 @@ public interface Dialect extends Serializable { // -------------------------------------------- Query /** - * 构建用于获取多条记录的PreparedStatement + * 构建用于获取多条记录的{@link PreparedStatement}
+ * 用户实现需按照数据库方言格式,将{@link Query}转换为带有占位符的SQL语句及参数列表
+ * {@link Query}中包含了查询所需的表名、查询条件等信息,可借助SqlBuilder完成SQL语句生成。 * * @param conn 数据库连接对象 * @param query 查询条件(包含表名) @@ -90,7 +98,9 @@ public interface Dialect extends Serializable { PreparedStatement psForFind(Connection conn, Query query) throws SQLException; /** - * 构建用于分页查询的PreparedStatement + * 构建用于分页查询的{@link PreparedStatement}
+ * 用户实现需按照数据库方言格式,将{@link Query}转换为带有占位符的SQL语句及参数列表
+ * {@link Query}中包含了分页查询所需的表名、查询条件、分页等信息,可借助SqlBuilder完成SQL语句生成。 * * @param conn 数据库连接对象 * @param query 查询条件(包含表名) @@ -100,7 +110,7 @@ public interface Dialect extends Serializable { PreparedStatement psForPage(Connection conn, Query query) throws SQLException; /** - * 构建用于分页查询的PreparedStatement
+ * 构建用于分页查询的{@link PreparedStatement}
* 可以在此方法中使用{@link SqlBuilder#orderBy(Order...)}方法加入排序信息, * 排序信息通过{@link Page#getOrders()}获取 * @@ -114,7 +124,9 @@ public interface Dialect extends Serializable { PreparedStatement psForPage(Connection conn, SqlBuilder sqlBuilder, Page page) throws SQLException; /** - * 构建用于查询行数的PreparedStatement + * 构建用于查询行数的{@link PreparedStatement}
+ * 用户实现需按照数据库方言格式,将{@link Query}转换为带有占位符的SQL语句及参数列表
+ * {@link Query}中包含了表名、查询条件等信息,可借助SqlBuilder完成SQL语句生成。 * * @param conn 数据库连接对象 * @param query 查询条件(包含表名) @@ -127,7 +139,9 @@ public interface Dialect extends Serializable { } /** - * 构建用于查询行数的PreparedStatement + * 构建用于查询行数的{@link PreparedStatement}
+ * 用户实现需按照数据库方言格式,将{@link Query}转换为带有占位符的SQL语句及参数列表
+ * {@link Query}中包含了表名、查询条件等信息,可借助SqlBuilder完成SQL语句生成。 * * @param conn 数据库连接对象 * @param sqlBuilder 查询语句,应该包含分页等信息 @@ -144,18 +158,18 @@ public interface Dialect extends Serializable { } /** - * 构建用于upsert的PreparedStatement
- * 方言实现需实现此默认方法,默认返回{@code null} + * 构建用于upsert的{@link PreparedStatement}
+ * 方言实现需实现此默认方法,如果没有实现,抛出{@link SQLException} * * @param conn 数据库连接对象 * @param entity 数据实体类(包含表名) - * @param keys 查找字段 + * @param keys 查找字段,某些数据库此字段必须,如H2,某些数据库无需此字段,如MySQL(通过主键) * @return PreparedStatement - * @throws SQLException SQL执行异常 + * @throws SQLException SQL执行异常,或方言数据不支持此操作 * @since 5.7.20 */ default PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { - return null; + throw new SQLException("Unsupported upsert operation of " + dialectName()); } diff --git a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/AnsiSqlDialect.java b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/AnsiSqlDialect.java index 3793d41f0..a615a45f3 100644 --- a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/AnsiSqlDialect.java +++ b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/AnsiSqlDialect.java @@ -1,5 +1,6 @@ package cn.hutool.db.dialect.impl; +import cn.hutool.core.collection.CollUtil; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.ArrayUtil; import cn.hutool.core.util.StrUtil; @@ -17,16 +18,17 @@ import cn.hutool.db.sql.Wrapper; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; +import java.util.Set; /** * ANSI SQL 方言 - * + * * @author loolly * */ public class AnsiSqlDialect implements Dialect { private static final long serialVersionUID = 2088101129774974580L; - + protected Wrapper wrapper = new Wrapper(); @Override @@ -53,7 +55,8 @@ public class AnsiSqlDialect implements Dialect { } // 批量,根据第一行数据结构生成SQL占位符 final SqlBuilder insert = SqlBuilder.create(wrapper).insert(entities[0], this.dialectName()); - return StatementUtil.prepareStatementForBatch(conn, insert.build(), insert.getFields(), entities); + final Set fields = CollUtil.filter(entities[0].keySet(), StrUtil::isNotBlank); + return StatementUtil.prepareStatementForBatch(conn, insert.build(), fields, entities); } @Override @@ -113,7 +116,7 @@ public class AnsiSqlDialect implements Dialect { /** * 根据不同数据库在查询SQL语句基础上包装其分页的语句
* 各自数据库通过重写此方法实现最小改动情况下修改分页语句 - * + * * @param find 标准查询语句 * @param page 分页对象 * @return 分页语句 diff --git a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/H2Dialect.java b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/H2Dialect.java index 917980049..110aea5f6 100644 --- a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/H2Dialect.java +++ b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/H2Dialect.java @@ -2,19 +2,16 @@ package cn.hutool.db.dialect.impl; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.ArrayUtil; +import cn.hutool.core.util.StrUtil; import cn.hutool.db.Entity; import cn.hutool.db.Page; import cn.hutool.db.StatementUtil; import cn.hutool.db.dialect.DialectName; -import cn.hutool.db.sql.Condition; -import cn.hutool.db.sql.Query; import cn.hutool.db.sql.SqlBuilder; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; -import java.util.Arrays; -import java.util.function.Function; /** * H2数据库方言 @@ -39,18 +36,42 @@ public class H2Dialect extends AnsiSqlDialect { return find.append(" limit ").append(page.getStartPosition()).append(" , ").append(page.getPageSize()); } - /** - * 构建用于upsert的PreparedStatement - * - * @param conn 数据库连接对象 - * @param entity 数据实体类(包含表名) - * @param keys 查找字段 如果不提供keys将自动使用主键 - * @return PreparedStatement - * @throws SQLException SQL执行异常 - */ @Override public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { - final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys); - return StatementUtil.prepareStatement(conn, upsert); + Assert.notEmpty(keys, "Keys must be not empty for H2 MERGE SQL."); + SqlBuilder.validateEntity(entity); + final SqlBuilder builder = SqlBuilder.create(wrapper); + + final StringBuilder fieldsPart = new StringBuilder(); + final StringBuilder placeHolder = new StringBuilder(); + + // 构建字段部分和参数占位符部分 + entity.forEach((field, value)->{ + if (StrUtil.isNotBlank(field)) { + if (fieldsPart.length() > 0) { + // 非第一个参数,追加逗号 + fieldsPart.append(", "); + placeHolder.append(", "); + } + + fieldsPart.append((null != wrapper) ? wrapper.wrap(field) : field); + placeHolder.append("?"); + builder.addParams(value); + } + }); + + String tableName = entity.getTableName(); + if (null != this.wrapper) { + tableName = this.wrapper.wrap(tableName); + } + builder.append("MERGE INTO ").append(tableName) + // 字段列表 + .append(" (").append(fieldsPart) + // 更新关键字列表 + .append(") KEY(").append(ArrayUtil.join(keys, ", ")) + // 更新值列表 + .append(") VALUES (").append(placeHolder).append(")"); + + return StatementUtil.prepareStatement(conn, builder); } } diff --git a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/MysqlDialect.java b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/MysqlDialect.java index 3ce7a199a..a52cfa759 100644 --- a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/MysqlDialect.java +++ b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/MysqlDialect.java @@ -1,5 +1,6 @@ package cn.hutool.db.dialect.impl; +import cn.hutool.core.util.StrUtil; import cn.hutool.db.Entity; import cn.hutool.db.Page; import cn.hutool.db.StatementUtil; @@ -34,17 +35,58 @@ public class MysqlDialect extends AnsiSqlDialect{ } /** - * 构建用于upsert的PreparedStatement + * 构建用于upsert的{@link PreparedStatement}
+ * MySQL通过主键方式实现Upsert,故keys无效,生成SQL语法为: + *
+	 *     INSERT INTO demo(a,b,c) values(?, ?, ?) ON DUPLICATE KEY UPDATE a=values(a), b=values(b), c=values(c);
+	 * 
* * @param conn 数据库连接对象 * @param entity 数据实体类(包含表名) - * @param keys 查找字段 + * @param keys 此参数无效 * @return PreparedStatement * @throws SQLException SQL执行异常 + * @since 5.7.20 */ @Override public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { - final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys); - return StatementUtil.prepareStatement(conn, upsert); + SqlBuilder.validateEntity(entity); + final SqlBuilder builder = SqlBuilder.create(wrapper); + + final StringBuilder fieldsPart = new StringBuilder(); + final StringBuilder placeHolder = new StringBuilder(); + final StringBuilder updateHolder = new StringBuilder(); + + // 构建字段部分和参数占位符部分 + entity.forEach((field, value)->{ + if (StrUtil.isNotBlank(field)) { + if (fieldsPart.length() > 0) { + // 非第一个参数,追加逗号 + fieldsPart.append(", "); + placeHolder.append(", "); + updateHolder.append(", "); + } + + field = (null != wrapper) ? wrapper.wrap(field) : field; + fieldsPart.append(field); + updateHolder.append(field).append("=values(").append(field).append(")"); + placeHolder.append("?"); + builder.addParams(value); + } + }); + + String tableName = entity.getTableName(); + if (null != this.wrapper) { + tableName = this.wrapper.wrap(tableName); + } + builder.append("INSERT INTO ").append(tableName) + // 字段列表 + .append(" (").append(fieldsPart) + // 更新值列表 + .append(") VALUES (").append(placeHolder) + // 主键冲突后的更新操作 + .append(") ON DUPLICATE KEY UPDATE ").append(updateHolder); + + return StatementUtil.prepareStatement(conn, builder); } } diff --git a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/OracleDialect.java b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/OracleDialect.java index 925889611..037aaf8a0 100644 --- a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/OracleDialect.java +++ b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/OracleDialect.java @@ -1,32 +1,44 @@ package cn.hutool.db.dialect.impl; +import cn.hutool.core.util.StrUtil; import cn.hutool.db.Page; import cn.hutool.db.dialect.DialectName; import cn.hutool.db.sql.SqlBuilder; /** * Oracle 方言 - * @author loolly * + * @author loolly */ -public class OracleDialect extends AnsiSqlDialect{ +public class OracleDialect extends AnsiSqlDialect { private static final long serialVersionUID = 6122761762247483015L; + /** + * 检查字段值是否为Oracle自增字段,自增字段以`.nextval`结尾 + * + * @param value 检查的字段值 + * @return 是否为Oracle自增字段 + * @since 5.7.20 + */ + public static boolean isNextVal(Object value) { + return (value instanceof CharSequence) && StrUtil.endWithIgnoreCase(value.toString(), ".nextval"); + } + public OracleDialect() { //Oracle所有字段名用双引号包围,防止字段名或表名与系统关键字冲突 //wrapper = new Wrapper('"'); } - + @Override protected SqlBuilder wrapPageSql(SqlBuilder find, Page page) { final int[] startEnd = page.getStartEnd(); return find - .insertPreFragment("SELECT * FROM ( SELECT row_.*, rownum rownum_ from ( ") - .append(" ) row_ where rownum <= ").append(startEnd[1])// - .append(") table_alias")// - .append(" where table_alias.rownum_ > ").append(startEnd[0]);// + .insertPreFragment("SELECT * FROM ( SELECT row_.*, rownum rownum_ from ( ") + .append(" ) row_ where rownum <= ").append(startEnd[1])// + .append(") table_alias")// + .append(" where table_alias.rownum_ > ").append(startEnd[0]);// } - + @Override public String dialectName() { return DialectName.ORACLE.name(); diff --git a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PhoenixDialect.java b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PhoenixDialect.java index c2ad5bdd5..8064e8982 100644 --- a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PhoenixDialect.java +++ b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PhoenixDialect.java @@ -24,6 +24,7 @@ public class PhoenixDialect extends AnsiSqlDialect { @Override public PreparedStatement psForUpdate(Connection conn, Entity entity, Query query) throws SQLException { // Phoenix的插入、更新语句是统一的,统一使用upsert into关键字 + // Phoenix只支持通过主键更新操作,因此query无效,自动根据entity中的主键更新 return super.psForInsert(conn, entity); } @@ -31,4 +32,10 @@ public class PhoenixDialect extends AnsiSqlDialect { public String dialectName() { return DialectName.PHOENIX.name(); } + + @Override + public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { + // Phoenix只支持通过主键更新操作,因此query无效,自动根据entity中的主键更新 + return psForInsert(conn, entity); + } } diff --git a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PostgresqlDialect.java b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PostgresqlDialect.java index 82f5fe373..1f5a90122 100644 --- a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PostgresqlDialect.java +++ b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PostgresqlDialect.java @@ -1,5 +1,8 @@ package cn.hutool.db.dialect.impl; +import cn.hutool.core.lang.Assert; +import cn.hutool.core.util.ArrayUtil; +import cn.hutool.core.util.StrUtil; import cn.hutool.db.Entity; import cn.hutool.db.StatementUtil; import cn.hutool.db.dialect.DialectName; @@ -28,21 +31,48 @@ public class PostgresqlDialect extends AnsiSqlDialect{ return DialectName.POSTGREESQL.name(); } - /** - * 构建用于upsert的PreparedStatement - * - * @param conn 数据库连接对象 - * @param entity 数据实体类(包含表名) - * @param keys 查找字段 必须是有唯一索引的列且不能为空 - * @return PreparedStatement - * @throws SQLException SQL执行异常 - */ @Override public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { - if (null==keys || keys.length==0){ - throw new SQLException("keys不能为空"); + Assert.notEmpty(keys, "Keys must be not empty for Postgres."); + SqlBuilder.validateEntity(entity); + final SqlBuilder builder = SqlBuilder.create(wrapper); + + final StringBuilder fieldsPart = new StringBuilder(); + final StringBuilder placeHolder = new StringBuilder(); + final StringBuilder updateHolder = new StringBuilder(); + + // 构建字段部分和参数占位符部分 + entity.forEach((field, value)->{ + if (StrUtil.isNotBlank(field)) { + if (fieldsPart.length() > 0) { + // 非第一个参数,追加逗号 + fieldsPart.append(", "); + placeHolder.append(", "); + updateHolder.append(", "); + } + + final String wrapedField = (null != wrapper) ? wrapper.wrap(field) : field; + fieldsPart.append(wrapedField); + updateHolder.append(wrapedField).append("=EXCLUDED.").append(field); + placeHolder.append("?"); + builder.addParams(value); + } + }); + + String tableName = entity.getTableName(); + if (null != this.wrapper) { + tableName = this.wrapper.wrap(tableName); } - final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys); - return StatementUtil.prepareStatement(conn, upsert); + builder.append("INSERT INTO ").append(tableName) + // 字段列表 + .append(" (").append(fieldsPart) + // 更新值列表 + .append(") VALUES (").append(placeHolder) + // 定义检查冲突的主键或字段 + .append(") ON CONFLICT (").append(ArrayUtil.join(keys,", ")) + // 主键冲突后的更新操作 + .append(") DO UPDATE SET ").append(updateHolder); + + return StatementUtil.prepareStatement(conn, builder); } } diff --git a/hutool-db/src/main/java/cn/hutool/db/sql/SqlBuilder.java b/hutool-db/src/main/java/cn/hutool/db/sql/SqlBuilder.java index fcf09c671..a2a52634d 100644 --- a/hutool-db/src/main/java/cn/hutool/db/sql/SqlBuilder.java +++ b/hutool-db/src/main/java/cn/hutool/db/sql/SqlBuilder.java @@ -7,13 +7,13 @@ import cn.hutool.core.util.StrUtil; import cn.hutool.db.DbRuntimeException; import cn.hutool.db.Entity; import cn.hutool.db.dialect.DialectName; +import cn.hutool.db.dialect.impl.OracleDialect; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.Map.Entry; /** * SQL构建器
@@ -57,6 +57,24 @@ public class SqlBuilder implements Builder { return create().append(sql); } + /** + * 验证实体类对象的有效性 + * + * @param entity 实体类对象 + * @throws DbRuntimeException SQL异常包装,获取元数据信息失败 + */ + public static void validateEntity(Entity entity) throws DbRuntimeException { + if (null == entity) { + throw new DbRuntimeException("Entity is null !"); + } + if (StrUtil.isBlank(entity.getTableName())) { + throw new DbRuntimeException("Entity`s table name is null !"); + } + if (entity.isEmpty()) { + throw new DbRuntimeException("No filed and value in this entity !"); + } + } + // --------------------------------------------------------------- Static methods end // --------------------------------------------------------------- Enums start @@ -87,10 +105,6 @@ public class SqlBuilder implements Builder { // --------------------------------------------------------------- Enums end private final StringBuilder sql = new StringBuilder(); - /** - * 字段列表(仅用于插入和更新) - */ - private final List fields = new ArrayList<>(); /** * 占位符对应的值列表 */ @@ -146,41 +160,29 @@ public class SqlBuilder implements Builder { // 验证 validateEntity(entity); - if (null != wrapper) { - // 包装表名 entity = wrapper.wrap(entity); - entity.setTableName(wrapper.wrap(entity.getTableName())); - } - final boolean isOracle = DialectName.ORACLE.match(dialectName);// 对Oracle的特殊处理 final StringBuilder fieldsPart = new StringBuilder(); final StringBuilder placeHolder = new StringBuilder(); - boolean isFirst = true; - String field; - Object value; - for (Entry entry : entity.entrySet()) { - field = entry.getKey(); - value = entry.getValue(); - if (StrUtil.isNotBlank(field) /* && null != value */) { - if (isFirst) { - isFirst = false; - } else { + entity.forEach((field, value) -> { + if (StrUtil.isNotBlank(field)) { + if (fieldsPart.length() > 0) { // 非第一个参数,追加逗号 fieldsPart.append(", "); placeHolder.append(", "); } - this.fields.add(field); fieldsPart.append((null != wrapper) ? wrapper.wrap(field) : field); - if (isOracle && value instanceof String && StrUtil.endWithIgnoreCase((String) value, ".nextval")) { + if (isOracle && OracleDialect.isNextVal(value)) { // Oracle的特殊自增键,通过字段名.nextval获得下一个值 placeHolder.append(value); } else { + // 普通字段使用占位符 placeHolder.append("?"); this.paramValues.add(value); } } - } + }); // issue#1656@Github Phoenix兼容 if (DialectName.PHOENIX.match(dialectName)) { @@ -189,94 +191,18 @@ public class SqlBuilder implements Builder { sql.append("INSERT INTO "); } - sql.append(entity.getTableName()) + String tableName = entity.getTableName(); + if (null != this.wrapper) { + // 包装表名 entity = wrapper.wrap(entity); + tableName = this.wrapper.wrap(tableName); + } + sql.append(tableName) .append(" (").append(fieldsPart).append(") VALUES (")// .append(placeHolder).append(")"); return this; } - /** - * 插入
- * 插入会忽略空的字段名及其对应值,但是对于有字段名对应值为{@code null}的情况不忽略 - * - * @param entity 实体 - * @param dialectName 方言名,用于对特殊数据库特殊处理 - * @param keys 根据何字段来确认唯一性,不传则用主键 - * @return 自己 - * @since 5.7.21 - */ - public SqlBuilder upsert(Entity entity, String dialectName, String... keys) { - // 验证 - validateEntity(entity); - - if (null != wrapper) { - // 包装表名 entity = wrapper.wrap(entity); - entity.setTableName(wrapper.wrap(entity.getTableName())); - } - - final boolean isOracle = DialectName.ORACLE.match(dialectName);// 对Oracle的特殊处理 - final StringBuilder fieldsPart = new StringBuilder(); - final StringBuilder placeHolder = new StringBuilder(); - - boolean isFirst = true; - String field; - Object value; - for (Entry entry : entity.entrySet()) { - field = entry.getKey(); - value = entry.getValue(); - if (StrUtil.isNotBlank(field) /* && null != value */) { - if (isFirst) { - isFirst = false; - } else { - // 非第一个参数,追加逗号 - fieldsPart.append(", "); - placeHolder.append(", "); - } - - this.fields.add(field); - fieldsPart.append((null != wrapper) ? wrapper.wrap(field) : field); - if (isOracle && value instanceof String && StrUtil.endWithIgnoreCase((String) value, ".nextval")) { - // Oracle的特殊自增键,通过字段名.nextval获得下一个值 - placeHolder.append(value); - } else { - placeHolder.append("?"); - this.paramValues.add(value); - } - } - } - - // issue#1656@Github Phoenix兼容 - if (DialectName.PHOENIX.match(dialectName)) { - sql.append("UPSERT INTO ").append(entity.getTableName()); - } else if (DialectName.MYSQL.match(dialectName)) { - sql.append("INSERT INTO "); - sql.append(entity.getTableName()) - .append(" (").append(fieldsPart).append(") VALUES (") - .append(placeHolder).append(") on duplicate key update ") - .append(ArrayUtil.join(ArrayUtil.map(entity.keySet().toArray(), String.class, (k) -> k + "=values(" + k + ")"), ",")); - } else if (DialectName.H2.match(dialectName)) { - sql.append("MERGE INTO ").append(entity.getTableName()); - if (null != keys && keys.length > 0) { - sql.append(" KEY(").append(ArrayUtil.join(keys, ",")) - .append(") VALUES (") - .append(placeHolder) - .append(")"); - } - } else if (DialectName.POSTGREESQL.match(dialectName)) { - sql.append("INSERT INTO "); - sql.append(entity.getTableName()) - .append(" (").append(fieldsPart).append(") VALUES (") - .append(placeHolder).append(") on conflict (") - .append(ArrayUtil.join(keys,",")) - .append(") do update set ") - .append(ArrayUtil.join(ArrayUtil.map(entity.keySet().toArray(), String.class, (k) -> k + "=excluded." + k ), ",")); - } else { - throw new RuntimeException(dialectName + " not support yet"); - } - return this; - } - /** * 删除 * @@ -308,25 +234,22 @@ public class SqlBuilder implements Builder { // 验证 validateEntity(entity); + String tableName = entity.getTableName(); if (null != wrapper) { // 包装表名 - entity.setTableName(wrapper.wrap(entity.getTableName())); + tableName = wrapper.wrap(tableName); } - sql.append("UPDATE ").append(entity.getTableName()).append(" SET "); - String field; - for (Entry entry : entity.entrySet()) { - field = entry.getKey(); + sql.append("UPDATE ").append(tableName).append(" SET "); + entity.forEach((field, value) -> { if (StrUtil.isNotBlank(field)) { if (paramValues.size() > 0) { sql.append(", "); } - this.fields.add(field); sql.append((null != wrapper) ? wrapper.wrap(field) : field).append(" = ? "); - this.paramValues.add(entry.getValue());// 更新不对空做处理,因为存在清空字段的情况 + this.paramValues.add(value);// 更新不对空做处理,因为存在清空字段的情况 } - } - + }); return this; } @@ -653,24 +576,6 @@ public class SqlBuilder implements Builder { } // --------------------------------------------------------------- Builder end - /** - * 获得插入或更新的数据库字段列表 - * - * @return 插入或更新的数据库字段列表 - */ - public String[] getFieldArray() { - return this.fields.toArray(new String[0]); - } - - /** - * 获得插入或更新的数据库字段列表 - * - * @return 插入或更新的数据库字段列表 - */ - public List getFields() { - return this.fields; - } - /** * 获得占位符对应的值列表
* @@ -725,23 +630,5 @@ public class SqlBuilder implements Builder { return ConditionBuilder.of(conditions).build(this.paramValues); } - - /** - * 验证实体类对象的有效性 - * - * @param entity 实体类对象 - * @throws DbRuntimeException SQL异常包装,获取元数据信息失败 - */ - private static void validateEntity(Entity entity) throws DbRuntimeException { - if (null == entity) { - throw new DbRuntimeException("Entity is null !"); - } - if (StrUtil.isBlank(entity.getTableName())) { - throw new DbRuntimeException("Entity`s table name is null !"); - } - if (entity.isEmpty()) { - throw new DbRuntimeException("No filed and value in this entity !"); - } - } // --------------------------------------------------------------- private method end } diff --git a/hutool-db/src/main/java/cn/hutool/db/sql/Wrapper.java b/hutool-db/src/main/java/cn/hutool/db/sql/Wrapper.java index 12ee7778d..8851b2eac 100644 --- a/hutool-db/src/main/java/cn/hutool/db/sql/Wrapper.java +++ b/hutool-db/src/main/java/cn/hutool/db/sql/Wrapper.java @@ -15,15 +15,19 @@ import java.util.Map.Entry; /** * 包装器
* 主要用于字段名的包装(在字段名的前后加字符,例如反引号来避免与数据库的关键字冲突) - * @author Looly * + * @author Looly */ public class Wrapper implements Serializable { private static final long serialVersionUID = 1L; - /** 前置包装符号 */ + /** + * 前置包装符号 + */ private Character preWrapQuote; - /** 后置包装符号 */ + /** + * 后置包装符号 + */ private Character sufWrapQuote; public Wrapper() { @@ -31,6 +35,7 @@ public class Wrapper implements Serializable { /** * 构造 + * * @param wrapQuote 单包装字符 */ public Wrapper(Character wrapQuote) { @@ -40,6 +45,7 @@ public class Wrapper implements Serializable { /** * 包装符号 + * * @param preWrapQuote 前置包装符号 * @param sufWrapQuote 后置包装符号 */ @@ -49,14 +55,17 @@ public class Wrapper implements Serializable { } //--------------------------------------------------------------- Getters and Setters start + /** * @return 前置包装符号 */ public char getPreWrapQuote() { return preWrapQuote; } + /** * 设置前置包装的符号 + * * @param preWrapQuote 前置包装符号 */ public void setPreWrapQuote(Character preWrapQuote) { @@ -69,8 +78,10 @@ public class Wrapper implements Serializable { public char getSufWrapQuote() { return sufWrapQuote; } + /** * 设置后置包装的符号 + * * @param sufWrapQuote 后置包装符号 */ public void setSufWrapQuote(Character sufWrapQuote) { @@ -81,26 +92,27 @@ public class Wrapper implements Serializable { /** * 包装字段名
* 有时字段与SQL的某些关键字冲突,导致SQL出错,因此需要将字段名用单引号或者反引号包装起来,避免冲突 + * * @param field 字段名 * @return 包装后的字段名 */ - public String wrap(String field){ - if(preWrapQuote == null || sufWrapQuote == null || StrUtil.isBlank(field)) { + public String wrap(String field) { + if (preWrapQuote == null || sufWrapQuote == null || StrUtil.isBlank(field)) { return field; } //如果已经包含包装的引号,返回原字符 - if(StrUtil.isSurround(field, preWrapQuote, sufWrapQuote)){ + if (StrUtil.isSurround(field, preWrapQuote, sufWrapQuote)) { return field; } //如果字段中包含通配符或者括号(字段通配符或者函数),不做包装 - if(StrUtil.containsAnyIgnoreCase(field, "*", "(", " ", " as ")) { + if (StrUtil.containsAnyIgnoreCase(field, "*", "(", " ", " as ")) { return field; } //对于Oracle这类数据库,表名中包含用户名需要单独拆分包装 - if(field.contains(StrUtil.DOT)){ + if (field.contains(StrUtil.DOT)) { final Collection target = CollUtil.edit(StrUtil.split(field, CharUtil.DOT, 2), t -> StrUtil.format("{}{}{}", preWrapQuote, t, sufWrapQuote)); return CollectionUtil.join(target, StrUtil.DOT); } @@ -111,16 +123,17 @@ public class Wrapper implements Serializable { /** * 包装字段名
* 有时字段与SQL的某些关键字冲突,导致SQL出错,因此需要将字段名用单引号或者反引号包装起来,避免冲突 + * * @param fields 字段名 * @return 包装后的字段名 */ - public String[] wrap(String... fields){ - if(ArrayUtil.isEmpty(fields)) { + public String[] wrap(String... fields) { + if (ArrayUtil.isEmpty(fields)) { return fields; } String[] wrappedFields = new String[fields.length]; - for(int i = 0; i < fields.length; i++) { + for (int i = 0; i < fields.length; i++) { wrappedFields[i] = wrap(fields[i]); } @@ -130,11 +143,12 @@ public class Wrapper implements Serializable { /** * 包装字段名
* 有时字段与SQL的某些关键字冲突,导致SQL出错,因此需要将字段名用单引号或者反引号包装起来,避免冲突 + * * @param fields 字段名 * @return 包装后的字段名 */ - public Collection wrap(Collection fields){ - if(CollectionUtil.isEmpty(fields)) { + public Collection wrap(Collection fields) { + if (CollectionUtil.isEmpty(fields)) { return fields; } @@ -142,13 +156,14 @@ public class Wrapper implements Serializable { } /** - * 包装字段名
+ * 包装表名和字段名,此方法返回一个新的Entity实体类
* 有时字段与SQL的某些关键字冲突,导致SQL出错,因此需要将字段名用单引号或者反引号包装起来,避免冲突 + * * @param entity 被包装的实体 - * @return 包装后的字段名 + * @return 新的实体 */ - public Entity wrap(Entity entity){ - if(null == entity) { + public Entity wrap(Entity entity) { + if (null == entity) { return null; } @@ -168,14 +183,15 @@ public class Wrapper implements Serializable { /** * 包装字段名
* 有时字段与SQL的某些关键字冲突,导致SQL出错,因此需要将字段名用单引号或者反引号包装起来,避免冲突 + * * @param conditions 被包装的实体 * @return 包装后的字段名 */ - public Condition[] wrap(Condition... conditions){ + public Condition[] wrap(Condition... conditions) { final Condition[] clonedConditions = new Condition[conditions.length]; - if(ArrayUtil.isNotEmpty(conditions)) { + if (ArrayUtil.isNotEmpty(conditions)) { Condition clonedCondition; - for(int i = 0; i < conditions.length; i++) { + for (int i = 0; i < conditions.length; i++) { clonedCondition = conditions[i].clone(); clonedCondition.setField(wrap(clonedCondition.getField())); clonedConditions[i] = clonedCondition; diff --git a/hutool-db/src/test/java/cn/hutool/db/H2Test.java b/hutool-db/src/test/java/cn/hutool/db/H2Test.java index 85bd27627..bc7e48dd7 100644 --- a/hutool-db/src/test/java/cn/hutool/db/H2Test.java +++ b/hutool-db/src/test/java/cn/hutool/db/H2Test.java @@ -1,6 +1,5 @@ package cn.hutool.db; -import com.alibaba.druid.support.json.JSONUtils; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; @@ -40,6 +39,7 @@ public class H2Test { List query = Db.use(DS_GROUP_NAME).find(Entity.create("test")); Assert.assertEquals(4, query.size()); } + @Test public void upsertTest() throws SQLException { Db db=Db.use(DS_GROUP_NAME);