fix upsert

This commit is contained in:
Looly 2022-01-15 13:08:46 +08:00
parent 9da17cf6c4
commit 8e67251fd3
15 changed files with 298 additions and 241 deletions

View File

@ -2343,11 +2343,7 @@ public class CollUtil {
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public static <K, V> ArrayList<V> valuesOfKeys(Map<K, V> map, K... keys) { public static <K, V> ArrayList<V> valuesOfKeys(Map<K, V> map, K... keys) {
final ArrayList<V> values = new ArrayList<>(); return MapUtil.valuesOfKeys(map, new ArrayIter<>(keys));
for (K k : keys) {
values.add(map.get(k));
}
return values;
} }
/** /**
@ -2377,11 +2373,7 @@ public class CollUtil {
* @since 3.0.9 * @since 3.0.9
*/ */
public static <K, V> ArrayList<V> valuesOfKeys(Map<K, V> map, Iterator<K> keys) { public static <K, V> ArrayList<V> valuesOfKeys(Map<K, V> map, Iterator<K> keys) {
final ArrayList<V> values = new ArrayList<>(); return MapUtil.valuesOfKeys(map, keys);
while (keys.hasNext()) {
values.add(map.get(keys.next()));
}
return values;
} }
// ------------------------------------------------------------------------------------------------- sort // ------------------------------------------------------------------------------------------------- sort

View File

@ -1354,4 +1354,23 @@ public class MapUtil {
} }
} }
} }
/**
* 从Map中获取指定键列表对应的值列表<br>
* 如果key在map中不存在或key对应值为null则返回值列表对应位置的值也为null
*
* @param <K> 键类型
* @param <V> 值类型
* @param map {@link Map}
* @param keys 键列表
* @return 值列表
* @since 5.7.20
*/
public static <K, V> ArrayList<V> valuesOfKeys(Map<K, V> map, Iterator<K> keys) {
final ArrayList<V> values = new ArrayList<>();
while (keys.hasNext()) {
values.add(map.get(keys.next()));
}
return values;
}
} }

View File

