diff --git a/hutool-socket/src/main/java/cn/hutool/socket/nio/NioClient.java b/hutool-socket/src/main/java/cn/hutool/socket/nio/NioClient.java index 131b78c95..ec112bc64 100644 --- a/hutool-socket/src/main/java/cn/hutool/socket/nio/NioClient.java +++ b/hutool-socket/src/main/java/cn/hutool/socket/nio/NioClient.java @@ -2,12 +2,20 @@ package cn.hutool.socket.nio; import cn.hutool.core.io.IORuntimeException; import cn.hutool.core.io.IoUtil; +import cn.hutool.core.thread.ThreadFactoryBuilder; import java.io.Closeable; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; +import java.util.Iterator; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; /** * NIO客户端 @@ -15,88 +23,160 @@ import java.nio.channels.SocketChannel; * @author looly * @since 4.4.5 */ -public class NioClient implements Closeable { +public abstract class NioClient implements Closeable { - private SocketChannel channel; + private Selector selector; + private SocketChannel channel; + private ExecutorService executorService; + + /** + * 构造 + * + * @param host 服务器地址 + * @param port 端口 + */ + public NioClient(String host, int port) { + init(new InetSocketAddress(host, port)); + } + + /** + * 构造 + * + * @param address 服务器地址 + */ + public NioClient(InetSocketAddress address) { + init(address); + } + + /** + * 初始化 + * + * @param address 地址和端口 + * @return this + */ + public NioClient init(InetSocketAddress address) { + try { + //创建一个SocketChannel对象,配置成非阻塞模式 + this.channel = SocketChannel.open(); + channel.configureBlocking(false); + + //创建一个选择器,并把SocketChannel交给selector对象 + this.selector = Selector.open(); + channel.register(selector, SelectionKey.OP_CONNECT); + + //发起建立连接的请求,这里会立即返回,当连接建立完成后,SocketChannel就会被选取出来 + channel.connect(address); + } catch (IOException e) { + throw new IORuntimeException(e); + } + return this; + } /** - * 构造 - * - * @param host 服务器地址 - * @param port 端口 + * 检查连接是否建立完成 */ - public NioClient(String host, int port) { - init(new InetSocketAddress(host, port)); - } - - /** - * 构造 - * - * @param address 服务器地址 - */ - public NioClient(InetSocketAddress address) { - init(address); - } - - /** - * 初始化 - * - * @param address 地址和端口 - * @return this - */ - public NioClient init(InetSocketAddress address) { - try { - this.channel = SocketChannel.open(address); - } catch (IOException e) { - throw new IORuntimeException(e); + public boolean waitConnect() throws IOException { + boolean isConnect = false; + while (0 != this.selector.select()) { + final Iterator keyIter = selector.selectedKeys().iterator(); + while (keyIter.hasNext()) { + //连接建立完成 + SelectionKey key = keyIter.next(); + if (key.isConnectable()) { + if (this.channel.finishConnect()) { + this.channel.register(selector, SelectionKey.OP_READ); + isConnect = true; + } + } + keyIter.remove(); + break; + } + if (isConnect) { + break; + } } - return this; + return isConnect; } - /** - * 处理读事件
- * 当收到读取准备就绪的信号后,回调此方法,用户可读取从客户端传世来的消息 - * - * @param buffer 服务端数据存储缓存 - * @return this - */ - public NioClient read(ByteBuffer buffer) { - try { - this.channel.read(buffer); - } catch (IOException e) { - throw new IORuntimeException(e); - } - return this; + /** + * 开始监听 + */ + public void listen() { + this.executorService = Executors.newSingleThreadExecutor(r -> { + final Thread thread = Executors.defaultThreadFactory().newThread(r); + thread.setName("nio-client-listen"); + return thread; + }); + this.executorService.execute(() -> { + try { + doListen(); + } catch (IOException e) { + e.printStackTrace(); + } + }); + } + + /** + * 开始监听 + * + * @throws IOException IO异常 + */ + private void doListen() throws IOException { + while (0 != this.selector.select()) { + // 返回已选择键的集合 + final Iterator keyIter = selector.selectedKeys().iterator(); + while (keyIter.hasNext()) { + handle(keyIter.next()); + keyIter.remove(); + } + } + } + + /** + * 处理SelectionKey + * + * @param key SelectionKey + */ + private void handle(SelectionKey key) throws IOException { + // 读事件就绪 + if (key.isReadable()) { + final SocketChannel socketChannel = (SocketChannel) key.channel(); + read(socketChannel); + } + } + + /** + * 处理读事件
+ * 当收到读取准备就绪的信号后,回调此方法,用户可读取从客户端传出来的消息 + * + * @param socketChannel SocketChannel + */ + protected abstract void read(SocketChannel socketChannel); + + /** + * 实现写逻辑
+ * 当收到写出准备就绪的信号后,回调此方法,用户可向客户端发送消息 + * + * @param datas 发送的数据 + * @return this + */ + public NioClient write(ByteBuffer... datas) { + try { + this.channel.write(datas); + } catch (IOException e) { + throw new IORuntimeException(e); + } + return this; + } + + public void closeListen() { + this.executorService.shutdown(); } - /** - * 实现写逻辑
- * 当收到写出准备就绪的信号后,回调此方法,用户可向客户端发送消息 - * - * @param datas 发送的数据 - * @return this - */ - public NioClient write(ByteBuffer... datas) { - try { - this.channel.write(datas); - } catch (IOException e) { - throw new IORuntimeException(e); - } - return this; - } - - /** - * 获取SocketChannel - * - * @return SocketChannel - * @since 5.3.10 - */ - public SocketChannel getChannel() { - return this.channel; - } - - @Override - public void close() { - IoUtil.close(this.channel); - } + @Override + public void close() { + IoUtil.close(this.selector); + IoUtil.close(this.channel); + closeListen(); + } } diff --git a/hutool-socket/src/main/java/cn/hutool/socket/nio/NioServer.java b/hutool-socket/src/main/java/cn/hutool/socket/nio/NioServer.java index 6ab08524e..20e8307f6 100644 --- a/hutool-socket/src/main/java/cn/hutool/socket/nio/NioServer.java +++ b/hutool-socket/src/main/java/cn/hutool/socket/nio/NioServer.java @@ -137,7 +137,7 @@ public abstract class NioServer implements Closeable { /** * 处理读事件
- * 当收到读取准备就绪的信号后,回调此方法,用户可读取从客户端传世来的消息 + * 当收到读取准备就绪的信号后,回调此方法,用户可读取从客户端传出来的消息 * * @param socketChannel SocketChannel */ diff --git a/hutool-socket/src/test/java/cn/hutool/socket/NioClientTest.java b/hutool-socket/src/test/java/cn/hutool/socket/NioClientTest.java new file mode 100644 index 000000000..1bc217b81 --- /dev/null +++ b/hutool-socket/src/test/java/cn/hutool/socket/NioClientTest.java @@ -0,0 +1,67 @@ +package cn.hutool.socket; + +import cn.hutool.core.util.StrUtil; +import cn.hutool.socket.nio.NioClient; +import lombok.SneakyThrows; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.nio.charset.Charset; +import java.util.Iterator; +import java.util.Scanner; +import java.util.Set; + +public class NioClientTest { + + @SneakyThrows + public static void main(String[] args) { + NioClient client = new NioClient("127.0.0.1", 8080) { + @SneakyThrows + @Override + protected void read(SocketChannel sc) { + ByteBuffer readBuffer = ByteBuffer.allocate(1024); + //从channel读数据到缓冲区 + int readBytes = sc.read(readBuffer); + if (readBytes > 0){ + //Flips this buffer. The limit is set to the current position and then + // the position is set to zero,就是表示要从起始位置开始读取数据 + readBuffer.flip(); + //eturns the number of elements between the current position and the limit. + // 要读取的字节长度 + byte[] bytes = new byte[readBuffer.remaining()]; + //将缓冲区的数据读到bytes数组 + readBuffer.get(bytes); + String body = new String(bytes, "UTF-8"); + System.out.println("the read client receive message: " + body); + }else if(readBytes < 0){ + sc.close(); + } + } + }; + if (client.waitConnect()) { + client.listen(); + } + ByteBuffer buffer = ByteBuffer.wrap("client 发生到 server".getBytes()); + client.write(buffer); + buffer = ByteBuffer.wrap("client 再次发生到 server".getBytes()); + client.write(buffer); + + /** + * 在控制台向服务器端发送数据 + */ + System.out.println("请在下方畅所欲言"); + Scanner scanner = new Scanner(System.in); + while (scanner.hasNextLine()) { + String request = scanner.nextLine(); + if (request != null && request.trim().length() > 0) { + client.write( + Charset.forName("UTF-8") + .encode("测试client" + ": " + request)); + } + } + } +} \ No newline at end of file diff --git a/hutool-socket/src/test/java/cn/hutool/socket/NioServerTest.java b/hutool-socket/src/test/java/cn/hutool/socket/NioServerTest.java new file mode 100644 index 000000000..39d0b6848 --- /dev/null +++ b/hutool-socket/src/test/java/cn/hutool/socket/NioServerTest.java @@ -0,0 +1,82 @@ +package cn.hutool.socket; + +import cn.hutool.core.util.StrUtil; +import cn.hutool.socket.nio.NioServer; +import lombok.SneakyThrows; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.Set; + +public class NioServerTest { + + public static void main(String[] args) { + NioServer server = new NioServer(8080) { + @SneakyThrows + @Override + protected void read(SocketChannel sc) { + ByteBuffer readBuffer = ByteBuffer.allocate(1024); + //从channel读数据到缓冲区 + int readBytes = sc.read(readBuffer); + if (readBytes > 0){ + //Flips this buffer. The limit is set to the current position and then + // the position is set to zero,就是表示要从起始位置开始读取数据 + readBuffer.flip(); + //eturns the number of elements between the current position and the limit. + // 要读取的字节长度 + byte[] bytes = new byte[readBuffer.remaining()]; + //将缓冲区的数据读到bytes数组 + readBuffer.get(bytes); + String body = new String(bytes, "UTF-8"); + System.out.println("the read server receive message: " + body); + doWrite(sc, body); + }else if(readBytes < 0){ + sc.close(); + } + } + + @SneakyThrows + @Override + protected void write(SocketChannel sc) { + ByteBuffer readBuffer = ByteBuffer.allocate(1024); + //从channel读数据到缓冲区 + int readBytes = sc.read(readBuffer); + if (readBytes > 0){ + //Flips this buffer. The limit is set to the current position and then + // the position is set to zero,就是表示要从起始位置开始读取数据 + readBuffer.flip(); + //eturns the number of elements between the current position and the limit. + // 要读取的字节长度 + byte[] bytes = new byte[readBuffer.remaining()]; + //将缓冲区的数据读到bytes数组 + readBuffer.get(bytes); + String body = new String(bytes, "UTF-8"); + System.out.println("the write server receive message: " + body); + doWrite(sc, body); + }else if(readBytes < 0){ + sc.close(); + } + } + }; + server.listen(); + } + + public static void doWrite(SocketChannel channel, String response) throws IOException { + response = "我们已收到消息:"+response; + if(!StrUtil.isBlank(response)){ + byte [] bytes = response.getBytes(); + //分配一个bytes的length长度的ByteBuffer + ByteBuffer write = ByteBuffer.allocate(bytes.length); + //将返回数据写入缓冲区 + write.put(bytes); + write.flip(); + //将缓冲数据写入渠道,返回给客户端 + channel.write(write); + } + } +} \ No newline at end of file