Skip to content
This repository has been archived by the owner on Oct 29, 2024. It is now read-only.

Commit

Permalink
feat(cohere): add Cohere component (#187)
Browse files Browse the repository at this point in the history
Because

We want to integrate Cohere's models into our VDP pipeline platform.

This commit

- Added the Cohere AI Component, which supports the following tasks:
(a) TASK_TEXT_GENERATION_CHAT
(b) TASK_TEXT_EMBEDDINGS
(c)  TASK_TEXT_RERANKING
  • Loading branch information
namwoam authored Jul 11, 2024
1 parent 4c9281e commit 63fd578
Show file tree
Hide file tree
Showing 15 changed files with 1,800 additions and 1 deletion.
125 changes: 125 additions & 0 deletions ai/cohere/v0/README.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
---
title: "Cohere"
lang: "en-US"
draft: false
description: "Learn about how to set up a VDP Cohere component https://github.com/instill-ai/instill-core"
---

The Cohere component is an AI component that allows users to connect the AI models served on the Cohere Platform.
It can carry out the following tasks:

- [Text Generation Chat](#text-generation-chat)
- [Text Embeddings](#text-embeddings)
- [Text Reranking](#text-reranking)



## Release Stage

`Alpha`



## Configuration

The component configuration is defined and maintained [here](https://github.com/instill-ai/component/blob/main/ai/cohere/v0/config/definition.json).




## Setup


| Field | Field ID | Type | Note |
| :--- | :--- | :--- | :--- |
| API Key (required) | `api-key` | string | Fill in your Cohere API key. To find your keys, visit the Cohere dashboard page. |




## Supported Tasks

### Text Generation Chat

Provide text outputs in response to text inputs.


| Input | ID | Type | Description |
| :--- | :--- | :--- | :--- |
| Task ID (required) | `task` | string | `TASK_TEXT_GENERATION_CHAT` |
| Model Name (required) | `model-name` | string | The Cohere command model to be used |
| Prompt (required) | `prompt` | string | The prompt text |
| System message | `system-message` | string | The system message helps set the behavior of the assistant. For example, you can modify the personality of the assistant or provide specific instructions about how it should behave throughout the conversation. By default, the model’s behavior is using a generic message as "You are a helpful assistant." |
| Documents | `documents` | array[string] | The documents to be used for the model, for optimal performance, the length of each document should be less than 300 words. |
| Prompt Images | `prompt-images` | array[string] | The prompt images (Note: As for 2024-06-24 Cohere models are not multimodal, so images will be ignored.) |
| Chat history | `chat-history` | array[object] | Incorporate external chat history, specifically previous messages within the conversation. Each message should adhere to the format: : \{"role": "The message role, i.e. 'USER' or 'CHATBOT'", "content": "message content"\}. |
| Seed | `seed` | integer | The seed (default=42) |
| Temperature | `temperature` | number | The temperature for sampling (default=0.7) |
| Top K | `top-k` | integer | Top k for sampling (default=10) |
| Max new tokens | `max-new-tokens` | integer | The maximum number of tokens for model to generate (default=50) |



| Output | ID | Type | Description |
| :--- | :--- | :--- | :--- |
| Text | `text` | string | Model Output |
| Citations (optional) | `citations` | array[object] | Citations |
| Usage (optional) | `usage` | object | Token Usage on the Cohere Platform Command Models |






### Text Embeddings

Turn text into a vector of numbers that capture its meaning, unlocking use cases like semantic search.


| Input | ID | Type | Description |
| :--- | :--- | :--- | :--- |
| Task ID (required) | `task` | string | `TASK_TEXT_EMBEDDINGS` |
| Embedding Type (required) | `embedding-type` | string | Specifies the return type of embedding, Note that 'binary'/'ubinary' options means the component will return packed unsigned binary embeddings. The length of each binary embedding is 1/8 the length of the float embeddings of the provided model. |
| Input Type (required) | `input-type` | string | Specifies the type of input passed to the model |
| Model Name (required) | `model-name` | string | The Cohere embed model to be used |
| Text (required) | `text` | string | The text |



| Output | ID | Type | Description |
| :--- | :--- | :--- | :--- |
| Embedding | `embedding` | array[number] | Embedding of the input text |
| Usage (optional) | `usage` | object | Token usage on the Cohere platform embed models |






### Text Reranking

Sort text inputs by semantic relevance to a specified query.


| Input | ID | Type | Description |
| :--- | :--- | :--- | :--- |
| Task ID (required) | `task` | string | `TASK_TEXT_RERANKING` |
| Model Name (required) | `model-name` | string | The Cohere rerank model to be used |
| Query (required) | `query` | string | The query |
| Documents (required) | `documents` | array[string] | The documents to be used for reranking |
| Top N | `top-n` | integer | The number of most relevant documents or indices to return. Defaults to the length of the documents (default=3) |
| Maximum number of chunks per document | `max-chunks-per-doc` | integer | The maximum number of chunks to produce internally from a document (default=10) |



| Output | ID | Type | Description |
| :--- | :--- | :--- | :--- |
| Reranked documents | `ranking` | array[string] | Reranked documents |
| Usage (optional) | `usage` | object | Search Usage on the Cohere Platform Rerank Models |







10 changes: 10 additions & 0 deletions ai/cohere/v0/assets/Cohere.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
79 changes: 79 additions & 0 deletions ai/cohere/v0/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package cohere

import (
"context"
"sync"

cohereSDK "github.com/cohere-ai/cohere-go/v2"
cohereClientSDK "github.com/cohere-ai/cohere-go/v2/client"
"github.com/cohere-ai/cohere-go/v2/core"
"go.uber.org/zap"
"google.golang.org/protobuf/types/known/structpb"
)

type cohereClient struct {
sdkClient cohereClientInterface
logger *zap.Logger
lock sync.Mutex
}

type cohereClientInterface interface {
Chat(ctx context.Context, request *cohereSDK.ChatRequest, opts ...core.RequestOption) (*cohereSDK.NonStreamedChatResponse, error)
Embed(ctx context.Context, request *cohereSDK.EmbedRequest, opts ...core.RequestOption) (*cohereSDK.EmbedResponse, error)
Rerank(ctx context.Context, request *cohereSDK.RerankRequest, opts ...core.RequestOption) (*cohereSDK.RerankResponse, error)
}

func newClient(apiKey string, logger *zap.Logger) *cohereClient {
client := cohereClientSDK.NewClient(cohereClientSDK.WithToken(apiKey))
return &cohereClient{sdkClient: client, logger: logger, lock: sync.Mutex{}}
}

func (cl *cohereClient) generateEmbedding(request cohereSDK.EmbedRequest) (cohereSDK.EmbedResponse, error) {
respPtr, err := cl.sdkClient.Embed(
context.TODO(),
&request,
)
if err != nil {
panic(err)
}
resp := cohereSDK.EmbedResponse{
EmbeddingsFloats: respPtr.EmbeddingsFloats,
EmbeddingsByType: respPtr.EmbeddingsByType,
}
return resp, nil
}

func (cl *cohereClient) generateTextChat(request cohereSDK.ChatRequest) (cohereSDK.NonStreamedChatResponse, error) {
respPtr, err := cl.sdkClient.Chat(
context.TODO(),
&request,
)
if err != nil {
panic(err)
}
resp := cohereSDK.NonStreamedChatResponse{
Text: respPtr.Text,
GenerationId: respPtr.GenerationId,
Citations: respPtr.Citations,
Meta: respPtr.Meta,
}
return resp, nil
}
func (cl *cohereClient) generateRerank(request cohereSDK.RerankRequest) (cohereSDK.RerankResponse, error) {
respPtr, err := cl.sdkClient.Rerank(
context.TODO(),
&request,
)
if err != nil {
panic(err)
}
resp := cohereSDK.RerankResponse{
Results: respPtr.Results,
Meta: respPtr.Meta,
}
return resp, nil
}

func getAPIKey(setup *structpb.Struct) string {
return setup.GetFields()[cfgAPIKey].GetStringValue()
}
102 changes: 102 additions & 0 deletions ai/cohere/v0/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package cohere

import (
"context"
"strings"
"sync"
"testing"

cohereSDK "github.com/cohere-ai/cohere-go/v2"
"github.com/cohere-ai/cohere-go/v2/core"
qt "github.com/frankban/quicktest"
"go.uber.org/zap"
)

func newMockClient() *cohereClient {
return &cohereClient{sdkClient: &MockSDKClient{}, logger: zap.NewNop(), lock: sync.Mutex{}}
}

type MockSDKClient struct {
}

func (cl *MockSDKClient) Chat(ctx context.Context, request *cohereSDK.ChatRequest, opts ...core.RequestOption) (*cohereSDK.NonStreamedChatResponse, error) {
uid := "944a80f0-c485-4fda-a5e8-2dd68890a5b7"
return &cohereSDK.NonStreamedChatResponse{
Text: strings.ToUpper(request.Message),
GenerationId: &uid,
}, nil
}

func (cl *MockSDKClient) Embed(ctx context.Context, request *cohereSDK.EmbedRequest, opts ...core.RequestOption) (*cohereSDK.EmbedResponse, error) {
emb := make([][]float64, 1)
emb[0] = make([]float64, len(request.Texts[0]))
return &cohereSDK.EmbedResponse{
EmbeddingsFloats: &cohereSDK.EmbedFloatsResponse{Embeddings: emb},
}, nil
}

func (cl *MockSDKClient) Rerank(ctx context.Context, request *cohereSDK.RerankRequest, opts ...core.RequestOption) (*cohereSDK.RerankResponse, error) {
docCnt := len(request.Documents)
res := make([]*cohereSDK.RerankResponseResultsItem, docCnt)
for i, doc := range request.Documents {
// reverse the provided documents
res[(docCnt-1)-i] = &cohereSDK.RerankResponseResultsItem{
Document: &cohereSDK.RerankResponseResultsItemDocument{Text: doc.String},
}
}
return &cohereSDK.RerankResponse{
Results: res,
}, nil
}

func TestClient(t *testing.T) {
c := qt.New(t)

clt := newMockClient()

commandTc := struct {
request cohereSDK.ChatRequest
want string
}{
request: cohereSDK.ChatRequest{Message: "Hello World"},
want: "HELLO WORLD",
}
c.Run("ok - task command", func(c *qt.C) {
resp, err := clt.generateTextChat(commandTc.request)
c.Check(err, qt.IsNil)
c.Check(resp.Text, qt.Equals, commandTc.want)

})

embedTc := struct {
request cohereSDK.EmbedRequest
want [][]float64
}{
request: cohereSDK.EmbedRequest{Texts: []string{"abcde"}},
want: [][]float64{{0, 0, 0, 0, 0}},
}
c.Run("ok - task embed", func(c *qt.C) {
resp, err := clt.generateEmbedding(embedTc.request)
c.Check(err, qt.IsNil)
c.Check(len(resp.EmbeddingsFloats.Embeddings[0]), qt.Equals, len(embedTc.want[0]))

})

rerankTc := struct {
request cohereSDK.RerankRequest
want []string
}{

request: cohereSDK.RerankRequest{Documents: []*cohereSDK.RerankRequestDocumentsItem{{String: "a"}, {String: "b"}, {String: "c"}, {String: "d"}}},

want: []string{"d", "c", "b", "a"},
}
c.Run("ok - task rerank", func(c *qt.C) {
resp, err := clt.generateRerank(rerankTc.request)
c.Check(err, qt.IsNil)
c.Check(len(resp.Results), qt.Equals, len(rerankTc.want))
for i, r := range resp.Results {
c.Check(r.Document.Text, qt.Equals, rerankTc.want[i])
}
})
}
Loading

0 comments on commit 63fd578

Please sign in to comment.