Skip to content

Commit

Permalink
Add Python processor grpc server (#440)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Sep 19, 2023
1 parent 9bba4f0 commit e654304
Show file tree
Hide file tree
Showing 33 changed files with 1,686 additions and 350 deletions.
4 changes: 4 additions & 0 deletions langstream-agents/langstream-agent-grpc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty-shaded</artifactId>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import ai.langstream.api.runner.code.AgentProcessor;
import ai.langstream.api.runner.code.RecordSink;
import ai.langstream.api.runner.code.SimpleRecord;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.ByteString;
import com.google.protobuf.Empty;
import io.grpc.ManagedChannel;
import io.grpc.stub.StreamObserver;
import java.io.ByteArrayOutputStream;
Expand All @@ -30,6 +33,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -44,7 +48,7 @@

@Slf4j
public class GrpcAgentProcessor extends AbstractAgentCode implements AgentProcessor {

protected static final ObjectMapper MAPPER = new ObjectMapper();
protected ManagedChannel channel;
private StreamObserver<ProcessorRequest> request;
private RecordSink sink;
Expand All @@ -65,7 +69,8 @@ public class GrpcAgentProcessor extends AbstractAgentCode implements AgentProces
private final Map<Integer, Object> serverSchemas = new ConcurrentHashMap<>();

private final StreamObserver<ProcessorResponse> responseObserver = getResponseObserver();
private AgentContext agentContext;
protected AgentContext agentContext;
protected AgentServiceGrpc.AgentServiceBlockingStub blockingStub;

private record RecordAndSink(
ai.langstream.api.runner.code.Record sourceRecord, RecordSink sink) {}
Expand All @@ -77,30 +82,42 @@ public GrpcAgentProcessor(ManagedChannel channel) {
}

@Override
public void start() {
public void start() throws Exception {
if (channel == null) {
throw new IllegalStateException("Channel not initialized");
}
request = ProcessorGrpc.newStub(channel).process(responseObserver);
blockingStub =
AgentServiceGrpc.newBlockingStub(channel).withDeadlineAfter(30, TimeUnit.SECONDS);
request = AgentServiceGrpc.newStub(channel).withWaitForReady().process(responseObserver);
}

@Override
public void setContext(AgentContext context) {
public void setContext(AgentContext context) throws Exception {
this.agentContext = context;
}

@Override
protected Map<String, Object> buildAdditionalInfo() {
try {
return MAPPER.readValue(
blockingStub.agentInfo(Empty.getDefaultInstance()).getJsonInfo(), Map.class);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}

@Override
public synchronized void process(
List<ai.langstream.api.runner.code.Record> records, RecordSink recordSink) {
if (sink == null) {
sink = recordSink;
}

Records.Builder recordsBuilder = Records.newBuilder();
ProcessorRequest.Builder requestBuilder = ProcessorRequest.newBuilder();
for (ai.langstream.api.runner.code.Record record : records) {
long rId = recordId.incrementAndGet();
try {
Record.Builder recordBuilder = recordsBuilder.addRecordBuilder().setRecordId(rId);
Record.Builder recordBuilder = requestBuilder.addRecordsBuilder().setRecordId(rId);

if (record.value() != null) {
recordBuilder.setValue(toGrpc(record.value()));
Expand Down Expand Up @@ -132,13 +149,13 @@ public synchronized void process(
recordSink.emit(new SourceRecordAndResult(record, null, e));
}
}
if (recordsBuilder.getRecordCount() > 0) {
request.onNext(ProcessorRequest.newBuilder().setRecords(recordsBuilder).build());
if (requestBuilder.getRecordsCount() > 0) {
request.onNext(requestBuilder.build());
}
}

@Override
public synchronized void close() {
public synchronized void close() throws Exception {
if (request != null) {
request.onCompleted();
}
Expand All @@ -152,17 +169,17 @@ private Object fromGrpc(Value value) throws IOException {
return null;
}
return switch (value.getTypeOneofCase()) {
case BYTESVALUE -> value.getBytesValue().toByteArray();
case BOOLEANVALUE -> value.getBooleanValue();
case STRINGVALUE -> value.getStringValue();
case BYTEVALUE -> (byte) value.getByteValue();
case SHORTVALUE -> (short) value.getShortValue();
case INTVALUE -> value.getIntValue();
case LONGVALUE -> value.getLongValue();
case FLOATVALUE -> value.getFloatValue();
case DOUBLEVALUE -> value.getDoubleValue();
case JSONVALUE -> value.getJsonValue();
case AVROVALUE -> {
case BYTES_VALUE -> value.getBytesValue().toByteArray();
case BOOLEAN_VALUE -> value.getBooleanValue();
case STRING_VALUE -> value.getStringValue();
case BYTE_VALUE -> (byte) value.getByteValue();
case SHORT_VALUE -> (short) value.getShortValue();
case INT_VALUE -> value.getIntValue();
case LONG_VALUE -> value.getLongValue();
case FLOAT_VALUE -> value.getFloatValue();
case DOUBLE_VALUE -> value.getDoubleValue();
case JSON_VALUE -> value.getJsonValue();
case AVRO_VALUE -> {
Object serverSchema = serverSchemas.get(value.getSchemaId());
if (serverSchema instanceof org.apache.avro.Schema schema) {
yield deserializeGenericRecord(schema, value.getAvroValue().toByteArray());
Expand All @@ -184,7 +201,7 @@ private SourceRecordAndResult fromGrpc(
return new SourceRecordAndResult(
sourceRecord, null, new RuntimeException(result.getError()));
}
for (Record record : result.getRecords().getRecordList()) {
for (Record record : result.getRecordsList()) {
resultRecords.add(fromGrpc(record));
}
return new SourceRecordAndResult(sourceRecord, resultRecords, null);
Expand Down Expand Up @@ -296,40 +313,36 @@ public void onNext(ProcessorResponse response) {
new org.apache.avro.Schema.Parser()
.parse(response.getSchema().getValue().toStringUtf8());
serverSchemas.put(response.getSchema().getSchemaId(), schema);
} else {
response.getResults()
.getResultList()
.forEach(
result -> {
RecordAndSink recordAndSink =
sourceRecords.remove(result.getRecordId());
if (recordAndSink == null) {
}
response.getResultsList()
.forEach(
result -> {
RecordAndSink recordAndSink =
sourceRecords.remove(result.getRecordId());
if (recordAndSink == null) {
agentContext.criticalFailure(
new RuntimeException(
"Received unknown record id "
+ result.getRecordId()));
} else {
try {
recordAndSink
.sink()
.emit(
fromGrpc(
recordAndSink.sourceRecord(),
result));
} catch (Exception e) {
agentContext.criticalFailure(
new RuntimeException(
"Received unknown record id "
+ result.getRecordId()));
} else {
try {
recordAndSink
.sink()
.emit(
fromGrpc(
recordAndSink
.sourceRecord(),
result));
} catch (Exception e) {
agentContext.criticalFailure(
new RuntimeException(
"Error while processing record %s: %s"
.formatted(
result
.getRecordId(),
e.getMessage()),
e));
}
"Error while processing record %s: %s"
.formatted(
result.getRecordId(),
e.getMessage()),
e));
}
});
}
}
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,92 @@
*/
package ai.langstream.agents.grpc;

import com.google.protobuf.Empty;
import io.grpc.ManagedChannelBuilder;
import java.util.UUID;
import java.net.ServerSocket;
import java.nio.file.Path;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class PythonGrpcAgentProcessor extends GrpcAgentProcessor {
private Map<String, Object> configuration;
private Process pythonProcess;

@Override
public void start() {
String target = "uds:///tmp/%s.sock".formatted(UUID.randomUUID());
this.channel = ManagedChannelBuilder.forTarget(target).usePlaintext().build();
// TODO: start the Python server
public void init(Map<String, Object> configuration) {
this.configuration = configuration;
}

@Override
public void start() throws Exception {
// Get a free port
int port;
try (ServerSocket socket = new ServerSocket(0)) {
socket.setReuseAddress(true);
port = socket.getLocalPort();
}

Path pythonCodeDirectory = agentContext.getCodeDirectory().resolve("python");
log.info("Python code directory {}", pythonCodeDirectory);

final String pythonPath = System.getenv("PYTHONPATH");
final String newPythonPath =
"%s:%s:%s"
.formatted(
pythonPath,
pythonCodeDirectory.toAbsolutePath(),
pythonCodeDirectory.resolve("lib").toAbsolutePath());

// copy input/output to standard input/output of the java process
// this allows to use "kubectl logs" easily
ProcessBuilder processBuilder =
new ProcessBuilder(
"python3",
"-m",
"langstream_grpc",
"[::]:%s".formatted(port),
MAPPER.writeValueAsString(configuration))
.inheritIO()
.redirectOutput(ProcessBuilder.Redirect.INHERIT)
.redirectError(ProcessBuilder.Redirect.INHERIT);
processBuilder.environment().put("PYTHONPATH", newPythonPath);
processBuilder.environment().put("NLTK_DATA", "/app/nltk_data");
pythonProcess = processBuilder.start();
this.channel =
ManagedChannelBuilder.forAddress("localhost", port)
.directExecutor()
.usePlaintext()
.build();
AgentServiceGrpc.AgentServiceBlockingStub stub =
AgentServiceGrpc.newBlockingStub(channel).withDeadlineAfter(30, TimeUnit.SECONDS);
for (int i = 0; ; i++) {
try {
stub.agentInfo(Empty.getDefaultInstance());
break;
} catch (Exception e) {
if (i > 8) {
throw e;
}
log.info("Waiting for python agent to start");
Thread.sleep(1000);
}
}
super.start();
}

@Override
public void close() throws Exception {
super.close();
if (pythonProcess != null) {
pythonProcess.destroy();
int exitCode = pythonProcess.waitFor();
log.info("Python process exited with code {}", exitCode);

if (exitCode != 0) {
throw new RuntimeException("Python code exited with code " + exitCode);
}
}
}
}
Loading

0 comments on commit e654304

Please sign in to comment.