@ -4,6 +4,7 @@ import cn.hutool.core.comparator.ComparableComparator;
import cn.hutool.core.date.DateUtil; import cn.hutool.core.date.DateUtil;
import cn.hutool.core.lang.Dict; import cn.hutool.core.lang.Dict;
import cn.hutool.core.map.MapUtil; import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.StrUtil;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import org.junit.Assert; import org.junit.Assert;
@ -302,6 +303,14 @@ public class CollUtilTest {
Assert.assertEquals(CollUtil.newArrayList("b", "c"), filtered); Assert.assertEquals(CollUtil.newArrayList("b", "c"), filtered);
} }
@Test
public void filterSetTest() {
Set<String> set = CollUtil.newLinkedHashSet("a", "b", "", " ", "c");
Set<String> filtered = CollUtil.filter(set, StrUtil::isNotBlank);
Assert.assertEquals(CollUtil.newLinkedHashSet("a", "b", "c"), filtered);
}
@Test @Test
public void filterRemoveTest() { public void filterRemoveTest() {
ArrayList<String> list = CollUtil.newArrayList("a", "b", "c"); ArrayList<String> list = CollUtil.newArrayList("a", "b", "c");

View File

@ -99,7 +99,12 @@ public class DialectRunner implements Serializable {
* @since 5.7.20 * @since 5.7.20
*/ */
public int upsert(Connection conn, Entity record, String... keys) throws SQLException { public int upsert(Connection conn, Entity record, String... keys) throws SQLException {
PreparedStatement ps = getDialect().psForUpsert(conn, record, keys); PreparedStatement ps = null;
try{
ps = getDialect().psForUpsert(conn, record, keys);
}catch (SQLException ignore){
// 方言不支持使用默认
}
if (null != ps) { if (null != ps) {
try { try {
return ps.executeUpdate(); return ps.executeUpdate();

View File

@ -194,7 +194,7 @@ public class StatementUtil {
* @throws SQLException SQL异常 * @throws SQLException SQL异常
* @since 4.6.7 * @since 4.6.7
*/ */
public static PreparedStatement prepareStatementForBatch(Connection conn, String sql, List<String> fields, Entity... entities) throws SQLException { public static PreparedStatement prepareStatementForBatch(Connection conn, String sql, Iterable<String> fields, Entity... entities) throws SQLException {
Assert.notBlank(sql, "Sql String must be not blank!"); Assert.notBlank(sql, "Sql String must be not blank!");
sql = sql.trim(); sql = sql.trim();

View File

@ -37,7 +37,8 @@ public interface Dialect extends Serializable {
// -------------------------------------------- Execute // -------------------------------------------- Execute
/** /**
* 构建用于插入的PreparedStatement * 构建用于插入的{@link PreparedStatement}<br>
* 用户实现需按照数据库方言格式{@link Entity}转换为带有占位符的SQL语句及参数列表
* *
* @param conn 数据库连接对象 * @param conn 数据库连接对象
* @param entity 数据实体类包含表名 * @param entity 数据实体类包含表名
@ -47,7 +48,8 @@ public interface Dialect extends Serializable {
PreparedStatement psForInsert(Connection conn, Entity entity) throws SQLException; PreparedStatement psForInsert(Connection conn, Entity entity) throws SQLException;
/** /**
* 构建用于批量插入的PreparedStatement * 构建用于批量插入的PreparedStatement<br>
* 用户实现需按照数据库方言格式{@link Entity}转换为带有占位符的SQL语句及参数列表
* *
* @param conn 数据库连接对象 * @param conn 数据库连接对象
* @param entities 数据实体实体的结构必须全部一致否则插入结果将不可预知 * @param entities 数据实体实体的结构必须全部一致否则插入结果将不可预知
@ -57,7 +59,9 @@ public interface Dialect extends Serializable {
PreparedStatement psForInsertBatch(Connection conn, Entity... entities) throws SQLException; PreparedStatement psForInsertBatch(Connection conn, Entity... entities) throws SQLException;
/** /**
* 构建用于删除的PreparedStatement * 构建用于删除的{@link PreparedStatement}<br>
* 用户实现需按照数据库方言格式{@link Query}转换为带有占位符的SQL语句及参数列表<br>
* {@link Query}中包含了删除所需的表名查询条件等信息可借助SqlBuilder完成SQL语句生成
* *
* @param conn 数据库连接对象 * @param conn 数据库连接对象
* @param query 查找条件包含表名 * @param query 查找条件包含表名
@ -67,7 +71,9 @@ public interface Dialect extends Serializable {
PreparedStatement psForDelete(Connection conn, Query query) throws SQLException; PreparedStatement psForDelete(Connection conn, Query query) throws SQLException;
/** /**
* 构建用于更新的PreparedStatement * 构建用于更新的{@link PreparedStatement}<br>
* 用户实现需按照数据库方言格式{@link Entity}配合{@link Query}转换为带有占位符的SQL语句及参数列表<br>
* 其中{@link Entity}中包含需要更新的数据信息{@link Query}包含更新的查找条件信息
* *
* @param conn 数据库连接对象 * @param conn 数据库连接对象
* @param entity 数据实体类包含表名 * @param entity 数据实体类包含表名
@ -80,7 +86,9 @@ public interface Dialect extends Serializable {
// -------------------------------------------- Query // -------------------------------------------- Query
/** /**
* 构建用于获取多条记录的PreparedStatement * 构建用于获取多条记录的{@link PreparedStatement}<br>
* 用户实现需按照数据库方言格式{@link Query}转换为带有占位符的SQL语句及参数列表<br>
* {@link Query}中包含了查询所需的表名查询条件等信息可借助SqlBuilder完成SQL语句生成
* *
* @param conn 数据库连接对象 * @param conn 数据库连接对象
* @param query 查询条件包含表名 * @param query 查询条件包含表名
@ -90,7 +98,9 @@ public interface Dialect extends Serializable {
PreparedStatement psForFind(Connection conn, Query query) throws SQLException; PreparedStatement psForFind(Connection conn, Query query) throws SQLException;
/** /**
* 构建用于分页查询的PreparedStatement * 构建用于分页查询的{@link PreparedStatement}<br>
* 用户实现需按照数据库方言格式{@link Query}转换为带有占位符的SQL语句及参数列表<br>
* {@link Query}中包含了分页查询所需的表名查询条件分页等信息可借助SqlBuilder完成SQL语句生成
* *
* @param conn 数据库连接对象 * @param conn 数据库连接对象
* @param query 查询条件包含表名 * @param query 查询条件包含表名
@ -100,7 +110,7 @@ public interface Dialect extends Serializable {
PreparedStatement psForPage(Connection conn, Query query) throws SQLException; PreparedStatement psForPage(Connection conn, Query query) throws SQLException;
/** /**
* 构建用于分页查询的PreparedStatement<br> * 构建用于分页查询的{@link PreparedStatement}<br>
* 可以在此方法中使用{@link SqlBuilder#orderBy(Order...)}方法加入排序信息 * 可以在此方法中使用{@link SqlBuilder#orderBy(Order...)}方法加入排序信息
* 排序信息通过{@link Page#getOrders()}获取 * 排序信息通过{@link Page#getOrders()}获取
* *
@ -114,7 +124,9 @@ public interface Dialect extends Serializable {
PreparedStatement psForPage(Connection conn, SqlBuilder sqlBuilder, Page page) throws SQLException; PreparedStatement psForPage(Connection conn, SqlBuilder sqlBuilder, Page page) throws SQLException;
/** /**
* 构建用于查询行数的PreparedStatement * 构建用于查询行数的{@link PreparedStatement}<br>
* 用户实现需按照数据库方言格式{@link Query}转换为带有占位符的SQL语句及参数列表<br>
* {@link Query}中包含了表名查询条件等信息可借助SqlBuilder完成SQL语句生成
* *
* @param conn 数据库连接对象 * @param conn 数据库连接对象
* @param query 查询条件包含表名 * @param query 查询条件包含表名
@ -127,7 +139,9 @@ public interface Dialect extends Serializable {
} }
/** /**
* 构建用于查询行数的PreparedStatement * 构建用于查询行数的{@link PreparedStatement}<br>
* 用户实现需按照数据库方言格式{@link Query}转换为带有占位符的SQL语句及参数列表<br>
* {@link Query}中包含了表名查询条件等信息可借助SqlBuilder完成SQL语句生成
* *
* @param conn 数据库连接对象 * @param conn 数据库连接对象
* @param sqlBuilder 查询语句应该包含分页等信息 * @param sqlBuilder 查询语句应该包含分页等信息
@ -144,18 +158,18 @@ public interface Dialect extends Serializable {
} }
/** /**
* 构建用于upsert的PreparedStatement<br> * 构建用于upsert的{@link PreparedStatement}<br>
* 方言实现需实现此默认方法默认返回{@code null} * 方言实现需实现此默认方法如果没有实现抛出{@link SQLException}
* *
* @param conn 数据库连接对象 * @param conn 数据库连接对象
* @param entity 数据实体类包含表名 * @param entity 数据实体类包含表名
* @param keys 查找字段 * @param keys 查找字段某些数据库此字段必须如H2某些数据库无需此字段如MySQL通过主键
* @return PreparedStatement * @return PreparedStatement
* @throws SQLException SQL执行异常 * @throws SQLException SQL执行异常或方言数据不支持此操作
* @since 5.7.20 * @since 5.7.20
*/ */
default PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { default PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException {
return null; throw new SQLException("Unsupported upsert operation of " + dialectName());
} }

View File

@ -1,5 +1,6 @@
package cn.hutool.db.dialect.impl; package cn.hutool.db.dialect.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert; import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ArrayUtil; import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
@ -17,16 +18,17 @@ import cn.hutool.db.sql.Wrapper;
import java.sql.Connection; import java.sql.Connection;
import java.sql.PreparedStatement; import java.sql.PreparedStatement;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Set;
/** /**
* ANSI SQL 方言 * ANSI SQL 方言
* *
* @author loolly * @author loolly
* *
*/ */
public class AnsiSqlDialect implements Dialect { public class AnsiSqlDialect implements Dialect {
private static final long serialVersionUID = 2088101129774974580L; private static final long serialVersionUID = 2088101129774974580L;
protected Wrapper wrapper = new Wrapper(); protected Wrapper wrapper = new Wrapper();
@Override @Override
@ -53,7 +55,8 @@ public class AnsiSqlDialect implements Dialect {
} }
// 批量根据第一行数据结构生成SQL占位符 // 批量根据第一行数据结构生成SQL占位符
final SqlBuilder insert = SqlBuilder.create(wrapper).insert(entities[0], this.dialectName()); final SqlBuilder insert = SqlBuilder.create(wrapper).insert(entities[0], this.dialectName());
return StatementUtil.prepareStatementForBatch(conn, insert.build(), insert.getFields(), entities); final Set<String> fields = CollUtil.filter(entities[0].keySet(), StrUtil::isNotBlank);
return StatementUtil.prepareStatementForBatch(conn, insert.build(), fields, entities);
} }
@Override @Override
@ -113,7 +116,7 @@ public class AnsiSqlDialect implements Dialect {
/** /**
* 根据不同数据库在查询SQL语句基础上包装其分页的语句<br> * 根据不同数据库在查询SQL语句基础上包装其分页的语句<br>
* 各自数据库通过重写此方法实现最小改动情况下修改分页语句 * 各自数据库通过重写此方法实现最小改动情况下修改分页语句
* *
* @param find 标准查询语句 * @param find 标准查询语句
* @param page 分页对象 * @param page 分页对象
* @return 分页语句 * @return 分页语句

View File

@ -2,19 +2,16 @@ package cn.hutool.db.dialect.impl;
import cn.hutool.core.lang.Assert; import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ArrayUtil; import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.db.Entity; import cn.hutool.db.Entity;
import cn.hutool.db.Page; import cn.hutool.db.Page;
import cn.hutool.db.StatementUtil; import cn.hutool.db.StatementUtil;
import cn.hutool.db.dialect.DialectName; import cn.hutool.db.dialect.DialectName;
import cn.hutool.db.sql.Condition;
import cn.hutool.db.sql.Query;
import cn.hutool.db.sql.SqlBuilder; import cn.hutool.db.sql.SqlBuilder;
import java.sql.Connection; import java.sql.Connection;
import java.sql.PreparedStatement; import java.sql.PreparedStatement;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Arrays;
import java.util.function.Function;
/** /**
* H2数据库方言 * H2数据库方言
@ -39,18 +36,42 @@ public class H2Dialect extends AnsiSqlDialect {
return find.append(" limit ").append(page.getStartPosition()).append(" , ").append(page.getPageSize()); return find.append(" limit ").append(page.getStartPosition()).append(" , ").append(page.getPageSize());
} }
/**
* 构建用于upsert的PreparedStatement
*
* @param conn 数据库连接对象
* @param entity 数据实体类包含表名
* @param keys 查找字段 如果不提供keys将自动使用主键
* @return PreparedStatement
* @throws SQLException SQL执行异常
*/
@Override @Override
public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException {
final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys); Assert.notEmpty(keys, "Keys must be not empty for H2 MERGE SQL.");
return StatementUtil.prepareStatement(conn, upsert); SqlBuilder.validateEntity(entity);
final SqlBuilder builder = SqlBuilder.create(wrapper);
final StringBuilder fieldsPart = new StringBuilder();
final StringBuilder placeHolder = new StringBuilder();
// 构建字段部分和参数占位符部分
entity.forEach((field, value)->{
if (StrUtil.isNotBlank(field)) {
if (fieldsPart.length() > 0) {
// 非第一个参数追加逗号
fieldsPart.append(", ");
placeHolder.append(", ");
}
fieldsPart.append((null != wrapper) ? wrapper.wrap(field) : field);
placeHolder.append("?");
builder.addParams(value);
}
});
String tableName = entity.getTableName();
if (null != this.wrapper) {
tableName = this.wrapper.wrap(tableName);
}
builder.append("MERGE INTO ").append(tableName)
// 字段列表
.append(" (").append(fieldsPart)
// 更新关键字列表
.append(") KEY(").append(ArrayUtil.join(keys, ", "))
// 更新值列表
.append(") VALUES (").append(placeHolder).append(")");
return StatementUtil.prepareStatement(conn, builder);
} }
} }

View File

@ -1,5 +1,6 @@
package cn.hutool.db.dialect.impl; package cn.hutool.db.dialect.impl;
import cn.hutool.core.util.StrUtil;
import cn.hutool.db.Entity; import cn.hutool.db.Entity;
import cn.hutool.db.Page; import cn.hutool.db.Page;
import cn.hutool.db.StatementUtil; import cn.hutool.db.StatementUtil;
@ -34,17 +35,58 @@ public class MysqlDialect extends AnsiSqlDialect{
} }
/** /**
* 构建用于upsert的PreparedStatement * 构建用于upsert的{@link PreparedStatement}<br>
* MySQL通过主键方式实现Upsert故keys无效生成SQL语法为
* <pre>
* INSERT INTO demo(a,b,c) values(?, ?, ?) ON DUPLICATE KEY UPDATE a=values(a), b=values(b), c=values(c);
* </pre>
* *
* @param conn 数据库连接对象 * @param conn 数据库连接对象
* @param entity 数据实体类包含表名 * @param entity 数据实体类包含表名
* @param keys 查找字段 * @param keys 此参数无效
* @return PreparedStatement * @return PreparedStatement
* @throws SQLException SQL执行异常 * @throws SQLException SQL执行异常
* @since 5.7.20
*/ */
@Override @Override
public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException {
final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys); SqlBuilder.validateEntity(entity);
return StatementUtil.prepareStatement(conn, upsert); final SqlBuilder builder = SqlBuilder.create(wrapper);
final StringBuilder fieldsPart = new StringBuilder();
final StringBuilder placeHolder = new StringBuilder();
final StringBuilder updateHolder = new StringBuilder();
// 构建字段部分和参数占位符部分
entity.forEach((field, value)->{
if (StrUtil.isNotBlank(field)) {
if (fieldsPart.length() > 0) {
// 非第一个参数追加逗号
fieldsPart.append(", ");
placeHolder.append(", ");
updateHolder.append(", ");
}
field = (null != wrapper) ? wrapper.wrap(field) : field;
fieldsPart.append(field);
updateHolder.append(field).append("=values(").append(field).append(")");
placeHolder.append("?");
builder.addParams(value);
}
});
String tableName = entity.getTableName();
if (null != this.wrapper) {
tableName = this.wrapper.wrap(tableName);
}
builder.append("INSERT INTO ").append(tableName)
// 字段列表
.append(" (").append(fieldsPart)
// 更新值列表
.append(") VALUES (").append(placeHolder)
// 主键冲突后的更新操作
.append(") ON DUPLICATE KEY UPDATE ").append(updateHolder);
return StatementUtil.prepareStatement(conn, builder);
} }
} }

View File

@ -1,32 +1,44 @@
package cn.hutool.db.dialect.impl; package cn.hutool.db.dialect.impl;
import cn.hutool.core.util.StrUtil;
import cn.hutool.db.Page; import cn.hutool.db.Page;
import cn.hutool.db.dialect.DialectName; import cn.hutool.db.dialect.DialectName;
import cn.hutool.db.sql.SqlBuilder; import cn.hutool.db.sql.SqlBuilder;
/** /**
* Oracle 方言 * Oracle 方言
* @author loolly
* *
* @author loolly
*/ */
public class OracleDialect extends AnsiSqlDialect{ public class OracleDialect extends AnsiSqlDialect {
private static final long serialVersionUID = 6122761762247483015L; private static final long serialVersionUID = 6122761762247483015L;
/**
* 检查字段值是否为Oracle自增字段自增字段以`.nextval`结尾
*
* @param value 检查的字段值
* @return 是否为Oracle自增字段
* @since 5.7.20
*/
public static boolean isNextVal(Object value) {
return (value instanceof CharSequence) && StrUtil.endWithIgnoreCase(value.toString(), ".nextval");
}
public OracleDialect() { public OracleDialect() {
//Oracle所有字段名用双引号包围防止字段名或表名与系统关键字冲突 //Oracle所有字段名用双引号包围防止字段名或表名与系统关键字冲突
//wrapper = new Wrapper('"'); //wrapper = new Wrapper('"');
} }
@Override @Override
protected SqlBuilder wrapPageSql(SqlBuilder find, Page page) { protected SqlBuilder wrapPageSql(SqlBuilder find, Page page) {
final int[] startEnd = page.getStartEnd(); final int[] startEnd = page.getStartEnd();
return find return find
.insertPreFragment("SELECT * FROM ( SELECT row_.*, rownum rownum_ from ( ") .insertPreFragment("SELECT * FROM ( SELECT row_.*, rownum rownum_ from ( ")
.append(" ) row_ where rownum <= ").append(startEnd[1])// .append(" ) row_ where rownum <= ").append(startEnd[1])//
.append(") table_alias")// .append(") table_alias")//
.append(" where table_alias.rownum_ > ").append(startEnd[0]);// .append(" where table_alias.rownum_ > ").append(startEnd[0]);//
} }
@Override @Override
public String dialectName() { public String dialectName() {
return DialectName.ORACLE.name(); return DialectName.ORACLE.name();

View File

@ -24,6 +24,7 @@ public class PhoenixDialect extends AnsiSqlDialect {
@Override @Override
public PreparedStatement psForUpdate(Connection conn, Entity entity, Query query) throws SQLException { public PreparedStatement psForUpdate(Connection conn, Entity entity, Query query) throws SQLException {
// Phoenix的插入更新语句是统一的统一使用upsert into关键字 // Phoenix的插入更新语句是统一的统一使用upsert into关键字
// Phoenix只支持通过主键更新操作因此query无效自动根据entity中的主键更新
return super.psForInsert(conn, entity); return super.psForInsert(conn, entity);
} }
@ -31,4 +32,10 @@ public class PhoenixDialect extends AnsiSqlDialect {
public String dialectName() { public String dialectName() {
return DialectName.PHOENIX.name(); return DialectName.PHOENIX.name();
} }
@Override
public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException {
// Phoenix只支持通过主键更新操作因此query无效自动根据entity中的主键更新
return psForInsert(conn, entity);
}
} }

View File

@ -1,5 +1,8 @@
package cn.hutool.db.dialect.impl; package cn.hutool.db.dialect.impl;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.db.Entity; import cn.hutool.db.Entity;
import cn.hutool.db.StatementUtil; import cn.hutool.db.StatementUtil;
import cn.hutool.db.dialect.DialectName; import cn.hutool.db.dialect.DialectName;
@ -28,21 +31,48 @@ public class PostgresqlDialect extends AnsiSqlDialect{
return DialectName.POSTGREESQL.name(); return DialectName.POSTGREESQL.name();
} }
/**
* 构建用于upsert的PreparedStatement
*
* @param conn 数据库连接对象
* @param entity 数据实体类包含表名
* @param keys 查找字段 必须是有唯一索引的列且不能为空
* @return PreparedStatement
* @throws SQLException SQL执行异常
*/
@Override @Override
public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException {
if (null==keys || keys.length==0){ Assert.notEmpty(keys, "Keys must be not empty for Postgres.");
throw new SQLException("keys不能为空"); SqlBuilder.validateEntity(entity);
final SqlBuilder builder = SqlBuilder.create(wrapper);
final StringBuilder fieldsPart = new StringBuilder();
final StringBuilder placeHolder = new StringBuilder();
final StringBuilder updateHolder = new StringBuilder();
// 构建字段部分和参数占位符部分
entity.forEach((field, value)->{
if (StrUtil.isNotBlank(field)) {
if (fieldsPart.length() > 0) {
// 非第一个参数追加逗号
fieldsPart.append(", ");
placeHolder.append(", ");
updateHolder.append(", ");
}
final String wrapedField = (null != wrapper) ? wrapper.wrap(field) : field;
fieldsPart.append(wrapedField);
updateHolder.append(wrapedField).append("=EXCLUDED.").append(field);
placeHolder.append("?");
builder.addParams(value);
}
});
String tableName = entity.getTableName();
if (null != this.wrapper) {
tableName = this.wrapper.wrap(tableName);
} }
final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys); builder.append("INSERT INTO ").append(tableName)
return StatementUtil.prepareStatement(conn, upsert); // 字段列表
.append(" (").append(fieldsPart)
// 更新值列表
.append(") VALUES (").append(placeHolder)
// 定义检查冲突的主键或字段
.append(") ON CONFLICT (").append(ArrayUtil.join(keys,", "))
// 主键冲突后的更新操作
.append(") DO UPDATE SET ").append(updateHolder);
return StatementUtil.prepareStatement(conn, builder);
} }
} }

View File

@ -7,13 +7,13 @@ import cn.hutool.core.util.StrUtil;
import cn.hutool.db.DbRuntimeException; import cn.hutool.db.DbRuntimeException;
import cn.hutool.db.Entity; import cn.hutool.db.Entity;
import cn.hutool.db.dialect.DialectName; import cn.hutool.db.dialect.DialectName;
import cn.hutool.db.dialect.impl.OracleDialect;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map.Entry;
/** /**
* SQL构建器<br> * SQL构建器<br>
@ -57,6 +57,24 @@ public class SqlBuilder implements Builder<String> {
return create().append(sql); return create().append(sql);
} }
/**
* 验证实体类对象的有效性
*
* @param entity 实体类对象
* @throws DbRuntimeException SQL异常包装获取元数据信息失败
*/
public static void validateEntity(Entity entity) throws DbRuntimeException {
if (null == entity) {
throw new DbRuntimeException("Entity is null !");
}
if (StrUtil.isBlank(entity.getTableName())) {
throw new DbRuntimeException("Entity`s table name is null !");
}
if (entity.isEmpty()) {
throw new DbRuntimeException("No filed and value in this entity !");
}
}
// --------------------------------------------------------------- Static methods end // --------------------------------------------------------------- Static methods end
// --------------------------------------------------------------- Enums start // --------------------------------------------------------------- Enums start
@ -87,10 +105,6 @@ public class SqlBuilder implements Builder<String> {
// --------------------------------------------------------------- Enums end // --------------------------------------------------------------- Enums end
private final StringBuilder sql = new StringBuilder(); private final StringBuilder sql = new StringBuilder();
/**
* 字段列表仅用于插入和更新
*/
private final List<String> fields = new ArrayList<>();
/** /**
* 占位符对应的值列表 * 占位符对应的值列表
*/ */
@ -146,41 +160,29 @@ public class SqlBuilder implements Builder<String> {
// 验证 // 验证
validateEntity(entity); validateEntity(entity);
if (null != wrapper) {
// 包装表名 entity = wrapper.wrap(entity);
entity.setTableName(wrapper.wrap(entity.getTableName()));
}
final boolean isOracle = DialectName.ORACLE.match(dialectName);// 对Oracle的特殊处理 final boolean isOracle = DialectName.ORACLE.match(dialectName);// 对Oracle的特殊处理
final StringBuilder fieldsPart = new StringBuilder(); final StringBuilder fieldsPart = new StringBuilder();
final StringBuilder placeHolder = new StringBuilder(); final StringBuilder placeHolder = new StringBuilder();
boolean isFirst = true; entity.forEach((field, value) -> {
String field; if (StrUtil.isNotBlank(field)) {
Object value; if (fieldsPart.length() > 0) {
for (Entry<String, Object> entry : entity.entrySet()) {
field = entry.getKey();
value = entry.getValue();
if (StrUtil.isNotBlank(field) /* && null != value */) {
if (isFirst) {
isFirst = false;
} else {
// 非第一个参数追加逗号 // 非第一个参数追加逗号
fieldsPart.append(", "); fieldsPart.append(", ");
placeHolder.append(", "); placeHolder.append(", ");
} }
this.fields.add(field);
fieldsPart.append((null != wrapper) ? wrapper.wrap(field) : field); fieldsPart.append((null != wrapper) ? wrapper.wrap(field) : field);
if (isOracle && value instanceof String && StrUtil.endWithIgnoreCase((String) value, ".nextval")) { if (isOracle && OracleDialect.isNextVal(value)) {
// Oracle的特殊自增键通过字段名.nextval获得下一个值 // Oracle的特殊自增键通过字段名.nextval获得下一个值
placeHolder.append(value); placeHolder.append(value);
} else { } else {
// 普通字段使用占位符
placeHolder.append("?"); placeHolder.append("?");
this.paramValues.add(value); this.paramValues.add(value);
} }
} }
} });
// issue#1656@Github Phoenix兼容 // issue#1656@Github Phoenix兼容
if (DialectName.PHOENIX.match(dialectName)) { if (DialectName.PHOENIX.match(dialectName)) {
@ -189,94 +191,18 @@ public class SqlBuilder implements Builder<String> {
sql.append("INSERT INTO "); sql.append("INSERT INTO ");
} }
sql.append(entity.getTableName()) String tableName = entity.getTableName();
if (null != this.wrapper) {
// 包装表名 entity = wrapper.wrap(entity);
tableName = this.wrapper.wrap(tableName);
}
sql.append(tableName)
.append(" (").append(fieldsPart).append(") VALUES (")// .append(" (").append(fieldsPart).append(") VALUES (")//
.append(placeHolder).append(")"); .append(placeHolder).append(")");
return this; return this;
} }
/**
* 插入<br>
* 插入会忽略空的字段名及其对应值但是对于有字段名对应值为{@code null}的情况不忽略
*
* @param entity 实体
* @param dialectName 方言名用于对特殊数据库特殊处理
* @param keys 根据何字段来确认唯一性不传则用主键
* @return 自己
* @since 5.7.21
*/
public SqlBuilder upsert(Entity entity, String dialectName, String... keys) {
// 验证
validateEntity(entity);
if (null != wrapper) {
// 包装表名 entity = wrapper.wrap(entity);
entity.setTableName(wrapper.wrap(entity.getTableName()));
}
final boolean isOracle = DialectName.ORACLE.match(dialectName);// 对Oracle的特殊处理
final StringBuilder fieldsPart = new StringBuilder();
final StringBuilder placeHolder = new StringBuilder();
boolean isFirst = true;
String field;
Object value;
for (Entry<String, Object> entry : entity.entrySet()) {
field = entry.getKey();
value = entry.getValue();
if (StrUtil.isNotBlank(field) /* && null != value */) {
if (isFirst) {
isFirst = false;
} else {
// 非第一个参数追加逗号
fieldsPart.append(", ");
placeHolder.append(", ");
}
this.fields.add(field);
fieldsPart.append((null != wrapper) ? wrapper.wrap(field) : field);
if (isOracle && value instanceof String && StrUtil.endWithIgnoreCase((String) value, ".nextval")) {
// Oracle的特殊自增键通过字段名.nextval获得下一个值
placeHolder.append(value);
} else {
placeHolder.append("?");
this.paramValues.add(value);
}
}
}
// issue#1656@Github Phoenix兼容
if (DialectName.PHOENIX.match(dialectName)) {
sql.append("UPSERT INTO ").append(entity.getTableName());
} else if (DialectName.MYSQL.match(dialectName)) {
sql.append("INSERT INTO ");
sql.append(entity.getTableName())
.append(" (").append(fieldsPart).append(") VALUES (")
.append(placeHolder).append(") on duplicate key update ")
.append(ArrayUtil.join(ArrayUtil.map(entity.keySet().toArray(), String.class, (k) -> k + "=values(" + k + ")"), ","));
} else if (DialectName.H2.match(dialectName)) {
sql.append("MERGE INTO ").append(entity.getTableName());
if (null != keys && keys.length > 0) {
sql.append(" KEY(").append(ArrayUtil.join(keys, ","))
.append(") VALUES (")
.append(placeHolder)
.append(")");
}
} else if (DialectName.POSTGREESQL.match(dialectName)) {
sql.append("INSERT INTO ");
sql.append(entity.getTableName())
.append(" (").append(fieldsPart).append(") VALUES (")
.append(placeHolder).append(") on conflict (")
.append(ArrayUtil.join(keys,","))
.append(") do update set ")
.append(ArrayUtil.join(ArrayUtil.map(entity.keySet().toArray(), String.class, (k) -> k + "=excluded." + k ), ","));
} else {
throw new RuntimeException(dialectName + " not support yet");
}
return this;
}
/** /**
* 删除 * 删除
* *
@ -308,25 +234,22 @@ public class SqlBuilder implements Builder<String> {
// 验证 // 验证
validateEntity(entity); validateEntity(entity);
String tableName = entity.getTableName();
if (null != wrapper) { if (null != wrapper) {
// 包装表名 // 包装表名
entity.setTableName(wrapper.wrap(entity.getTableName())); tableName = wrapper.wrap(tableName);
} }
sql.append("UPDATE ").append(entity.getTableName()).append(" SET "); sql.append("UPDATE ").append(tableName).append(" SET ");
String field; entity.forEach((field, value) -> {
for (Entry<String, Object> entry : entity.entrySet()) {
field = entry.getKey();
if (StrUtil.isNotBlank(field)) { if (StrUtil.isNotBlank(field)) {
if (paramValues.size() > 0) { if (paramValues.size() > 0) {
sql.append(", "); sql.append(", ");
} }
this.fields.add(field);
sql.append((null != wrapper) ? wrapper.wrap(field) : field).append(" = ? "); sql.append((null != wrapper) ? wrapper.wrap(field) : field).append(" = ? ");
this.paramValues.add(entry.getValue());// 更新不对空做处理因为存在清空字段的情况 this.paramValues.add(value);// 更新不对空做处理因为存在清空字段的情况
} }
} });
return this; return this;
} }
@ -653,24 +576,6 @@ public class SqlBuilder implements Builder<String> {
} }
// --------------------------------------------------------------- Builder end // --------------------------------------------------------------- Builder end
/**
* 获得插入或更新的数据库字段列表
*
* @return 插入或更新的数据库字段列表
*/
public String[] getFieldArray() {
return this.fields.toArray(new String[0]);
}
/**
* 获得插入或更新的数据库字段列表
*
* @return 插入或更新的数据库字段列表
*/
public List<String> getFields() {
return this.fields;
}
/** /**
* 获得占位符对应的值列表<br> * 获得占位符对应的值列表<br>
* *
@ -725,23 +630,5 @@ public class SqlBuilder implements Builder<String> {
return ConditionBuilder.of(conditions).build(this.paramValues); return ConditionBuilder.of(conditions).build(this.paramValues);
} }
/**
* 验证实体类对象的有效性
*
* @param entity 实体类对象
* @throws DbRuntimeException SQL异常包装获取元数据信息失败
*/
private static void validateEntity(Entity entity) throws DbRuntimeException {
if (null == entity) {
throw new DbRuntimeException("Entity is null !");
}
if (StrUtil.isBlank(entity.getTableName())) {
throw new DbRuntimeException("Entity`s table name is null !");
}
if (entity.isEmpty()) {
throw new DbRuntimeException("No filed and value in this entity !");
}
}
// --------------------------------------------------------------- private method end // --------------------------------------------------------------- private method end
} }

View File

@ -15,15 +15,19 @@ import java.util.Map.Entry;
/** /**
* 包装器<br> * 包装器<br>
* 主要用于字段名的包装在字段名的前后加字符例如反引号来避免与数据库的关键字冲突 * 主要用于字段名的包装在字段名的前后加字符例如反引号来避免与数据库的关键字冲突
* @author Looly
* *
* @author Looly
*/ */
public class Wrapper implements Serializable { public class Wrapper implements Serializable {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
/** 前置包装符号 */ /**
* 前置包装符号
*/
private Character preWrapQuote; private Character preWrapQuote;
/** 后置包装符号 */ /**
* 后置包装符号
*/
private Character sufWrapQuote; private Character sufWrapQuote;
public Wrapper() { public Wrapper() {
@ -31,6 +35,7 @@ public class Wrapper implements Serializable {
/** /**
* 构造 * 构造
*
* @param wrapQuote 单包装字符 * @param wrapQuote 单包装字符
*/ */
public Wrapper(Character wrapQuote) { public Wrapper(Character wrapQuote) {
@ -40,6 +45,7 @@ public class Wrapper implements Serializable {
/** /**
* 包装符号 * 包装符号
*
* @param preWrapQuote 前置包装符号 * @param preWrapQuote 前置包装符号
* @param sufWrapQuote 后置包装符号 * @param sufWrapQuote 后置包装符号
*/ */
@ -49,14 +55,17 @@ public class Wrapper implements Serializable {
} }
//--------------------------------------------------------------- Getters and Setters start //--------------------------------------------------------------- Getters and Setters start
/** /**
* @return 前置包装符号 * @return 前置包装符号
*/ */
public char getPreWrapQuote() { public char getPreWrapQuote() {
return preWrapQuote; return preWrapQuote;
} }
/** /**
* 设置前置包装的符号 * 设置前置包装的符号
*
* @param preWrapQuote 前置包装符号 * @param preWrapQuote 前置包装符号
*/ */
public void setPreWrapQuote(Character preWrapQuote) { public void setPreWrapQuote(Character preWrapQuote) {
@ -69,8 +78,10 @@ public class Wrapper implements Serializable {
public char getSufWrapQuote() { public char getSufWrapQuote() {
return sufWrapQuote; return sufWrapQuote;
} }
/** /**
* 设置后置包装的符号 * 设置后置包装的符号
*
* @param sufWrapQuote 后置包装符号 * @param sufWrapQuote 后置包装符号
*/ */
public void setSufWrapQuote(Character sufWrapQuote) { public void setSufWrapQuote(Character sufWrapQuote) {
@ -81,26 +92,27 @@ public class Wrapper implements Serializable {
/** /**
* 包装字段名<br> * 包装字段名<br>
* 有时字段与SQL的某些关键字冲突导致SQL出错因此需要将字段名用单引号或者反引号包装起来避免冲突 * 有时字段与SQL的某些关键字冲突导致SQL出错因此需要将字段名用单引号或者反引号包装起来避免冲突
*
* @param field 字段名 * @param field 字段名
* @return 包装后的字段名 * @return 包装后的字段名
*/ */
public String wrap(String field){ public String wrap(String field) {
if(preWrapQuote == null || sufWrapQuote == null || StrUtil.isBlank(field)) { if (preWrapQuote == null || sufWrapQuote == null || StrUtil.isBlank(field)) {
return field; return field;
} }
//如果已经包含包装的引号返回原字符 //如果已经包含包装的引号返回原字符
if(StrUtil.isSurround(field, preWrapQuote, sufWrapQuote)){ if (StrUtil.isSurround(field, preWrapQuote, sufWrapQuote)) {
return field; return field;
} }
//如果字段中包含通配符或者括号字段通配符或者函数不做包装 //如果字段中包含通配符或者括号字段通配符或者函数不做包装
if(StrUtil.containsAnyIgnoreCase(field, "*", "(", " ", " as ")) { if (StrUtil.containsAnyIgnoreCase(field, "*", "(", " ", " as ")) {
return field; return field;
} }
//对于Oracle这类数据库表名中包含用户名需要单独拆分包装 //对于Oracle这类数据库表名中包含用户名需要单独拆分包装
if(field.contains(StrUtil.DOT)){ if (field.contains(StrUtil.DOT)) {
final Collection<String> target = CollUtil.edit(StrUtil.split(field, CharUtil.DOT, 2), t -> StrUtil.format("{}{}{}", preWrapQuote, t, sufWrapQuote)); final Collection<String> target = CollUtil.edit(StrUtil.split(field, CharUtil.DOT, 2), t -> StrUtil.format("{}{}{}", preWrapQuote, t, sufWrapQuote));
return CollectionUtil.join(target, StrUtil.DOT); return CollectionUtil.join(target, StrUtil.DOT);
} }
@ -111,16 +123,17 @@ public class Wrapper implements Serializable {
/** /**
* 包装字段名<br> * 包装字段名<br>
* 有时字段与SQL的某些关键字冲突导致SQL出错因此需要将字段名用单引号或者反引号包装起来避免冲突 * 有时字段与SQL的某些关键字冲突导致SQL出错因此需要将字段名用单引号或者反引号包装起来避免冲突
*
* @param fields 字段名 * @param fields 字段名
* @return 包装后的字段名 * @return 包装后的字段名
*/ */
public String[] wrap(String... fields){ public String[] wrap(String... fields) {
if(ArrayUtil.isEmpty(fields)) { if (ArrayUtil.isEmpty(fields)) {
return fields; return fields;
} }
String[] wrappedFields = new String[fields.length]; String[] wrappedFields = new String[fields.length];
for(int i = 0; i < fields.length; i++) { for (int i = 0; i < fields.length; i++) {
wrappedFields[i] = wrap(fields[i]); wrappedFields[i] = wrap(fields[i]);
} }
@ -130,11 +143,12 @@ public class Wrapper implements Serializable {
/** /**
* 包装字段名<br> * 包装字段名<br>
* 有时字段与SQL的某些关键字冲突导致SQL出错因此需要将字段名用单引号或者反引号包装起来避免冲突 * 有时字段与SQL的某些关键字冲突导致SQL出错因此需要将字段名用单引号或者反引号包装起来避免冲突
*
* @param fields 字段名 * @param fields 字段名
* @return 包装后的字段名 * @return 包装后的字段名
*/ */
public Collection<String> wrap(Collection<String> fields){ public Collection<String> wrap(Collection<String> fields) {
if(CollectionUtil.isEmpty(fields)) { if (CollectionUtil.isEmpty(fields)) {
return fields; return fields;
} }
@ -142,13 +156,14 @@ public class Wrapper implements Serializable {
} }
/** /**
* 包装字段名<br> * 包装表名和字段名此方法返回一个新的Entity实体类<br>
* 有时字段与SQL的某些关键字冲突导致SQL出错因此需要将字段名用单引号或者反引号包装起来避免冲突 * 有时字段与SQL的某些关键字冲突导致SQL出错因此需要将字段名用单引号或者反引号包装起来避免冲突
*
* @param entity 被包装的实体 * @param entity 被包装的实体
* @return 包装后的字段名 * @return 新的实体
*/ */
public Entity wrap(Entity entity){ public Entity wrap(Entity entity) {
if(null == entity) { if (null == entity) {
return null; return null;
} }
@ -168,14 +183,15 @@ public class Wrapper implements Serializable {
/** /**
* 包装字段名<br> * 包装字段名<br>
* 有时字段与SQL的某些关键字冲突导致SQL出错因此需要将字段名用单引号或者反引号包装起来避免冲突 * 有时字段与SQL的某些关键字冲突导致SQL出错因此需要将字段名用单引号或者反引号包装起来避免冲突
*
* @param conditions 被包装的实体 * @param conditions 被包装的实体
* @return 包装后的字段名 * @return 包装后的字段名
*/ */
public Condition[] wrap(Condition... conditions){ public Condition[] wrap(Condition... conditions) {
final Condition[] clonedConditions = new Condition[conditions.length]; final Condition[] clonedConditions = new Condition[conditions.length];
if(ArrayUtil.isNotEmpty(conditions)) { if (ArrayUtil.isNotEmpty(conditions)) {
Condition clonedCondition; Condition clonedCondition;
for(int i = 0; i < conditions.length; i++) { for (int i = 0; i < conditions.length; i++) {
clonedCondition = conditions[i].clone(); clonedCondition = conditions[i].clone();
clonedCondition.setField(wrap(clonedCondition.getField())); clonedCondition.setField(wrap(clonedCondition.getField()));
clonedConditions[i] = clonedCondition; clonedConditions[i] = clonedCondition;

View File

@ -1,6 +1,5 @@
package cn.hutool.db; package cn.hutool.db;
import com.alibaba.druid.support.json.JSONUtils;
import org.junit.Assert; import org.junit.Assert;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
@ -40,6 +39,7 @@ public class H2Test {
List<Entity> query = Db.use(DS_GROUP_NAME).find(Entity.create("test")); List<Entity> query = Db.use(DS_GROUP_NAME).find(Entity.create("test"));
Assert.assertEquals(4, query.size()); Assert.assertEquals(4, query.size());
} }
@Test @Test
public void upsertTest() throws SQLException { public void upsertTest() throws SQLException {
Db db=Db.use(DS_GROUP_NAME); Db db=Db.use(DS_GROUP_NAME);