diff --git a/jcommon/api/src/main/java/com/xiaomi/data/push/uds/processor/StreamCallback.java b/jcommon/api/src/main/java/com/xiaomi/data/push/uds/processor/StreamCallback.java new file mode 100644 index 000000000..60fd49bc2 --- /dev/null +++ b/jcommon/api/src/main/java/com/xiaomi/data/push/uds/processor/StreamCallback.java @@ -0,0 +1,15 @@ +package com.xiaomi.data.push.uds.processor; + +/** + * @author goodjava@qq.com + * @date 2024/11/7 11:56 + */ +public interface StreamCallback { + + void onContent(String content); + + void onComplete(); + + void onError(Throwable error); + +} diff --git a/jcommon/api/src/main/java/com/xiaomi/data/push/uds/processor/UdsProcessor.java b/jcommon/api/src/main/java/com/xiaomi/data/push/uds/processor/UdsProcessor.java index 4f505f5fd..9c99399b7 100644 --- a/jcommon/api/src/main/java/com/xiaomi/data/push/uds/processor/UdsProcessor.java +++ b/jcommon/api/src/main/java/com/xiaomi/data/push/uds/processor/UdsProcessor.java @@ -25,6 +25,16 @@ public interface UdsProcessor { Response processRequest(Request request); + // 新增:判断是否为流式处理器 + default boolean isStreamProcessor() { + return false; + } + + // 新增:流式处理方法 + default void processStream(Request request, StreamCallback callback) { + throw new UnsupportedOperationException("Stream processing not supported"); + } + default String cmd() { return ""; diff --git a/jcommon/rcurve/pom.xml b/jcommon/rcurve/pom.xml index 3bdc708e7..27df228d3 100644 --- a/jcommon/rcurve/pom.xml +++ b/jcommon/rcurve/pom.xml @@ -68,7 +68,14 @@ io.netty netty-all - 4.1.48.Final + 4.1.114.Final + + + + io.netty + netty-transport-native-kqueue + 4.1.114.Final + osx-aarch_64 diff --git a/jcommon/rcurve/src/main/java/com/xiaomi/data/push/common/NetUtils.java b/jcommon/rcurve/src/main/java/com/xiaomi/data/push/common/NetUtils.java index 0368e760d..d0a289181 100644 --- a/jcommon/rcurve/src/main/java/com/xiaomi/data/push/common/NetUtils.java +++ b/jcommon/rcurve/src/main/java/com/xiaomi/data/push/common/NetUtils.java @@ -14,9 +14,9 @@ public class NetUtils { public static EventLoopGroup getEventLoopGroup() { - if (CommonUtils.isMac() && CommonUtils.isArch64()) { - return new NioEventLoopGroup(); - } +// if (CommonUtils.isMac() && CommonUtils.isArch64()) { +// return new NioEventLoopGroup(); +// } if (CommonUtils.isWindows()) { return new NioEventLoopGroup(); } diff --git a/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/ClientStreamCallback.java b/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/ClientStreamCallback.java new file mode 100644 index 000000000..9ff88e828 --- /dev/null +++ b/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/ClientStreamCallback.java @@ -0,0 +1,15 @@ +package com.xiaomi.data.push.uds.handler; + +/** + * @author goodjava@qq.com + * @date 2024/11/7 10:35 + */ +public interface ClientStreamCallback { + + void onContent(String content); + + void onComplete(); + + void onError(Throwable error); + +} diff --git a/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/MessageTypes.java b/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/MessageTypes.java new file mode 100644 index 000000000..1c5bab680 --- /dev/null +++ b/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/MessageTypes.java @@ -0,0 +1,18 @@ +package com.xiaomi.data.push.uds.handler; + +/** + * @author goodjava@qq.com + * @date 2024/11/6 17:41 + */ +public class MessageTypes { + + public static final String TYPE_KEY = "messageType"; + public static final String TYPE_NORMAL = "normal"; + public static final String TYPE_OPENAI = "openai"; + public static final String STREAM_ID_KEY = "streamId"; + public static final String PROMPT_KEY = "prompt"; + public static final String CONTENT_KEY = "content"; + public static final String STATUS_KEY = "status"; + + +} diff --git a/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/UdsClientHandler.java b/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/UdsClientHandler.java index 417daad7b..87b504a5d 100644 --- a/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/UdsClientHandler.java +++ b/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/UdsClientHandler.java @@ -26,8 +26,10 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import lombok.Getter; import lombok.extern.slf4j.Slf4j; +import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -42,6 +44,9 @@ public class UdsClientHandler extends SimpleChannelInboundHandler { private ConcurrentHashMap,ExecutorService>> processorMap; + @Getter + private final Map streamCallbacks = new ConcurrentHashMap<>(); + public UdsClientHandler(ConcurrentHashMap,ExecutorService>> processorMap) { this.processorMap = processorMap; @@ -70,26 +75,61 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Excep log.warn("processor is null cmd:{}", command.getCmd()); } } else { - Optional.ofNullable(UdsClient.reqMap.get(command.getId())).ifPresent(f -> { - if (Boolean.TRUE.toString().equals(String.valueOf(f.get("async")))) { - Object res = null; - try { - res = processResult(command, (Class) f.get("returnType")); - if (command.getCode() == 0) { - ((CompletableFuture)f.get("future")).complete(res); - } else { - ((CompletableFuture)f.get("future")).completeExceptionally(new RuntimeException(res.toString())); - } - } catch (Exception e) { - log.error("async response error,", e); - ((CompletableFuture)f.get("future")).completeExceptionally(e); + String messageType = command.getAttachments() + .getOrDefault(MessageTypes.TYPE_KEY, MessageTypes.TYPE_NORMAL); + + //流式的操作 + if (MessageTypes.TYPE_OPENAI.equals(messageType)) { + handleOpenAIResponse(command); + } else { + handleNormalResponse(command); + } + } + } + + private void handleOpenAIResponse(UdsCommand command) { + Map attachments = command.getAttachments(); + String streamId = attachments.get(MessageTypes.STREAM_ID_KEY); + String content = attachments.get(MessageTypes.CONTENT_KEY); + String status = attachments.get(MessageTypes.STATUS_KEY); + + ClientStreamCallback callback = streamCallbacks.get(streamId); + if (callback != null) { + if ("complete".equals(status)) { + callback.onComplete(); + streamCallbacks.remove(streamId); + } else if ("error".equals(status)) { + callback.onError(new RuntimeException(content)); + streamCallbacks.remove(streamId); + } else { + callback.onContent(content); + } + } + } + + private void handleNormalResponse(UdsCommand command) { + // 保持原有的处理逻辑不变 + Optional.ofNullable(UdsClient.reqMap.get(command.getId())).ifPresent(f -> { + if (Boolean.TRUE.toString().equals(String.valueOf(f.get("async")))) { + Object res = null; + try { + res = processResult(command, (Class) f.get("returnType")); + if (command.getCode() == 0) { + ((CompletableFuture)f.get("future")).complete(res); + } else { + ((CompletableFuture)f.get("future")).completeExceptionally( + new RuntimeException(res.toString()) + ); } - UdsClient.reqMap.remove(command.getId()); - } else { - ((CompletableFuture)f.get("future")).complete(command); + } catch (Exception e) { + log.error("async response error,", e); + ((CompletableFuture)f.get("future")).completeExceptionally(e); } - }); - } + UdsClient.reqMap.remove(command.getId()); + } else { + ((CompletableFuture)f.get("future")).complete(command); + } + }); } @Override diff --git a/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/UdsServerHandler.java b/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/UdsServerHandler.java index 38fb22441..a66e0333c 100644 --- a/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/UdsServerHandler.java +++ b/jcommon/rcurve/src/main/java/com/xiaomi/data/push/uds/handler/UdsServerHandler.java @@ -21,6 +21,7 @@ import com.xiaomi.data.push.uds.UdsServer; import com.xiaomi.data.push.uds.context.UdsServerContext; import com.xiaomi.data.push.uds.po.UdsCommand; +import com.xiaomi.data.push.uds.processor.StreamCallback; import com.xiaomi.data.push.uds.processor.UdsProcessor; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; @@ -32,7 +33,9 @@ import java.util.Map; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; /** * @author goodjava@qq.com @@ -45,33 +48,30 @@ public class UdsServerHandler extends ChannelInboundHandlerAdapter { private Map> m; - public UdsServerHandler(ConcurrentHashMap> processorMap) { this.m = processorMap; } - @Override public void channelRead(ChannelHandlerContext ctx, Object _msg) { try { ByteBuf msg = (ByteBuf) _msg; UdsCommand command = new UdsCommand(); command.decode(msg); - log.debug("server receive:id:{}:{}:{}:{}:{}",command.getId(), command.isRequest(), command.getApp(), command.getCmd(), command.getSerializeType()); + log.debug("server receive:id:{}:{}:{}:{}:{}", command.getId(), command.isRequest(), command.getApp(), command.getCmd(), command.getSerializeType()); if (command.isRequest()) { command.setChannel(ctx.channel()); Pair pair = this.m.get(command.getCmd()); if (null != pair) { UdsProcessor processor = pair.getKey(); - pair.getValue().submit(() -> { - log.debug("server received:{}", command.getId()); - UdsCommand res = processor.processRequest(command); - if (null != res) { - Send.send(ctx.channel(), res); - } - }); + // 判断是否为流式处理 + if (processor.isStreamProcessor()) { + handleStreamRequest(ctx, command, processor); + } else { + handleNormalRequest(pair.getValue(), ctx, command, processor); + } } else { - log.warn("processor is null cmd:{},id:{}", command.getCmd(),command.getId()); + log.warn("processor is null cmd:{},id:{}", command.getCmd(), command.getId()); } } else { Optional.ofNullable(UdsServer.reqMap.get(command.getId())).ifPresent(f -> f.complete(command)); @@ -81,6 +81,74 @@ public void channelRead(ChannelHandlerContext ctx, Object _msg) { } } + private void handleNormalRequest(ExecutorService pool, ChannelHandlerContext ctx, UdsCommand command, UdsProcessor processor) { + pool.submit(() -> { + log.debug("server received:{}", command.getId()); + UdsCommand res = processor.processRequest(command); + if (null != res) { + Send.send(ctx.channel(), res); + } + }); + } + + + private void handleStreamRequest(ChannelHandlerContext ctx, UdsCommand command, + UdsProcessor processor) { + + String streamId = command.getAttachments().getOrDefault( + MessageTypes.STREAM_ID_KEY, + UUID.randomUUID().toString() + ); + + StreamCallback callback = new StreamCallback() { + @Override + public void onContent(String content) { + sendStreamContent(ctx, command, streamId, content); + } + + @Override + public void onComplete() { + sendCompleteResponse(ctx, command, streamId); + } + + @Override + public void onError(Throwable error) { + sendErrorResponse(ctx, command, error.getMessage()); + } + }; + + // 执行流式处理 + processor.processStream(command, callback); + } + + + private void sendErrorResponse(ChannelHandlerContext ctx, UdsCommand command, String error) { + UdsCommand response = UdsCommand.createResponse(command); + response.setCode(-1); + response.setMessage(error); + Send.send(ctx.channel(), response); + } + + + private void sendCompleteResponse(ChannelHandlerContext ctx, UdsCommand request, String streamId) { + UdsCommand response = UdsCommand.createResponse(request); + Map attachments = response.getAttachments(); + attachments.put(MessageTypes.TYPE_KEY, MessageTypes.TYPE_OPENAI); + attachments.put(MessageTypes.STREAM_ID_KEY, streamId); + attachments.put(MessageTypes.STATUS_KEY, "complete"); + Send.send(ctx.channel(), response); + } + + + private void sendStreamContent(ChannelHandlerContext ctx, UdsCommand request, String streamId, String content) { + UdsCommand response = UdsCommand.createResponse(request); + Map attachments = response.getAttachments(); + attachments.put(MessageTypes.TYPE_KEY, MessageTypes.TYPE_OPENAI); + attachments.put(MessageTypes.STREAM_ID_KEY, streamId); + attachments.put(MessageTypes.CONTENT_KEY, content); + Send.send(ctx.channel(), response); + } + @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { @@ -91,7 +159,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { public void channelInactive(ChannelHandlerContext ctx) { Attribute attr = ctx.channel().attr(app); String v = attr.get(); - log.error("server channelInactive:{},{},{}", app, v,ctx.channel().id()); + log.error("server channelInactive:{},{},{}", app, v, ctx.channel().id()); if (null != v) { UdsServerContext.ins().remove(v); } @@ -99,7 +167,7 @@ public void channelInactive(ChannelHandlerContext ctx) { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - log.error("exceptionCaught,{}:{}",ctx.channel().id(), cause); + log.error("exceptionCaught,{}:{}", ctx.channel().id(), cause); Attribute attr = ctx.channel().attr(app); String v = attr.get(); if (null != v) { diff --git a/jcommon/rcurve/src/test/java/com/xiaomi/mone/rcurve/test/UdsTest.java b/jcommon/rcurve/src/test/java/com/xiaomi/mone/rcurve/test/UdsTest.java index 128e63ff0..a469094fb 100644 --- a/jcommon/rcurve/src/test/java/com/xiaomi/mone/rcurve/test/UdsTest.java +++ b/jcommon/rcurve/src/test/java/com/xiaomi/mone/rcurve/test/UdsTest.java @@ -28,6 +28,8 @@ public class UdsTest { private String path = "/tmp/test.sock"; + private boolean remote = false; + /** * 模拟启动server */