mirror of
https://gitee.com/chinabugotech/hutool.git
synced 2025-04-19 03:01:48 +08:00
fix upsert
This commit is contained in:
parent
9da17cf6c4
commit
8e67251fd3
@ -2343,11 +2343,7 @@ public class CollUtil {
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <K, V> ArrayList<V> valuesOfKeys(Map<K, V> map, K... keys) {
|
||||
final ArrayList<V> values = new ArrayList<>();
|
||||
for (K k : keys) {
|
||||
values.add(map.get(k));
|
||||
}
|
||||
return values;
|
||||
return MapUtil.valuesOfKeys(map, new ArrayIter<>(keys));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -2377,11 +2373,7 @@ public class CollUtil {
|
||||
* @since 3.0.9
|
||||
*/
|
||||
public static <K, V> ArrayList<V> valuesOfKeys(Map<K, V> map, Iterator<K> keys) {
|
||||
final ArrayList<V> values = new ArrayList<>();
|
||||
while (keys.hasNext()) {
|
||||
values.add(map.get(keys.next()));
|
||||
}
|
||||
return values;
|
||||
return MapUtil.valuesOfKeys(map, keys);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------------------- sort
|
||||
|
@ -1354,4 +1354,23 @@ public class MapUtil {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从Map中获取指定键列表对应的值列表<br>
|
||||
* 如果key在map中不存在或key对应值为null,则返回值列表对应位置的值也为null
|
||||
*
|
||||
* @param <K> 键类型
|
||||
* @param <V> 值类型
|
||||
* @param map {@link Map}
|
||||
* @param keys 键列表
|
||||
* @return 值列表
|
||||
* @since 5.7.20
|
||||
*/
|
||||
public static <K, V> ArrayList<V> valuesOfKeys(Map<K, V> map, Iterator<K> keys) {
|
||||
final ArrayList<V> values = new ArrayList<>();
|
||||
while (keys.hasNext()) {
|
||||
values.add(map.get(keys.next()));
|
||||
}
|
||||
return values;
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import cn.hutool.core.comparator.ComparableComparator;
|
||||
import cn.hutool.core.date.DateUtil;
|
||||
import cn.hutool.core.lang.Dict;
|
||||
import cn.hutool.core.map.MapUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import org.junit.Assert;
|
||||
@ -302,6 +303,14 @@ public class CollUtilTest {
|
||||
Assert.assertEquals(CollUtil.newArrayList("b", "c"), filtered);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterSetTest() {
|
||||
Set<String> set = CollUtil.newLinkedHashSet("a", "b", "", " ", "c");
|
||||
Set<String> filtered = CollUtil.filter(set, StrUtil::isNotBlank);
|
||||
|
||||
Assert.assertEquals(CollUtil.newLinkedHashSet("a", "b", "c"), filtered);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void filterRemoveTest() {
|
||||
ArrayList<String> list = CollUtil.newArrayList("a", "b", "c");
|
||||
|
@ -99,7 +99,12 @@ public class DialectRunner implements Serializable {
|
||||
* @since 5.7.20
|
||||
*/
|
||||
public int upsert(Connection conn, Entity record, String... keys) throws SQLException {
|
||||
PreparedStatement ps = getDialect().psForUpsert(conn, record, keys);
|
||||
PreparedStatement ps = null;
|
||||
try{
|
||||
ps = getDialect().psForUpsert(conn, record, keys);
|
||||
}catch (SQLException ignore){
|
||||
// 方言不支持,使用默认
|
||||
}
|
||||
if (null != ps) {
|
||||
try {
|
||||
return ps.executeUpdate();
|
||||
|
@ -194,7 +194,7 @@ public class StatementUtil {
|
||||
* @throws SQLException SQL异常
|
||||
* @since 4.6.7
|
||||
*/
|
||||
public static PreparedStatement prepareStatementForBatch(Connection conn, String sql, List<String> fields, Entity... entities) throws SQLException {
|
||||
public static PreparedStatement prepareStatementForBatch(Connection conn, String sql, Iterable<String> fields, Entity... entities) throws SQLException {
|
||||
Assert.notBlank(sql, "Sql String must be not blank!");
|
||||
|
||||
sql = sql.trim();
|
||||
|
@ -37,7 +37,8 @@ public interface Dialect extends Serializable {
|
||||
// -------------------------------------------- Execute
|
||||
|
||||
/**
|
||||
* 构建用于插入的PreparedStatement
|
||||
* 构建用于插入的{@link PreparedStatement}<br>
|
||||
* 用户实现需按照数据库方言格式,将{@link Entity}转换为带有占位符的SQL语句及参数列表
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param entity 数据实体类(包含表名)
|
||||
@ -47,7 +48,8 @@ public interface Dialect extends Serializable {
|
||||
PreparedStatement psForInsert(Connection conn, Entity entity) throws SQLException;
|
||||
|
||||
/**
|
||||
* 构建用于批量插入的PreparedStatement
|
||||
* 构建用于批量插入的PreparedStatement<br>
|
||||
* 用户实现需按照数据库方言格式,将{@link Entity}转换为带有占位符的SQL语句及参数列表
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param entities 数据实体,实体的结构必须全部一致,否则插入结果将不可预知
|
||||
@ -57,7 +59,9 @@ public interface Dialect extends Serializable {
|
||||
PreparedStatement psForInsertBatch(Connection conn, Entity... entities) throws SQLException;
|
||||
|
||||
/**
|
||||
* 构建用于删除的PreparedStatement
|
||||
* 构建用于删除的{@link PreparedStatement}<br>
|
||||
* 用户实现需按照数据库方言格式,将{@link Query}转换为带有占位符的SQL语句及参数列表<br>
|
||||
* {@link Query}中包含了删除所需的表名、查询条件等信息,可借助SqlBuilder完成SQL语句生成。
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param query 查找条件(包含表名)
|
||||
@ -67,7 +71,9 @@ public interface Dialect extends Serializable {
|
||||
PreparedStatement psForDelete(Connection conn, Query query) throws SQLException;
|
||||
|
||||
/**
|
||||
* 构建用于更新的PreparedStatement
|
||||
* 构建用于更新的{@link PreparedStatement}<br>
|
||||
* 用户实现需按照数据库方言格式,将{@link Entity}配合{@link Query}转换为带有占位符的SQL语句及参数列表<br>
|
||||
* 其中{@link Entity}中包含需要更新的数据信息,{@link Query}包含更新的查找条件信息。
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param entity 数据实体类(包含表名)
|
||||
@ -80,7 +86,9 @@ public interface Dialect extends Serializable {
|
||||
// -------------------------------------------- Query
|
||||
|
||||
/**
|
||||
* 构建用于获取多条记录的PreparedStatement
|
||||
* 构建用于获取多条记录的{@link PreparedStatement}<br>
|
||||
* 用户实现需按照数据库方言格式,将{@link Query}转换为带有占位符的SQL语句及参数列表<br>
|
||||
* {@link Query}中包含了查询所需的表名、查询条件等信息,可借助SqlBuilder完成SQL语句生成。
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param query 查询条件(包含表名)
|
||||
@ -90,7 +98,9 @@ public interface Dialect extends Serializable {
|
||||
PreparedStatement psForFind(Connection conn, Query query) throws SQLException;
|
||||
|
||||
/**
|
||||
* 构建用于分页查询的PreparedStatement
|
||||
* 构建用于分页查询的{@link PreparedStatement}<br>
|
||||
* 用户实现需按照数据库方言格式,将{@link Query}转换为带有占位符的SQL语句及参数列表<br>
|
||||
* {@link Query}中包含了分页查询所需的表名、查询条件、分页等信息,可借助SqlBuilder完成SQL语句生成。
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param query 查询条件(包含表名)
|
||||
@ -100,7 +110,7 @@ public interface Dialect extends Serializable {
|
||||
PreparedStatement psForPage(Connection conn, Query query) throws SQLException;
|
||||
|
||||
/**
|
||||
* 构建用于分页查询的PreparedStatement<br>
|
||||
* 构建用于分页查询的{@link PreparedStatement}<br>
|
||||
* 可以在此方法中使用{@link SqlBuilder#orderBy(Order...)}方法加入排序信息,
|
||||
* 排序信息通过{@link Page#getOrders()}获取
|
||||
*
|
||||
@ -114,7 +124,9 @@ public interface Dialect extends Serializable {
|
||||
PreparedStatement psForPage(Connection conn, SqlBuilder sqlBuilder, Page page) throws SQLException;
|
||||
|
||||
/**
|
||||
* 构建用于查询行数的PreparedStatement
|
||||
* 构建用于查询行数的{@link PreparedStatement}<br>
|
||||
* 用户实现需按照数据库方言格式,将{@link Query}转换为带有占位符的SQL语句及参数列表<br>
|
||||
* {@link Query}中包含了表名、查询条件等信息,可借助SqlBuilder完成SQL语句生成。
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param query 查询条件(包含表名)
|
||||
@ -127,7 +139,9 @@ public interface Dialect extends Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建用于查询行数的PreparedStatement
|
||||
* 构建用于查询行数的{@link PreparedStatement}<br>
|
||||
* 用户实现需按照数据库方言格式,将{@link Query}转换为带有占位符的SQL语句及参数列表<br>
|
||||
* {@link Query}中包含了表名、查询条件等信息,可借助SqlBuilder完成SQL语句生成。
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param sqlBuilder 查询语句,应该包含分页等信息
|
||||
@ -144,18 +158,18 @@ public interface Dialect extends Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建用于upsert的PreparedStatement<br>
|
||||
* 方言实现需实现此默认方法,默认返回{@code null}
|
||||
* 构建用于upsert的{@link PreparedStatement}<br>
|
||||
* 方言实现需实现此默认方法,如果没有实现,抛出{@link SQLException}
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param entity 数据实体类(包含表名)
|
||||
* @param keys 查找字段
|
||||
* @param keys 查找字段,某些数据库此字段必须,如H2,某些数据库无需此字段,如MySQL(通过主键)
|
||||
* @return PreparedStatement
|
||||
* @throws SQLException SQL执行异常
|
||||
* @throws SQLException SQL执行异常,或方言数据不支持此操作
|
||||
* @since 5.7.20
|
||||
*/
|
||||
default PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException {
|
||||
return null;
|
||||
throw new SQLException("Unsupported upsert operation of " + dialectName());
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
package cn.hutool.db.dialect.impl;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.util.ArrayUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
@ -17,16 +18,17 @@ import cn.hutool.db.sql.Wrapper;
|
||||
import java.sql.Connection;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.SQLException;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* ANSI SQL 方言
|
||||
*
|
||||
*
|
||||
* @author loolly
|
||||
*
|
||||
*/
|
||||
public class AnsiSqlDialect implements Dialect {
|
||||
private static final long serialVersionUID = 2088101129774974580L;
|
||||
|
||||
|
||||
protected Wrapper wrapper = new Wrapper();
|
||||
|
||||
@Override
|
||||
@ -53,7 +55,8 @@ public class AnsiSqlDialect implements Dialect {
|
||||
}
|
||||
// 批量,根据第一行数据结构生成SQL占位符
|
||||
final SqlBuilder insert = SqlBuilder.create(wrapper).insert(entities[0], this.dialectName());
|
||||
return StatementUtil.prepareStatementForBatch(conn, insert.build(), insert.getFields(), entities);
|
||||
final Set<String> fields = CollUtil.filter(entities[0].keySet(), StrUtil::isNotBlank);
|
||||
return StatementUtil.prepareStatementForBatch(conn, insert.build(), fields, entities);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -113,7 +116,7 @@ public class AnsiSqlDialect implements Dialect {
|
||||
/**
|
||||
* 根据不同数据库在查询SQL语句基础上包装其分页的语句<br>
|
||||
* 各自数据库通过重写此方法实现最小改动情况下修改分页语句
|
||||
*
|
||||
*
|
||||
* @param find 标准查询语句
|
||||
* @param page 分页对象
|
||||
* @return 分页语句
|
||||
|
@ -2,19 +2,16 @@ package cn.hutool.db.dialect.impl;
|
||||
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.util.ArrayUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.db.Entity;
|
||||
import cn.hutool.db.Page;
|
||||
import cn.hutool.db.StatementUtil;
|
||||
import cn.hutool.db.dialect.DialectName;
|
||||
import cn.hutool.db.sql.Condition;
|
||||
import cn.hutool.db.sql.Query;
|
||||
import cn.hutool.db.sql.SqlBuilder;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.SQLException;
|
||||
import java.util.Arrays;
|
||||
import java.util.function.Function;
|
||||
|
||||
/**
|
||||
* H2数据库方言
|
||||
@ -39,18 +36,42 @@ public class H2Dialect extends AnsiSqlDialect {
|
||||
return find.append(" limit ").append(page.getStartPosition()).append(" , ").append(page.getPageSize());
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建用于upsert的PreparedStatement
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param entity 数据实体类(包含表名)
|
||||
* @param keys 查找字段 如果不提供keys将自动使用主键
|
||||
* @return PreparedStatement
|
||||
* @throws SQLException SQL执行异常
|
||||
*/
|
||||
@Override
|
||||
public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException {
|
||||
final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys);
|
||||
return StatementUtil.prepareStatement(conn, upsert);
|
||||
Assert.notEmpty(keys, "Keys must be not empty for H2 MERGE SQL.");
|
||||
SqlBuilder.validateEntity(entity);
|
||||
final SqlBuilder builder = SqlBuilder.create(wrapper);
|
||||
|
||||
final StringBuilder fieldsPart = new StringBuilder();
|
||||
final StringBuilder placeHolder = new StringBuilder();
|
||||
|
||||
// 构建字段部分和参数占位符部分
|
||||
entity.forEach((field, value)->{
|
||||
if (StrUtil.isNotBlank(field)) {
|
||||
if (fieldsPart.length() > 0) {
|
||||
// 非第一个参数,追加逗号
|
||||
fieldsPart.append(", ");
|
||||
placeHolder.append(", ");
|
||||
}
|
||||
|
||||
fieldsPart.append((null != wrapper) ? wrapper.wrap(field) : field);
|
||||
placeHolder.append("?");
|
||||
builder.addParams(value);
|
||||
}
|
||||
});
|
||||
|
||||
String tableName = entity.getTableName();
|
||||
if (null != this.wrapper) {
|
||||
tableName = this.wrapper.wrap(tableName);
|
||||
}
|
||||
builder.append("MERGE INTO ").append(tableName)
|
||||
// 字段列表
|
||||
.append(" (").append(fieldsPart)
|
||||
// 更新关键字列表
|
||||
.append(") KEY(").append(ArrayUtil.join(keys, ", "))
|
||||
// 更新值列表
|
||||
.append(") VALUES (").append(placeHolder).append(")");
|
||||
|
||||
return StatementUtil.prepareStatement(conn, builder);
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
package cn.hutool.db.dialect.impl;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.db.Entity;
|
||||
import cn.hutool.db.Page;
|
||||
import cn.hutool.db.StatementUtil;
|
||||
@ -34,17 +35,58 @@ public class MysqlDialect extends AnsiSqlDialect{
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建用于upsert的PreparedStatement
|
||||
* 构建用于upsert的{@link PreparedStatement}<br>
|
||||
* MySQL通过主键方式实现Upsert,故keys无效,生成SQL语法为:
|
||||
* <pre>
|
||||
* INSERT INTO demo(a,b,c) values(?, ?, ?) ON DUPLICATE KEY UPDATE a=values(a), b=values(b), c=values(c);
|
||||
* </pre>
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param entity 数据实体类(包含表名)
|
||||
* @param keys 查找字段
|
||||
* @param keys 此参数无效
|
||||
* @return PreparedStatement
|
||||
* @throws SQLException SQL执行异常
|
||||
* @since 5.7.20
|
||||
*/
|
||||
@Override
|
||||
public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException {
|
||||
final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys);
|
||||
return StatementUtil.prepareStatement(conn, upsert);
|
||||
SqlBuilder.validateEntity(entity);
|
||||
final SqlBuilder builder = SqlBuilder.create(wrapper);
|
||||
|
||||
final StringBuilder fieldsPart = new StringBuilder();
|
||||
final StringBuilder placeHolder = new StringBuilder();
|
||||
final StringBuilder updateHolder = new StringBuilder();
|
||||
|
||||
// 构建字段部分和参数占位符部分
|
||||
entity.forEach((field, value)->{
|
||||
if (StrUtil.isNotBlank(field)) {
|
||||
if (fieldsPart.length() > 0) {
|
||||
// 非第一个参数,追加逗号
|
||||
fieldsPart.append(", ");
|
||||
placeHolder.append(", ");
|
||||
updateHolder.append(", ");
|
||||
}
|
||||
|
||||
field = (null != wrapper) ? wrapper.wrap(field) : field;
|
||||
fieldsPart.append(field);
|
||||
updateHolder.append(field).append("=values(").append(field).append(")");
|
||||
placeHolder.append("?");
|
||||
builder.addParams(value);
|
||||
}
|
||||
});
|
||||
|
||||
String tableName = entity.getTableName();
|
||||
if (null != this.wrapper) {
|
||||
tableName = this.wrapper.wrap(tableName);
|
||||
}
|
||||
builder.append("INSERT INTO ").append(tableName)
|
||||
// 字段列表
|
||||
.append(" (").append(fieldsPart)
|
||||
// 更新值列表
|
||||
.append(") VALUES (").append(placeHolder)
|
||||
// 主键冲突后的更新操作
|
||||
.append(") ON DUPLICATE KEY UPDATE ").append(updateHolder);
|
||||
|
||||
return StatementUtil.prepareStatement(conn, builder);
|
||||
}
|
||||
}
|
||||
|
@ -1,32 +1,44 @@
|
||||
package cn.hutool.db.dialect.impl;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.db.Page;
|
||||
import cn.hutool.db.dialect.DialectName;
|
||||
import cn.hutool.db.sql.SqlBuilder;
|
||||
|
||||
/**
|
||||
* Oracle 方言
|
||||
* @author loolly
|
||||
*
|
||||
* @author loolly
|
||||
*/
|
||||
public class OracleDialect extends AnsiSqlDialect{
|
||||
public class OracleDialect extends AnsiSqlDialect {
|
||||
private static final long serialVersionUID = 6122761762247483015L;
|
||||
|
||||
/**
|
||||
* 检查字段值是否为Oracle自增字段,自增字段以`.nextval`结尾
|
||||
*
|
||||
* @param value 检查的字段值
|
||||
* @return 是否为Oracle自增字段
|
||||
* @since 5.7.20
|
||||
*/
|
||||
public static boolean isNextVal(Object value) {
|
||||
return (value instanceof CharSequence) && StrUtil.endWithIgnoreCase(value.toString(), ".nextval");
|
||||
}
|
||||
|
||||
public OracleDialect() {
|
||||
//Oracle所有字段名用双引号包围,防止字段名或表名与系统关键字冲突
|
||||
//wrapper = new Wrapper('"');
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected SqlBuilder wrapPageSql(SqlBuilder find, Page page) {
|
||||
final int[] startEnd = page.getStartEnd();
|
||||
return find
|
||||
.insertPreFragment("SELECT * FROM ( SELECT row_.*, rownum rownum_ from ( ")
|
||||
.append(" ) row_ where rownum <= ").append(startEnd[1])//
|
||||
.append(") table_alias")//
|
||||
.append(" where table_alias.rownum_ > ").append(startEnd[0]);//
|
||||
.insertPreFragment("SELECT * FROM ( SELECT row_.*, rownum rownum_ from ( ")
|
||||
.append(" ) row_ where rownum <= ").append(startEnd[1])//
|
||||
.append(") table_alias")//
|
||||
.append(" where table_alias.rownum_ > ").append(startEnd[0]);//
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String dialectName() {
|
||||
return DialectName.ORACLE.name();
|
||||
|
@ -24,6 +24,7 @@ public class PhoenixDialect extends AnsiSqlDialect {
|
||||
@Override
|
||||
public PreparedStatement psForUpdate(Connection conn, Entity entity, Query query) throws SQLException {
|
||||
// Phoenix的插入、更新语句是统一的,统一使用upsert into关键字
|
||||
// Phoenix只支持通过主键更新操作,因此query无效,自动根据entity中的主键更新
|
||||
return super.psForInsert(conn, entity);
|
||||
}
|
||||
|
||||
@ -31,4 +32,10 @@ public class PhoenixDialect extends AnsiSqlDialect {
|
||||
public String dialectName() {
|
||||
return DialectName.PHOENIX.name();
|
||||
}
|
||||
|
||||
@Override
|
||||
public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException {
|
||||
// Phoenix只支持通过主键更新操作,因此query无效,自动根据entity中的主键更新
|
||||
return psForInsert(conn, entity);
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
package cn.hutool.db.dialect.impl;
|
||||
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.util.ArrayUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.db.Entity;
|
||||
import cn.hutool.db.StatementUtil;
|
||||
import cn.hutool.db.dialect.DialectName;
|
||||
@ -28,21 +31,48 @@ public class PostgresqlDialect extends AnsiSqlDialect{
|
||||
return DialectName.POSTGREESQL.name();
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建用于upsert的PreparedStatement
|
||||
*
|
||||
* @param conn 数据库连接对象
|
||||
* @param entity 数据实体类(包含表名)
|
||||
* @param keys 查找字段 必须是有唯一索引的列且不能为空
|
||||
* @return PreparedStatement
|
||||
* @throws SQLException SQL执行异常
|
||||
*/
|
||||
@Override
|
||||
public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException {
|
||||
if (null==keys || keys.length==0){
|
||||
throw new SQLException("keys不能为空");
|
||||
Assert.notEmpty(keys, "Keys must be not empty for Postgres.");
|
||||
SqlBuilder.validateEntity(entity);
|
||||
final SqlBuilder builder = SqlBuilder.create(wrapper);
|
||||
|
||||
final StringBuilder fieldsPart = new StringBuilder();
|
||||
final StringBuilder placeHolder = new StringBuilder();
|
||||
final StringBuilder updateHolder = new StringBuilder();
|
||||
|
||||
// 构建字段部分和参数占位符部分
|
||||
entity.forEach((field, value)->{
|
||||
if (StrUtil.isNotBlank(field)) {
|
||||
if (fieldsPart.length() > 0) {
|
||||
// 非第一个参数,追加逗号
|
||||
fieldsPart.append(", ");
|
||||
placeHolder.append(", ");
|
||||
updateHolder.append(", ");
|
||||
}
|
||||
|
||||
final String wrapedField = (null != wrapper) ? wrapper.wrap(field) : field;
|
||||
fieldsPart.append(wrapedField);
|
||||
updateHolder.append(wrapedField).append("=EXCLUDED.").append(field);
|
||||
placeHolder.append("?");
|
||||
builder.addParams(value);
|
||||
}
|
||||
});
|
||||
|
||||
String tableName = entity.getTableName();
|
||||
if (null != this.wrapper) {
|
||||
tableName = this.wrapper.wrap(tableName);
|
||||
}
|
||||
final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys);
|
||||
return StatementUtil.prepareStatement(conn, upsert);
|
||||
builder.append("INSERT INTO ").append(tableName)
|
||||
// 字段列表
|
||||
.append(" (").append(fieldsPart)
|
||||
// 更新值列表
|
||||
.append(") VALUES (").append(placeHolder)
|
||||
// 定义检查冲突的主键或字段
|
||||
.append(") ON CONFLICT (").append(ArrayUtil.join(keys,", "))
|
||||
// 主键冲突后的更新操作
|
||||
.append(") DO UPDATE SET ").append(updateHolder);
|
||||
|
||||
return StatementUtil.prepareStatement(conn, builder);
|
||||
}
|
||||
}
|
||||
|
@ -7,13 +7,13 @@ import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.db.DbRuntimeException;
|
||||
import cn.hutool.db.Entity;
|
||||
import cn.hutool.db.dialect.DialectName;
|
||||
import cn.hutool.db.dialect.impl.OracleDialect;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map.Entry;
|
||||
|
||||
/**
|
||||
* SQL构建器<br>
|
||||
@ -57,6 +57,24 @@ public class SqlBuilder implements Builder<String> {
|
||||
return create().append(sql);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证实体类对象的有效性
|
||||
*
|
||||
* @param entity 实体类对象
|
||||
* @throws DbRuntimeException SQL异常包装,获取元数据信息失败
|
||||
*/
|
||||
public static void validateEntity(Entity entity) throws DbRuntimeException {
|
||||
if (null == entity) {
|
||||
throw new DbRuntimeException("Entity is null !");
|
||||
}
|
||||
if (StrUtil.isBlank(entity.getTableName())) {
|
||||
throw new DbRuntimeException("Entity`s table name is null !");
|
||||
}
|
||||
if (entity.isEmpty()) {
|
||||
throw new DbRuntimeException("No filed and value in this entity !");
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------- Static methods end
|
||||
|
||||
// --------------------------------------------------------------- Enums start
|
||||
@ -87,10 +105,6 @@ public class SqlBuilder implements Builder<String> {
|
||||
// --------------------------------------------------------------- Enums end
|
||||
|
||||
private final StringBuilder sql = new StringBuilder();
|
||||
/**
|
||||
* 字段列表(仅用于插入和更新)
|
||||
*/
|
||||
private final List<String> fields = new ArrayList<>();
|
||||
/**
|
||||
* 占位符对应的值列表
|
||||
*/
|
||||
@ -146,41 +160,29 @@ public class SqlBuilder implements Builder<String> {
|
||||
// 验证
|
||||
validateEntity(entity);
|
||||
|
||||
if (null != wrapper) {
|
||||
// 包装表名 entity = wrapper.wrap(entity);
|
||||
entity.setTableName(wrapper.wrap(entity.getTableName()));
|
||||
}
|
||||
|
||||
final boolean isOracle = DialectName.ORACLE.match(dialectName);// 对Oracle的特殊处理
|
||||
final StringBuilder fieldsPart = new StringBuilder();
|
||||
final StringBuilder placeHolder = new StringBuilder();
|
||||
|
||||
boolean isFirst = true;
|
||||
String field;
|
||||
Object value;
|
||||
for (Entry<String, Object> entry : entity.entrySet()) {
|
||||
field = entry.getKey();
|
||||
value = entry.getValue();
|
||||
if (StrUtil.isNotBlank(field) /* && null != value */) {
|
||||
if (isFirst) {
|
||||
isFirst = false;
|
||||
} else {
|
||||
entity.forEach((field, value) -> {
|
||||
if (StrUtil.isNotBlank(field)) {
|
||||
if (fieldsPart.length() > 0) {
|
||||
// 非第一个参数,追加逗号
|
||||
fieldsPart.append(", ");
|
||||
placeHolder.append(", ");
|
||||
}
|
||||
|
||||
this.fields.add(field);
|
||||
fieldsPart.append((null != wrapper) ? wrapper.wrap(field) : field);
|
||||
if (isOracle && value instanceof String && StrUtil.endWithIgnoreCase((String) value, ".nextval")) {
|
||||
if (isOracle && OracleDialect.isNextVal(value)) {
|
||||
// Oracle的特殊自增键,通过字段名.nextval获得下一个值
|
||||
placeHolder.append(value);
|
||||
} else {
|
||||
// 普通字段使用占位符
|
||||
placeHolder.append("?");
|
||||
this.paramValues.add(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// issue#1656@Github Phoenix兼容
|
||||
if (DialectName.PHOENIX.match(dialectName)) {
|
||||
@ -189,94 +191,18 @@ public class SqlBuilder implements Builder<String> {
|
||||
sql.append("INSERT INTO ");
|
||||
}
|
||||
|
||||
sql.append(entity.getTableName())
|
||||
String tableName = entity.getTableName();
|
||||
if (null != this.wrapper) {
|
||||
// 包装表名 entity = wrapper.wrap(entity);
|
||||
tableName = this.wrapper.wrap(tableName);
|
||||
}
|
||||
sql.append(tableName)
|
||||
.append(" (").append(fieldsPart).append(") VALUES (")//
|
||||
.append(placeHolder).append(")");
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 插入<br>
|
||||
* 插入会忽略空的字段名及其对应值,但是对于有字段名对应值为{@code null}的情况不忽略
|
||||
*
|
||||
* @param entity 实体
|
||||
* @param dialectName 方言名,用于对特殊数据库特殊处理
|
||||
* @param keys 根据何字段来确认唯一性,不传则用主键
|
||||
* @return 自己
|
||||
* @since 5.7.21
|
||||
*/
|
||||
public SqlBuilder upsert(Entity entity, String dialectName, String... keys) {
|
||||
// 验证
|
||||
validateEntity(entity);
|
||||
|
||||
if (null != wrapper) {
|
||||
// 包装表名 entity = wrapper.wrap(entity);
|
||||
entity.setTableName(wrapper.wrap(entity.getTableName()));
|
||||
}
|
||||
|
||||
final boolean isOracle = DialectName.ORACLE.match(dialectName);// 对Oracle的特殊处理
|
||||
final StringBuilder fieldsPart = new StringBuilder();
|
||||
final StringBuilder placeHolder = new StringBuilder();
|
||||
|
||||
boolean isFirst = true;
|
||||
String field;
|
||||
Object value;
|
||||
for (Entry<String, Object> entry : entity.entrySet()) {
|
||||
field = entry.getKey();
|
||||
value = entry.getValue();
|
||||
if (StrUtil.isNotBlank(field) /* && null != value */) {
|
||||
if (isFirst) {
|
||||
isFirst = false;
|
||||
} else {
|
||||
// 非第一个参数,追加逗号
|
||||
fieldsPart.append(", ");
|
||||
placeHolder.append(", ");
|
||||
}
|
||||
|
||||
this.fields.add(field);
|
||||
fieldsPart.append((null != wrapper) ? wrapper.wrap(field) : field);
|
||||
if (isOracle && value instanceof String && StrUtil.endWithIgnoreCase((String) value, ".nextval")) {
|
||||
// Oracle的特殊自增键,通过字段名.nextval获得下一个值
|
||||
placeHolder.append(value);
|
||||
} else {
|
||||
placeHolder.append("?");
|
||||
this.paramValues.add(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// issue#1656@Github Phoenix兼容
|
||||
if (DialectName.PHOENIX.match(dialectName)) {
|
||||
sql.append("UPSERT INTO ").append(entity.getTableName());
|
||||
} else if (DialectName.MYSQL.match(dialectName)) {
|
||||
sql.append("INSERT INTO ");
|
||||
sql.append(entity.getTableName())
|
||||
.append(" (").append(fieldsPart).append(") VALUES (")
|
||||
.append(placeHolder).append(") on duplicate key update ")
|
||||
.append(ArrayUtil.join(ArrayUtil.map(entity.keySet().toArray(), String.class, (k) -> k + "=values(" + k + ")"), ","));
|
||||
} else if (DialectName.H2.match(dialectName)) {
|
||||
sql.append("MERGE INTO ").append(entity.getTableName());
|
||||
if (null != keys && keys.length > 0) {
|
||||
sql.append(" KEY(").append(ArrayUtil.join(keys, ","))
|
||||
.append(") VALUES (")
|
||||
.append(placeHolder)
|
||||
.append(")");
|
||||
}
|
||||
} else if (DialectName.POSTGREESQL.match(dialectName)) {
|
||||
sql.append("INSERT INTO ");
|
||||
sql.append(entity.getTableName())
|
||||
.append(" (").append(fieldsPart).append(") VALUES (")
|
||||
.append(placeHolder).append(") on conflict (")
|
||||
.append(ArrayUtil.join(keys,","))
|
||||
.append(") do update set ")
|
||||
.append(ArrayUtil.join(ArrayUtil.map(entity.keySet().toArray(), String.class, (k) -> k + "=excluded." + k ), ","));
|
||||
} else {
|
||||
throw new RuntimeException(dialectName + " not support yet");
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除
|
||||
*
|
||||
@ -308,25 +234,22 @@ public class SqlBuilder implements Builder<String> {
|
||||
// 验证
|
||||
validateEntity(entity);
|
||||
|
||||
String tableName = entity.getTableName();
|
||||
if (null != wrapper) {
|
||||
// 包装表名
|
||||
entity.setTableName(wrapper.wrap(entity.getTableName()));
|
||||
tableName = wrapper.wrap(tableName);
|
||||
}
|
||||
|
||||
sql.append("UPDATE ").append(entity.getTableName()).append(" SET ");
|
||||
String field;
|
||||
for (Entry<String, Object> entry : entity.entrySet()) {
|
||||
field = entry.getKey();
|
||||
sql.append("UPDATE ").append(tableName).append(" SET ");
|
||||
entity.forEach((field, value) -> {
|
||||
if (StrUtil.isNotBlank(field)) {
|
||||
if (paramValues.size() > 0) {
|
||||
sql.append(", ");
|
||||
}
|
||||
this.fields.add(field);
|
||||
sql.append((null != wrapper) ? wrapper.wrap(field) : field).append(" = ? ");
|
||||
this.paramValues.add(entry.getValue());// 更新不对空做处理,因为存在清空字段的情况
|
||||
this.paramValues.add(value);// 更新不对空做处理,因为存在清空字段的情况
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
return this;
|
||||
}
|
||||
|
||||
@ -653,24 +576,6 @@ public class SqlBuilder implements Builder<String> {
|
||||
}
|
||||
// --------------------------------------------------------------- Builder end
|
||||
|
||||
/**
|
||||
* 获得插入或更新的数据库字段列表
|
||||
*
|
||||
* @return 插入或更新的数据库字段列表
|
||||
*/
|
||||
public String[] getFieldArray() {
|
||||
return this.fields.toArray(new String[0]);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获得插入或更新的数据库字段列表
|
||||
*
|
||||
* @return 插入或更新的数据库字段列表
|
||||
*/
|
||||
public List<String> getFields() {
|
||||
return this.fields;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获得占位符对应的值列表<br>
|
||||
*
|
||||
@ -725,23 +630,5 @@ public class SqlBuilder implements Builder<String> {
|
||||
|
||||
return ConditionBuilder.of(conditions).build(this.paramValues);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证实体类对象的有效性
|
||||
*
|
||||
* @param entity 实体类对象
|
||||
* @throws DbRuntimeException SQL异常包装,获取元数据信息失败
|
||||
*/
|
||||
private static void validateEntity(Entity entity) throws DbRuntimeException {
|
||||
if (null == entity) {
|
||||
throw new DbRuntimeException("Entity is null !");
|
||||
}
|
||||
if (StrUtil.isBlank(entity.getTableName())) {
|
||||
throw new DbRuntimeException("Entity`s table name is null !");
|
||||
}
|
||||
if (entity.isEmpty()) {
|
||||
throw new DbRuntimeException("No filed and value in this entity !");
|
||||
}
|
||||
}
|
||||
// --------------------------------------------------------------- private method end
|
||||
}
|
||||
|
@ -15,15 +15,19 @@ import java.util.Map.Entry;
|
||||
/**
|
||||
* 包装器<br>
|
||||
* 主要用于字段名的包装(在字段名的前后加字符,例如反引号来避免与数据库的关键字冲突)
|
||||
* @author Looly
|
||||
*
|
||||
* @author Looly
|
||||
*/
|
||||
public class Wrapper implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
/** 前置包装符号 */
|
||||
/**
|
||||
* 前置包装符号
|
||||
*/
|
||||
private Character preWrapQuote;
|
||||
/** 后置包装符号 */
|
||||
/**
|
||||
* 后置包装符号
|
||||
*/
|
||||
private Character sufWrapQuote;
|
||||
|
||||
public Wrapper() {
|
||||
@ -31,6 +35,7 @@ public class Wrapper implements Serializable {
|
||||
|
||||
/**
|
||||
* 构造
|
||||
*
|
||||
* @param wrapQuote 单包装字符
|
||||
*/
|
||||
public Wrapper(Character wrapQuote) {
|
||||
@ -40,6 +45,7 @@ public class Wrapper implements Serializable {
|
||||
|
||||
/**
|
||||
* 包装符号
|
||||
*
|
||||
* @param preWrapQuote 前置包装符号
|
||||
* @param sufWrapQuote 后置包装符号
|
||||
*/
|
||||
@ -49,14 +55,17 @@ public class Wrapper implements Serializable {
|
||||
}
|
||||
|
||||
//--------------------------------------------------------------- Getters and Setters start
|
||||
|
||||
/**
|
||||
* @return 前置包装符号
|
||||
*/
|
||||
public char getPreWrapQuote() {
|
||||
return preWrapQuote;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置前置包装的符号
|
||||
*
|
||||
* @param preWrapQuote 前置包装符号
|
||||
*/
|
||||
public void setPreWrapQuote(Character preWrapQuote) {
|
||||
@ -69,8 +78,10 @@ public class Wrapper implements Serializable {
|
||||
public char getSufWrapQuote() {
|
||||
return sufWrapQuote;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置后置包装的符号
|
||||
*
|
||||
* @param sufWrapQuote 后置包装符号
|
||||
*/
|
||||
public void setSufWrapQuote(Character sufWrapQuote) {
|
||||
@ -81,26 +92,27 @@ public class Wrapper implements Serializable {
|
||||
/**
|
||||
* 包装字段名<br>
|
||||
* 有时字段与SQL的某些关键字冲突,导致SQL出错,因此需要将字段名用单引号或者反引号包装起来,避免冲突
|
||||
*
|
||||
* @param field 字段名
|
||||
* @return 包装后的字段名
|
||||
*/
|
||||
public String wrap(String field){
|
||||
if(preWrapQuote == null || sufWrapQuote == null || StrUtil.isBlank(field)) {
|
||||
public String wrap(String field) {
|
||||
if (preWrapQuote == null || sufWrapQuote == null || StrUtil.isBlank(field)) {
|
||||
return field;
|
||||
}
|
||||
|
||||
//如果已经包含包装的引号,返回原字符
|
||||
if(StrUtil.isSurround(field, preWrapQuote, sufWrapQuote)){
|
||||
if (StrUtil.isSurround(field, preWrapQuote, sufWrapQuote)) {
|
||||
return field;
|
||||
}
|
||||
|
||||
//如果字段中包含通配符或者括号(字段通配符或者函数),不做包装
|
||||
if(StrUtil.containsAnyIgnoreCase(field, "*", "(", " ", " as ")) {
|
||||
if (StrUtil.containsAnyIgnoreCase(field, "*", "(", " ", " as ")) {
|
||||
return field;
|
||||
}
|
||||
|
||||
//对于Oracle这类数据库,表名中包含用户名需要单独拆分包装
|
||||
if(field.contains(StrUtil.DOT)){
|
||||
if (field.contains(StrUtil.DOT)) {
|
||||
final Collection<String> target = CollUtil.edit(StrUtil.split(field, CharUtil.DOT, 2), t -> StrUtil.format("{}{}{}", preWrapQuote, t, sufWrapQuote));
|
||||
return CollectionUtil.join(target, StrUtil.DOT);
|
||||
}
|
||||
@ -111,16 +123,17 @@ public class Wrapper implements Serializable {
|
||||
/**
|
||||
* 包装字段名<br>
|
||||
* 有时字段与SQL的某些关键字冲突,导致SQL出错,因此需要将字段名用单引号或者反引号包装起来,避免冲突
|
||||
*
|
||||
* @param fields 字段名
|
||||
* @return 包装后的字段名
|
||||
*/
|
||||
public String[] wrap(String... fields){
|
||||
if(ArrayUtil.isEmpty(fields)) {
|
||||
public String[] wrap(String... fields) {
|
||||
if (ArrayUtil.isEmpty(fields)) {
|
||||
return fields;
|
||||
}
|
||||
|
||||
String[] wrappedFields = new String[fields.length];
|
||||
for(int i = 0; i < fields.length; i++) {
|
||||
for (int i = 0; i < fields.length; i++) {
|
||||
wrappedFields[i] = wrap(fields[i]);
|
||||
}
|
||||
|
||||
@ -130,11 +143,12 @@ public class Wrapper implements Serializable {
|
||||
/**
|
||||
* 包装字段名<br>
|
||||
* 有时字段与SQL的某些关键字冲突,导致SQL出错,因此需要将字段名用单引号或者反引号包装起来,避免冲突
|
||||
*
|
||||
* @param fields 字段名
|
||||
* @return 包装后的字段名
|
||||
*/
|
||||
public Collection<String> wrap(Collection<String> fields){
|
||||
if(CollectionUtil.isEmpty(fields)) {
|
||||
public Collection<String> wrap(Collection<String> fields) {
|
||||
if (CollectionUtil.isEmpty(fields)) {
|
||||
return fields;
|
||||
}
|
||||
|
||||
@ -142,13 +156,14 @@ public class Wrapper implements Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* 包装字段名<br>
|
||||
* 包装表名和字段名,此方法返回一个新的Entity实体类<br>
|
||||
* 有时字段与SQL的某些关键字冲突,导致SQL出错,因此需要将字段名用单引号或者反引号包装起来,避免冲突
|
||||
*
|
||||
* @param entity 被包装的实体
|
||||
* @return 包装后的字段名
|
||||
* @return 新的实体
|
||||
*/
|
||||
public Entity wrap(Entity entity){
|
||||
if(null == entity) {
|
||||
public Entity wrap(Entity entity) {
|
||||
if (null == entity) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -168,14 +183,15 @@ public class Wrapper implements Serializable {
|
||||
/**
|
||||
* 包装字段名<br>
|
||||
* 有时字段与SQL的某些关键字冲突,导致SQL出错,因此需要将字段名用单引号或者反引号包装起来,避免冲突
|
||||
*
|
||||
* @param conditions 被包装的实体
|
||||
* @return 包装后的字段名
|
||||
*/
|
||||
public Condition[] wrap(Condition... conditions){
|
||||
public Condition[] wrap(Condition... conditions) {
|
||||
final Condition[] clonedConditions = new Condition[conditions.length];
|
||||
if(ArrayUtil.isNotEmpty(conditions)) {
|
||||
if (ArrayUtil.isNotEmpty(conditions)) {
|
||||
Condition clonedCondition;
|
||||
for(int i = 0; i < conditions.length; i++) {
|
||||
for (int i = 0; i < conditions.length; i++) {
|
||||
clonedCondition = conditions[i].clone();
|
||||
clonedCondition.setField(wrap(clonedCondition.getField()));
|
||||
clonedConditions[i] = clonedCondition;
|
||||
|
@ -1,6 +1,5 @@
|
||||
package cn.hutool.db;
|
||||
|
||||
import com.alibaba.druid.support.json.JSONUtils;
|
||||
import org.junit.Assert;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
@ -40,6 +39,7 @@ public class H2Test {
|
||||
List<Entity> query = Db.use(DS_GROUP_NAME).find(Entity.create("test"));
|
||||
Assert.assertEquals(4, query.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void upsertTest() throws SQLException {
|
||||
Db db=Db.use(DS_GROUP_NAME);
|
||||
|
Loading…
x
Reference in New Issue
Block a user