Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for the conversation API #646

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/validate_examples.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
GOPROXY: https://proxy.golang.org
DAPR_INSTALL_URL: https://raw.githubusercontent.com/dapr/cli/master/install/install.sh
DAPR_CLI_REF: ${{ github.event.inputs.daprcli_commit }}
DAPR_REF: ${{ github.event.inputs.daprdapr_commit }}
DAPR_REF: 334ae9eea43d487a7b29a0e4aef904e3eba57a10
CHECKOUT_REPO: ${{ github.repository }}
CHECKOUT_REF: ${{ github.ref }}
outputs:
Expand Down Expand Up @@ -164,6 +164,7 @@ jobs:
[
"actor",
"configuration",
"conversation",
"crypto",
"dist-scheduler",
"grpc-service",
Expand Down
3 changes: 3 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ type Client interface {
// DeleteJobAlpha1 deletes a scheduled job.
DeleteJobAlpha1(ctx context.Context, name string) error

// ConverseAlpha1 interacts with a conversational AI model.
ConverseAlpha1(ctx context.Context, request conversationRequest, options ...conversationRequestOption) (*ConversationResponse, error)

// GrpcClient returns the base grpc client if grpc is used and nil otherwise
GrpcClient() pb.DaprClient

Expand Down
145 changes: 145 additions & 0 deletions client/conversation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package client

import (
"context"
runtimev1pb "github.com/dapr/dapr/pkg/proto/runtime/v1"
"google.golang.org/protobuf/types/known/anypb"
)

// conversationRequest object - currently unexported as used in a functions option pattern
type conversationRequest struct {
name string
inputs []ConversationInput
Parameters map[string]*anypb.Any
Metadata map[string]string
ContextID *string
ScrubPII *bool // Scrub PII from the output
Temperature *float64
}

// NewConversationRequest defines a request with a component name and one or more inputs as a slice
func NewConversationRequest(llmName string, inputs []ConversationInput) conversationRequest {
return conversationRequest{
name: llmName,
inputs: inputs,
}

Check warning on line 38 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L34-L38

Added lines #L34 - L38 were not covered by tests
}

type conversationRequestOption func(request *conversationRequest)

// ConversationInput defines a single input.
type ConversationInput struct {
// The string to send to the llm.
Message string
// The role of the message.
Role *string
// Whether to Scrub PII from the input
ScrubPII *bool
}

// ConversationResponse is the basic response from a conversationRequest.
type ConversationResponse struct {
ContextID string
Outputs []ConversationResult
}

// ConversationResult is the individual
type ConversationResult struct {
Result string
Parameters map[string]*anypb.Any
}

// WithParameters should be used to provide parameters for custom fields.
func WithParameters(parameters map[string]*anypb.Any) conversationRequestOption {
return func(o *conversationRequest) {
o.Parameters = parameters
}

Check warning on line 69 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L66-L69

Added lines #L66 - L69 were not covered by tests
}

// WithMetadata used to define metadata to be passed to components.
func WithMetadata(metadata map[string]string) conversationRequestOption {
return func(o *conversationRequest) {
o.Metadata = metadata
}

Check warning on line 76 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L73-L76

Added lines #L73 - L76 were not covered by tests
}

// WithContextID to provide a new context or continue an existing one.
func WithContextID(id string) conversationRequestOption {
return func(o *conversationRequest) {
o.ContextID = &id
}

Check warning on line 83 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L80-L83

Added lines #L80 - L83 were not covered by tests
}

// WithScrubPII to define whether the outputs should have PII removed.
func WithScrubPII(scrub bool) conversationRequestOption {
return func(o *conversationRequest) {
o.ScrubPII = &scrub
}

Check warning on line 90 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L87-L90

Added lines #L87 - L90 were not covered by tests
}

// WithTemperature to specify which way the LLM leans.
func WithTemperature(temp float64) conversationRequestOption {
return func(o *conversationRequest) {
o.Temperature = &temp
}

Check warning on line 97 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L94-L97

Added lines #L94 - L97 were not covered by tests
}

// ConverseAlpha1 can invoke an LLM given a request created by the NewConversationRequest function.
func (c *GRPCClient) ConverseAlpha1(ctx context.Context, req conversationRequest, options ...conversationRequestOption) (*ConversationResponse, error) {

var cinputs []*runtimev1pb.ConversationInput
for _, i := range req.inputs {
cinputs = append(cinputs, &runtimev1pb.ConversationInput{
Message: i.Message,
Role: i.Role,
ScrubPII: i.ScrubPII,
})
}

Check warning on line 110 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L101-L110

Added lines #L101 - L110 were not covered by tests

for _, opt := range options {
if opt != nil {
opt(&req)
}

Check warning on line 115 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L112-L115

Added lines #L112 - L115 were not covered by tests
}

request := runtimev1pb.ConversationRequest{
Name: req.name,
ContextID: req.ContextID,
Inputs: cinputs,
Parameters: req.Parameters,
Metadata: req.Metadata,
ScrubPII: req.ScrubPII,
Temperature: req.Temperature,
}

resp, err := c.protoClient.ConverseAlpha1(ctx, &request)
if err != nil {
return nil, err
}

Check warning on line 131 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L118-L131

Added lines #L118 - L131 were not covered by tests

var outputs []ConversationResult
for _, i := range resp.GetOutputs() {
outputs = append(outputs, ConversationResult{
Result: i.GetResult(),
Parameters: i.GetParameters(),
})
}

Check warning on line 139 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L133-L139

Added lines #L133 - L139 were not covered by tests

return &ConversationResponse{
ContextID: resp.GetContextID(),
Outputs: outputs,
}, nil

Check warning on line 144 in client/conversation.go

View check run for this annotation

Codecov / codecov/patch

client/conversation.go#L141-L144

Added lines #L141 - L144 were not covered by tests
}
36 changes: 36 additions & 0 deletions examples/conversation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Dapr Conversation Example with go-sdk

## Step

### Prepare

- Dapr installed

### Run Conversation Example

<!-- STEP
name: Run Conversation
output_match_mode: substring
expected_stdout_lines:
- '== APP == conversation output: hello world'

background: true
sleep: 60
timeout_seconds: 60
-->

```bash
dapr run --app-id conversation \
--dapr-grpc-port 50001 \
--log-level debug \
--resources-path ./config \
-- go run ./main.go
```

<!-- END_STEP -->

## Result

```
- '== APP == conversation output: hello world'
```
7 changes: 7 additions & 0 deletions examples/conversation/config/conversation-echo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
apiVersion: dapr.io/v1alpha1
kind: Component
metadata:
name: echo
spec:
type: conversation.echo
version: v1
48 changes: 48 additions & 0 deletions examples/conversation/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main

import (
"context"
"fmt"
dapr "github.com/dapr/go-sdk/client"
"log"
)

func main() {
client, err := dapr.NewClient()
if err != nil {
panic(err)
}

input := dapr.ConversationInput{
Message: "hello world",
// Role: nil, // Optional
// ScrubPII: nil, // Optional
}

fmt.Printf("conversation input: %s\n", input.Message)

var conversationComponent = "echo"

request := dapr.NewConversationRequest(conversationComponent, []dapr.ConversationInput{input})

resp, err := client.ConverseAlpha1(context.Background(), request)
if err != nil {
log.Fatalf("err: %v", err)
}

fmt.Printf("conversation output: %s\n", resp.Outputs[0].Result)
}
20 changes: 10 additions & 10 deletions examples/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/dapr/go-sdk/examples

go 1.22.6
go 1.23.1

replace github.com/dapr/go-sdk => ../

Expand All @@ -9,7 +9,7 @@ require (
github.com/dapr/go-sdk v0.0.0-00010101000000-000000000000
github.com/go-redis/redis/v8 v8.11.5
github.com/google/uuid v1.6.0
google.golang.org/grpc v1.65.0
google.golang.org/grpc v1.67.0
google.golang.org/grpc/examples v0.0.0-20240516203910-e22436abb809
google.golang.org/protobuf v1.34.2
)
Expand All @@ -18,7 +18,7 @@ require (
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dapr/dapr v1.14.1 // indirect
github.com/dapr/dapr v1.14.3-0.20241104205526-334ae9eea43d // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/go-chi/chi/v5 v5.1.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
Expand All @@ -28,12 +28,12 @@ require (
github.com/marusama/semaphore/v2 v2.5.0 // indirect
github.com/microsoft/durabletask-go v0.5.1-0.20241014200046-fac9dd959f4d // indirect
github.com/xhit/go-str2duration/v2 v2.1.0 // indirect
go.opentelemetry.io/otel v1.27.0 // indirect
go.opentelemetry.io/otel/metric v1.27.0 // indirect
go.opentelemetry.io/otel/trace v1.27.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect
go.opentelemetry.io/otel v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.30.0 // indirect
golang.org/x/net v0.29.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
40 changes: 20 additions & 20 deletions examples/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyY
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/dapr/dapr v1.14.1 h1:n+FGF82caTsBjmnmKdBfrO94GRuLeuYs6qrAN5oG4ZM=
github.com/dapr/dapr v1.14.1/go.mod h1:oDNgaPHQIDZ3G4n4g89TElXWgkluYwcar41DI/oF4gw=
github.com/dapr/dapr v1.14.3-0.20241104205526-334ae9eea43d h1:0mkhz/uwGP+FE7EkMpnFZATnCshQY/9z3yHLnp+j9Ts=
github.com/dapr/dapr v1.14.3-0.20241104205526-334ae9eea43d/go.mod h1:/G9Z/yj9eQQlZSh14X4WQyF/KyzlQfxZqk2ut3LfqhM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down Expand Up @@ -59,24 +59,24 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc=
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
go.opentelemetry.io/otel v1.27.0 h1:9BZoF3yMK/O1AafMiQTVu0YDj5Ea4hPhxCs7sGva+cg=
go.opentelemetry.io/otel v1.27.0/go.mod h1:DMpAK8fzYRzs+bi3rS5REupisuqTheUlSZJ1WnZaPAQ=
go.opentelemetry.io/otel/metric v1.27.0 h1:hvj3vdEKyeCi4YaYfNjv2NUje8FqKqUY8IlF0FxV/ik=
go.opentelemetry.io/otel/metric v1.27.0/go.mod h1:mVFgmRlhljgBiuk/MP/oKylr4hs85GZAylncepAX/ak=
go.opentelemetry.io/otel/trace v1.27.0 h1:IqYb813p7cmbHk0a5y6pD5JPakbVfftRXABGt5/Rscw=
go.opentelemetry.io/otel/trace v1.27.0/go.mod h1:6RiD1hkAprV4/q+yd2ln1HG9GoPx39SuvvstaLBl+l4=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts=
go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc=
go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w=
go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ=
go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc=
go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61 h1:N9BgCIAUvn/M+p4NJccWPWb3BWh88+zyL0ll9HgbEeM=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240924160255-9d4c2d233b61/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/grpc v1.67.0 h1:IdH9y6PF5MPSdAntIcpjQ+tXO41pcQsfZV2RxtQgVcw=
google.golang.org/grpc v1.67.0/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
google.golang.org/grpc/examples v0.0.0-20240516203910-e22436abb809 h1:f96Rv5C5Y2CWlbKK6KhKDdyFgGOjPHPEMsdyaxE9k0c=
google.golang.org/grpc/examples v0.0.0-20240516203910-e22436abb809/go.mod h1:uaPEAc5V00jjG3DPhGFLXGT290RUV3+aNQigs1W50/8=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
Expand Down
Loading
Loading