From b84494a52aa36bf5909083ddcf98bf91cb7637d1 Mon Sep 17 00:00:00 2001 From: Looly Date: Wed, 7 Apr 2021 02:00:12 +0800 Subject: [PATCH] AES support big data --- CHANGELOG.md | 3 +- .../main/java/cn/hutool/core/io/IoUtil.java | 23 ++- .../crypto/symmetric/SymmetricCrypto.java | 171 ++++++++++++++++-- .../crypto/test/symmetric/SymmetricTest.java | 25 ++- 4 files changed, 198 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index af91f210d..a2f8481be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ ------------------------------------------------------------------------------------------------------------- -# 5.6.3 (2021-04-05) +# 5.6.3 (2021-04-06) ### 新特性 * 【core 】 修改数字转换的实现,增加按照指定端序转换(pr#1492@Github) @@ -15,6 +15,7 @@ * 【core 】 增加银行卡号脱敏(pr#301@Gitee) * 【cache 】 使用LongAddr代替AtomicLong(pr#301@Gitee) * 【cache 】 EnumUtil使用LinkedHashMap(pr#304@Gitee) +* 【crypto 】 SymmetricCrypto支持大量数据加密解密(pr#1497@Gitee) ### Bug修复 * 【core 】 修复Validator.isUrl()传空返回true(issue#I3ETTY@Gitee) diff --git a/hutool-core/src/main/java/cn/hutool/core/io/IoUtil.java b/hutool-core/src/main/java/cn/hutool/core/io/IoUtil.java index 283feed1b..6796d29fb 100644 --- a/hutool-core/src/main/java/cn/hutool/core/io/IoUtil.java +++ b/hutool-core/src/main/java/cn/hutool/core/io/IoUtil.java @@ -13,6 +13,7 @@ import java.io.BufferedOutputStream; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.File; import java.io.FileInputStream; @@ -160,11 +161,11 @@ public class IoUtil extends NioUtil { for (int readSize; (readSize = in.read(buffer)) != EOF; ) { out.write(buffer, 0, readSize); size += readSize; - out.flush(); if (null != streamProgress) { streamProgress.progress(size); } } + out.flush(); } catch (IOException e) { throw new IORuntimeException(e); } @@ -771,7 +772,7 @@ public class IoUtil extends NioUtil { } /** - * 文件转为流 + * 文件转为{@link FileInputStream} * * @param file 文件 * @return {@link FileInputStream} @@ -785,7 +786,7 @@ public class IoUtil extends NioUtil { } /** - * String 转为流 + * byte[] 转为{@link ByteArrayInputStream} * * @param content 内容bytes * @return 字节流 @@ -798,6 +799,20 @@ public class IoUtil extends NioUtil { return new ByteArrayInputStream(content); } + /** + * {@link ByteArrayOutputStream}转为{@link ByteArrayInputStream} + * + * @param out {@link ByteArrayOutputStream} + * @return 字节流 + * @since 5.3.6 + */ + public static ByteArrayInputStream toStream(ByteArrayOutputStream out) { + if (out == null) { + return null; + } + return new ByteArrayInputStream(out.toByteArray()); + } + /** * 转换为{@link BufferedInputStream} * @@ -1069,9 +1084,9 @@ public class IoUtil extends NioUtil { for (Object content : contents) { if (content != null) { osw.writeObject(content); - osw.flush(); } } + osw.flush(); } catch (IOException e) { throw new IORuntimeException(e); } finally { diff --git a/hutool-crypto/src/main/java/cn/hutool/crypto/symmetric/SymmetricCrypto.java b/hutool-crypto/src/main/java/cn/hutool/crypto/symmetric/SymmetricCrypto.java index bbcd43274..d077a8dd6 100644 --- a/hutool-crypto/src/main/java/cn/hutool/crypto/symmetric/SymmetricCrypto.java +++ b/hutool-crypto/src/main/java/cn/hutool/crypto/symmetric/SymmetricCrypto.java @@ -4,19 +4,29 @@ import cn.hutool.core.codec.Base64; import cn.hutool.core.io.IORuntimeException; import cn.hutool.core.io.IoUtil; import cn.hutool.core.lang.Assert; -import cn.hutool.core.util.*; +import cn.hutool.core.util.ArrayUtil; +import cn.hutool.core.util.CharsetUtil; +import cn.hutool.core.util.HexUtil; +import cn.hutool.core.util.RandomUtil; +import cn.hutool.core.util.StrUtil; import cn.hutool.crypto.CryptoException; import cn.hutool.crypto.KeyUtil; import cn.hutool.crypto.Padding; import cn.hutool.crypto.SecureUtil; import javax.crypto.Cipher; +import javax.crypto.CipherInputStream; +import javax.crypto.CipherOutputStream; import javax.crypto.SecretKey; import javax.crypto.spec.IvParameterSpec; import javax.crypto.spec.PBEParameterSpec; +import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; import java.io.Serializable; import java.nio.charset.Charset; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; import java.security.spec.AlgorithmParameterSpec; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -133,7 +143,7 @@ public class SymmetricCrypto implements Serializable { * 初始化 * * @param algorithm 算法 - * @param key 密钥,如果为null自动生成一个key + * @param key 密钥,如果为{@code null}自动生成一个key * @return SymmetricCrypto的子对象,即子对象自身 */ public SymmetricCrypto init(String algorithm, SecretKey key) { @@ -199,12 +209,8 @@ public class SymmetricCrypto implements Serializable { public byte[] encrypt(byte[] data) { lock.lock(); try { - if (null == this.params) { - cipher.init(Cipher.ENCRYPT_MODE, secretKey); - } else { - cipher.init(Cipher.ENCRYPT_MODE, secretKey, params); - } - return cipher.doFinal(paddingDataWithZero(data, cipher.getBlockSize())); + final Cipher cipher = initCipher(Cipher.ENCRYPT_MODE); + return cipher.doFinal(paddingDataWithZero(data, cipher.getBlockSize())); } catch (Exception e) { throw new CryptoException(e); } finally { @@ -212,6 +218,47 @@ public class SymmetricCrypto implements Serializable { } } + /** + * 加密,针对大数据量,结束后不关闭流 + * + * @param data 被加密的字符串 + * @param out 输出流,可以是文件或网络位置 + * @param isClose 是否关闭流 + * @throws IORuntimeException IO异常 + * @since 5.6.3 + */ + public void encrypt(InputStream data, OutputStream out, boolean isClose) throws IORuntimeException { + lock.lock(); + CipherOutputStream cipherOutputStream = null; + try { + final Cipher cipher = initCipher(Cipher.ENCRYPT_MODE); + cipherOutputStream = new CipherOutputStream(out, cipher); + long length = IoUtil.copy(data, cipherOutputStream); + if(this.isZeroPadding){ + final int blockSize = cipher.getBlockSize(); + if(blockSize > 0){ + // 按照块拆分后的数据中多余的数据 + final int remainLength = (int) (length % blockSize); + if (remainLength > 0) { + // 补充0 + cipherOutputStream.write(new byte[blockSize - remainLength]); + cipherOutputStream.flush(); + } + } + } + } catch (IORuntimeException e) { + throw e; + } catch (Exception e) { + throw new CryptoException(e); + } finally { + lock.unlock(); + if(isClose){ + IoUtil.close(data); + IoUtil.close(cipherOutputStream); + } + } + } + /** * 加密 * @@ -333,7 +380,7 @@ public class SymmetricCrypto implements Serializable { } /** - * 加密 + * 加密,加密后关闭流 * * @param data 被加密的字符串 * @return 加密后的bytes @@ -377,11 +424,7 @@ public class SymmetricCrypto implements Serializable { lock.lock(); try { - if (null == this.params) { - cipher.init(Cipher.DECRYPT_MODE, secretKey); - } else { - cipher.init(Cipher.DECRYPT_MODE, secretKey, params); - } + final Cipher cipher = initCipher(Cipher.DECRYPT_MODE); blockSize = cipher.getBlockSize(); decryptData = cipher.doFinal(bytes); } catch (Exception e) { @@ -393,6 +436,44 @@ public class SymmetricCrypto implements Serializable { return removePadding(decryptData, blockSize); } + /** + * 解密,针对大数据量,结束后不关闭流 + * + * @param data 加密的字符串 + * @param out 输出流,可以是文件或网络位置 + * @param isClose 是否关闭流,包括输入和输出流 + * @throws IORuntimeException IO异常 + * @since 5.6.3 + */ + public void decrypt(InputStream data, OutputStream out, boolean isClose) throws IORuntimeException { + lock.lock(); + CipherInputStream cipherInputStream = null; + try { + final Cipher cipher = initCipher(Cipher.DECRYPT_MODE); + cipherInputStream = new CipherInputStream(data, cipher); + if(this.isZeroPadding){ + final int blockSize = cipher.getBlockSize(); + if(blockSize > 0){ + copyForZeroPadding(cipherInputStream, out, blockSize); + return; + } + } + IoUtil.copy(cipherInputStream, out); + } catch (IOException e) { + throw new IORuntimeException(e); + } catch (IORuntimeException e) { + throw e; + } catch (Exception e) { + throw new CryptoException(e); + } finally { + lock.unlock(); + if(isClose){ + IoUtil.close(data); + IoUtil.close(cipherInputStream); + } + } + } + /** * 解密为字符串 * @@ -446,7 +527,7 @@ public class SymmetricCrypto implements Serializable { } /** - * 解密,不会关闭流 + * 解密,会关闭流 * * @param data 被解密的bytes * @return 解密后的bytes @@ -499,6 +580,24 @@ public class SymmetricCrypto implements Serializable { // --------------------------------------------------------------------------------- Private method start + /** + * 初始化{@link Cipher}为加密或者解密模式 + * + * @param mode 模式,见{@link Cipher#ENCRYPT_MODE} 或 {@link Cipher#DECRYPT_MODE} + * @return {@link Cipher} + * @throws InvalidKeyException 无效key + * @throws InvalidAlgorithmParameterException 无效算法 + */ + private Cipher initCipher(int mode) throws InvalidKeyException, InvalidAlgorithmParameterException { + final Cipher cipher = this.cipher; + if (null == this.params) { + cipher.init(mode, secretKey); + } else { + cipher.init(mode, secretKey, params); + } + return cipher; + } + /** * 数据按照blockSize的整数倍长度填充填充0 * @@ -533,12 +632,12 @@ public class SymmetricCrypto implements Serializable { * 在{@link Padding#ZeroPadding} 模式下,且数据长度不是blockSize的整数倍才有效,否则返回原数据 * * @param data 数据 - * @param blockSize 块大小 + * @param blockSize 块大小,必须大于0 * @return 去除填充后的数据,如果isZeroPadding为false或长度刚好,返回原数据 * @since 4.6.7 */ private byte[] removePadding(byte[] data, int blockSize) { - if (this.isZeroPadding) { + if (this.isZeroPadding && blockSize > 0) { final int length = data.length; final int remainLength = length % blockSize; if (remainLength == 0) { @@ -552,5 +651,43 @@ public class SymmetricCrypto implements Serializable { } return data; } + + /** + * 拷贝解密后的流 + * @param in {@link CipherInputStream} + * @param out 输出流 + * @param blockSize 块大小 + * @throws IOException IO异常 + */ + private void copyForZeroPadding(CipherInputStream in, OutputStream out, int blockSize) throws IOException { + int n = 1; + if(IoUtil.DEFAULT_BUFFER_SIZE > blockSize){ + n = Math.max(n, IoUtil.DEFAULT_BUFFER_SIZE / blockSize); + } + // 此处缓存buffer使用blockSize的整数倍,方便读取时可以正好将补位的0读在一个buffer中 + final int bufSize = blockSize * n; + final byte[] preBuffer = new byte[bufSize]; + final byte[] buffer = new byte[bufSize]; + + boolean isFirst = true; + int preReadSize = 0; + for (int readSize; (readSize = in.read(buffer)) != IoUtil.EOF; ) { + if(isFirst){ + isFirst = false; + } else{ + // 将前一批数据写出 + out.write(preBuffer, 0, preReadSize); + } + ArrayUtil.copy(buffer, preBuffer, readSize); + preReadSize = readSize; + } + // 去掉末尾所有的补位0 + int i = preReadSize - 1; + while (i >= 0 && 0 == preBuffer[i]) { + i--; + } + out.write(preBuffer, 0, i+1); + out.flush(); + } // --------------------------------------------------------------------------------- Private method end } diff --git a/hutool-crypto/src/test/java/cn/hutool/crypto/test/symmetric/SymmetricTest.java b/hutool-crypto/src/test/java/cn/hutool/crypto/test/symmetric/SymmetricTest.java index f35386865..5cee56197 100644 --- a/hutool-crypto/src/test/java/cn/hutool/crypto/test/symmetric/SymmetricTest.java +++ b/hutool-crypto/src/test/java/cn/hutool/crypto/test/symmetric/SymmetricTest.java @@ -1,5 +1,6 @@ package cn.hutool.crypto.test.symmetric; +import cn.hutool.core.io.IoUtil; import cn.hutool.core.util.CharsetUtil; import cn.hutool.core.util.RandomUtil; import cn.hutool.core.util.StrUtil; @@ -7,15 +8,21 @@ import cn.hutool.crypto.KeyUtil; import cn.hutool.crypto.Mode; import cn.hutool.crypto.Padding; import cn.hutool.crypto.SecureUtil; -import cn.hutool.crypto.symmetric.*; +import cn.hutool.crypto.symmetric.AES; +import cn.hutool.crypto.symmetric.DES; +import cn.hutool.crypto.symmetric.DESede; +import cn.hutool.crypto.symmetric.SymmetricAlgorithm; +import cn.hutool.crypto.symmetric.SymmetricCrypto; +import cn.hutool.crypto.symmetric.Vigenere; import org.junit.Assert; import org.junit.Test; +import java.io.ByteArrayOutputStream; + /** * 对称加密算法单元测试 * * @author Looly - * */ public class SymmetricTest { @@ -113,6 +120,20 @@ public class SymmetricTest { Assert.assertEquals(content, decryptStr); } + @Test + public void aesZeroPaddingTest2() { + String content = "RandomUtil.randomString(RandomUtil.randomInt(2000))"; + AES aes = new AES(Mode.CBC, Padding.ZeroPadding, "0123456789ABHAEQ".getBytes(), "DYgjCEIMVrj2W9xN".getBytes()); + + final ByteArrayOutputStream encryptStream = new ByteArrayOutputStream(); + aes.encrypt(IoUtil.toUtf8Stream(content), encryptStream, true); + + final ByteArrayOutputStream contentStream = new ByteArrayOutputStream(); + aes.decrypt(IoUtil.toStream(encryptStream), contentStream, true); + + Assert.assertEquals(content, StrUtil.utf8Str(contentStream.toByteArray())); + } + @Test public void aesPkcs7PaddingTest() { String content = RandomUtil.randomString(RandomUtil.randomInt(200));