Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: mesh streaming response support #896

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xiaomi.data.push.uds.processor;

/**
* @author [email protected]
* @date 2024/11/7 11:56
*/
public interface StreamCallback {

void onContent(String content);

void onComplete();

void onError(Throwable error);

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ public interface UdsProcessor<Request, Response> {

Response processRequest(Request request);

// 新增:判断是否为流式处理器
default boolean isStreamProcessor() {
return false;
}

// 新增:流式处理方法
default void processStream(Request request, StreamCallback callback) {
throw new UnsupportedOperationException("Stream processing not supported");
}


default String cmd() {
return "";
Expand Down
9 changes: 8 additions & 1 deletion jcommon/rcurve/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.48.Final</version>
<version>4.1.114.Final</version>
</dependency>

<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-transport-native-kqueue</artifactId>
<version>4.1.114.Final</version>
<classifier>osx-aarch_64</classifier>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
public class NetUtils {

public static EventLoopGroup getEventLoopGroup() {
if (CommonUtils.isMac() && CommonUtils.isArch64()) {
return new NioEventLoopGroup();
}
// if (CommonUtils.isMac() && CommonUtils.isArch64()) {
// return new NioEventLoopGroup();
// }
if (CommonUtils.isWindows()) {
return new NioEventLoopGroup();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xiaomi.data.push.uds.handler;

/**
* @author [email protected]
* @date 2024/11/7 10:35
*/
public interface ClientStreamCallback {

void onContent(String content);

void onComplete();

void onError(Throwable error);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.xiaomi.data.push.uds.handler;

/**
* @author [email protected]
* @date 2024/11/6 17:41
*/
public class MessageTypes {

public static final String TYPE_KEY = "messageType";
public static final String TYPE_NORMAL = "normal";
public static final String TYPE_OPENAI = "openai";
public static final String STREAM_ID_KEY = "streamId";
public static final String PROMPT_KEY = "prompt";
public static final String CONTENT_KEY = "content";
public static final String STATUS_KEY = "status";


}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -42,6 +44,9 @@ public class UdsClientHandler extends SimpleChannelInboundHandler<ByteBuf> {

private ConcurrentHashMap<String, Pair<UdsProcessor<UdsCommand, UdsCommand>,ExecutorService>> processorMap;

@Getter
private final Map<String, ClientStreamCallback> streamCallbacks = new ConcurrentHashMap<>();


public UdsClientHandler(ConcurrentHashMap<String, Pair<UdsProcessor<UdsCommand, UdsCommand>,ExecutorService>> processorMap) {
this.processorMap = processorMap;
Expand Down Expand Up @@ -70,26 +75,61 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Excep
log.warn("processor is null cmd:{}", command.getCmd());
}
} else {
Optional.ofNullable(UdsClient.reqMap.get(command.getId())).ifPresent(f -> {
if (Boolean.TRUE.toString().equals(String.valueOf(f.get("async")))) {
Object res = null;
try {
res = processResult(command, (Class<?>) f.get("returnType"));
if (command.getCode() == 0) {
((CompletableFuture)f.get("future")).complete(res);
} else {
((CompletableFuture)f.get("future")).completeExceptionally(new RuntimeException(res.toString()));
}
} catch (Exception e) {
log.error("async response error,", e);
((CompletableFuture)f.get("future")).completeExceptionally(e);
String messageType = command.getAttachments()
.getOrDefault(MessageTypes.TYPE_KEY, MessageTypes.TYPE_NORMAL);

//流式的操作
if (MessageTypes.TYPE_OPENAI.equals(messageType)) {
handleOpenAIResponse(command);
} else {
handleNormalResponse(command);
}
}
}

private void handleOpenAIResponse(UdsCommand command) {
Map<String, String> attachments = command.getAttachments();
String streamId = attachments.get(MessageTypes.STREAM_ID_KEY);
String content = attachments.get(MessageTypes.CONTENT_KEY);
String status = attachments.get(MessageTypes.STATUS_KEY);

ClientStreamCallback callback = streamCallbacks.get(streamId);
if (callback != null) {
if ("complete".equals(status)) {
callback.onComplete();
streamCallbacks.remove(streamId);
} else if ("error".equals(status)) {
callback.onError(new RuntimeException(content));
streamCallbacks.remove(streamId);
} else {
callback.onContent(content);
}
}
}

private void handleNormalResponse(UdsCommand command) {
// 保持原有的处理逻辑不变
Optional.ofNullable(UdsClient.reqMap.get(command.getId())).ifPresent(f -> {
if (Boolean.TRUE.toString().equals(String.valueOf(f.get("async")))) {
Object res = null;
try {
res = processResult(command, (Class<?>) f.get("returnType"));
if (command.getCode() == 0) {
((CompletableFuture)f.get("future")).complete(res);
} else {
((CompletableFuture)f.get("future")).completeExceptionally(
new RuntimeException(res.toString())
);
}
UdsClient.reqMap.remove(command.getId());
} else {
((CompletableFuture)f.get("future")).complete(command);
} catch (Exception e) {
log.error("async response error,", e);
((CompletableFuture)f.get("future")).completeExceptionally(e);
}
});
}
UdsClient.reqMap.remove(command.getId());
} else {
((CompletableFuture)f.get("future")).complete(command);
}
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.xiaomi.data.push.uds.UdsServer;
import com.xiaomi.data.push.uds.context.UdsServerContext;
import com.xiaomi.data.push.uds.po.UdsCommand;
import com.xiaomi.data.push.uds.processor.StreamCallback;
import com.xiaomi.data.push.uds.processor.UdsProcessor;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
Expand All @@ -32,7 +33,9 @@

import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
* @author [email protected]
Expand All @@ -45,33 +48,30 @@ public class UdsServerHandler extends ChannelInboundHandlerAdapter {

private Map<String, Pair<UdsProcessor, ExecutorService>> m;


public UdsServerHandler(ConcurrentHashMap<String, Pair<UdsProcessor, ExecutorService>> processorMap) {
this.m = processorMap;
}


@Override
public void channelRead(ChannelHandlerContext ctx, Object _msg) {
try {
ByteBuf msg = (ByteBuf) _msg;
UdsCommand command = new UdsCommand();
command.decode(msg);
log.debug("server receive:id:{}:{}:{}:{}:{}",command.getId(), command.isRequest(), command.getApp(), command.getCmd(), command.getSerializeType());
log.debug("server receive:id:{}:{}:{}:{}:{}", command.getId(), command.isRequest(), command.getApp(), command.getCmd(), command.getSerializeType());
if (command.isRequest()) {
command.setChannel(ctx.channel());
Pair<UdsProcessor, ExecutorService> pair = this.m.get(command.getCmd());
if (null != pair) {
UdsProcessor<UdsCommand, UdsCommand> processor = pair.getKey();
pair.getValue().submit(() -> {
log.debug("server received:{}", command.getId());
UdsCommand res = processor.processRequest(command);
if (null != res) {
Send.send(ctx.channel(), res);
}
});
// 判断是否为流式处理
if (processor.isStreamProcessor()) {
handleStreamRequest(ctx, command, processor);
} else {
handleNormalRequest(pair.getValue(), ctx, command, processor);
}
} else {
log.warn("processor is null cmd:{},id:{}", command.getCmd(),command.getId());
log.warn("processor is null cmd:{},id:{}", command.getCmd(), command.getId());
}
} else {
Optional.ofNullable(UdsServer.reqMap.get(command.getId())).ifPresent(f -> f.complete(command));
Expand All @@ -81,6 +81,74 @@ public void channelRead(ChannelHandlerContext ctx, Object _msg) {
}
}

private void handleNormalRequest(ExecutorService pool, ChannelHandlerContext ctx, UdsCommand command, UdsProcessor<UdsCommand, UdsCommand> processor) {
pool.submit(() -> {
log.debug("server received:{}", command.getId());
UdsCommand res = processor.processRequest(command);
if (null != res) {
Send.send(ctx.channel(), res);
}
});
}


private void handleStreamRequest(ChannelHandlerContext ctx, UdsCommand command,
UdsProcessor<UdsCommand, UdsCommand> processor) {

String streamId = command.getAttachments().getOrDefault(
MessageTypes.STREAM_ID_KEY,
UUID.randomUUID().toString()
);

StreamCallback callback = new StreamCallback() {
@Override
public void onContent(String content) {
sendStreamContent(ctx, command, streamId, content);
}

@Override
public void onComplete() {
sendCompleteResponse(ctx, command, streamId);
}

@Override
public void onError(Throwable error) {
sendErrorResponse(ctx, command, error.getMessage());
}
};

// 执行流式处理
processor.processStream(command, callback);
}


private void sendErrorResponse(ChannelHandlerContext ctx, UdsCommand command, String error) {
UdsCommand response = UdsCommand.createResponse(command);
response.setCode(-1);
response.setMessage(error);
Send.send(ctx.channel(), response);
}


private void sendCompleteResponse(ChannelHandlerContext ctx, UdsCommand request, String streamId) {
UdsCommand response = UdsCommand.createResponse(request);
Map<String, String> attachments = response.getAttachments();
attachments.put(MessageTypes.TYPE_KEY, MessageTypes.TYPE_OPENAI);
attachments.put(MessageTypes.STREAM_ID_KEY, streamId);
attachments.put(MessageTypes.STATUS_KEY, "complete");
Send.send(ctx.channel(), response);
}


private void sendStreamContent(ChannelHandlerContext ctx, UdsCommand request, String streamId, String content) {
UdsCommand response = UdsCommand.createResponse(request);
Map<String, String> attachments = response.getAttachments();
attachments.put(MessageTypes.TYPE_KEY, MessageTypes.TYPE_OPENAI);
attachments.put(MessageTypes.STREAM_ID_KEY, streamId);
attachments.put(MessageTypes.CONTENT_KEY, content);
Send.send(ctx.channel(), response);
}


@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
Expand All @@ -91,15 +159,15 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {
public void channelInactive(ChannelHandlerContext ctx) {
Attribute<String> attr = ctx.channel().attr(app);
String v = attr.get();
log.error("server channelInactive:{},{},{}", app, v,ctx.channel().id());
log.error("server channelInactive:{},{},{}", app, v, ctx.channel().id());
if (null != v) {
UdsServerContext.ins().remove(v);
}
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.error("exceptionCaught,{}:{}",ctx.channel().id(), cause);
log.error("exceptionCaught,{}:{}", ctx.channel().id(), cause);
Attribute<String> attr = ctx.channel().attr(app);
String v = attr.get();
if (null != v) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public class UdsTest {

private String path = "/tmp/test.sock";

private boolean remote = false;

/**
* 模拟启动server
*/
Expand Down
Loading