This repository has been archived by the owner on Oct 29, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(cohere): add Cohere component (#187)
Because We want to integrate Cohere's models into our VDP pipeline platform. This commit - Added the Cohere AI Component, which supports the following tasks: (a) TASK_TEXT_GENERATION_CHAT (b) TASK_TEXT_EMBEDDINGS (c) TASK_TEXT_RERANKING
- Loading branch information
Showing
15 changed files
with
1,800 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
--- | ||
title: "Cohere" | ||
lang: "en-US" | ||
draft: false | ||
description: "Learn about how to set up a VDP Cohere component https://github.com/instill-ai/instill-core" | ||
--- | ||
|
||
The Cohere component is an AI component that allows users to connect the AI models served on the Cohere Platform. | ||
It can carry out the following tasks: | ||
|
||
- [Text Generation Chat](#text-generation-chat) | ||
- [Text Embeddings](#text-embeddings) | ||
- [Text Reranking](#text-reranking) | ||
|
||
|
||
|
||
## Release Stage | ||
|
||
`Alpha` | ||
|
||
|
||
|
||
## Configuration | ||
|
||
The component configuration is defined and maintained [here](https://github.com/instill-ai/component/blob/main/ai/cohere/v0/config/definition.json). | ||
|
||
|
||
|
||
|
||
## Setup | ||
|
||
|
||
| Field | Field ID | Type | Note | | ||
| :--- | :--- | :--- | :--- | | ||
| API Key (required) | `api-key` | string | Fill in your Cohere API key. To find your keys, visit the Cohere dashboard page. | | ||
|
||
|
||
|
||
|
||
## Supported Tasks | ||
|
||
### Text Generation Chat | ||
|
||
Provide text outputs in response to text inputs. | ||
|
||
|
||
| Input | ID | Type | Description | | ||
| :--- | :--- | :--- | :--- | | ||
| Task ID (required) | `task` | string | `TASK_TEXT_GENERATION_CHAT` | | ||
| Model Name (required) | `model-name` | string | The Cohere command model to be used | | ||
| Prompt (required) | `prompt` | string | The prompt text | | ||
| System message | `system-message` | string | The system message helps set the behavior of the assistant. For example, you can modify the personality of the assistant or provide specific instructions about how it should behave throughout the conversation. By default, the model’s behavior is using a generic message as "You are a helpful assistant." | | ||
| Documents | `documents` | array[string] | The documents to be used for the model, for optimal performance, the length of each document should be less than 300 words. | | ||
| Prompt Images | `prompt-images` | array[string] | The prompt images (Note: As for 2024-06-24 Cohere models are not multimodal, so images will be ignored.) | | ||
| Chat history | `chat-history` | array[object] | Incorporate external chat history, specifically previous messages within the conversation. Each message should adhere to the format: : \{"role": "The message role, i.e. 'USER' or 'CHATBOT'", "content": "message content"\}. | | ||
| Seed | `seed` | integer | The seed (default=42) | | ||
| Temperature | `temperature` | number | The temperature for sampling (default=0.7) | | ||
| Top K | `top-k` | integer | Top k for sampling (default=10) | | ||
| Max new tokens | `max-new-tokens` | integer | The maximum number of tokens for model to generate (default=50) | | ||
|
||
|
||
|
||
| Output | ID | Type | Description | | ||
| :--- | :--- | :--- | :--- | | ||
| Text | `text` | string | Model Output | | ||
| Citations (optional) | `citations` | array[object] | Citations | | ||
| Usage (optional) | `usage` | object | Token Usage on the Cohere Platform Command Models | | ||
|
||
|
||
|
||
|
||
|
||
|
||
### Text Embeddings | ||
|
||
Turn text into a vector of numbers that capture its meaning, unlocking use cases like semantic search. | ||
|
||
|
||
| Input | ID | Type | Description | | ||
| :--- | :--- | :--- | :--- | | ||
| Task ID (required) | `task` | string | `TASK_TEXT_EMBEDDINGS` | | ||
| Embedding Type (required) | `embedding-type` | string | Specifies the return type of embedding, Note that 'binary'/'ubinary' options means the component will return packed unsigned binary embeddings. The length of each binary embedding is 1/8 the length of the float embeddings of the provided model. | | ||
| Input Type (required) | `input-type` | string | Specifies the type of input passed to the model | | ||
| Model Name (required) | `model-name` | string | The Cohere embed model to be used | | ||
| Text (required) | `text` | string | The text | | ||
|
||
|
||
|
||
| Output | ID | Type | Description | | ||
| :--- | :--- | :--- | :--- | | ||
| Embedding | `embedding` | array[number] | Embedding of the input text | | ||
| Usage (optional) | `usage` | object | Token usage on the Cohere platform embed models | | ||
|
||
|
||
|
||
|
||
|
||
|
||
### Text Reranking | ||
|
||
Sort text inputs by semantic relevance to a specified query. | ||
|
||
|
||
| Input | ID | Type | Description | | ||
| :--- | :--- | :--- | :--- | | ||
| Task ID (required) | `task` | string | `TASK_TEXT_RERANKING` | | ||
| Model Name (required) | `model-name` | string | The Cohere rerank model to be used | | ||
| Query (required) | `query` | string | The query | | ||
| Documents (required) | `documents` | array[string] | The documents to be used for reranking | | ||
| Top N | `top-n` | integer | The number of most relevant documents or indices to return. Defaults to the length of the documents (default=3) | | ||
| Maximum number of chunks per document | `max-chunks-per-doc` | integer | The maximum number of chunks to produce internally from a document (default=10) | | ||
|
||
|
||
|
||
| Output | ID | Type | Description | | ||
| :--- | :--- | :--- | :--- | | ||
| Reranked documents | `ranking` | array[string] | Reranked documents | | ||
| Usage (optional) | `usage` | object | Search Usage on the Cohere Platform Rerank Models | | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package cohere | ||
|
||
import ( | ||
"context" | ||
"sync" | ||
|
||
cohereSDK "github.com/cohere-ai/cohere-go/v2" | ||
cohereClientSDK "github.com/cohere-ai/cohere-go/v2/client" | ||
"github.com/cohere-ai/cohere-go/v2/core" | ||
"go.uber.org/zap" | ||
"google.golang.org/protobuf/types/known/structpb" | ||
) | ||
|
||
type cohereClient struct { | ||
sdkClient cohereClientInterface | ||
logger *zap.Logger | ||
lock sync.Mutex | ||
} | ||
|
||
type cohereClientInterface interface { | ||
Chat(ctx context.Context, request *cohereSDK.ChatRequest, opts ...core.RequestOption) (*cohereSDK.NonStreamedChatResponse, error) | ||
Embed(ctx context.Context, request *cohereSDK.EmbedRequest, opts ...core.RequestOption) (*cohereSDK.EmbedResponse, error) | ||
Rerank(ctx context.Context, request *cohereSDK.RerankRequest, opts ...core.RequestOption) (*cohereSDK.RerankResponse, error) | ||
} | ||
|
||
func newClient(apiKey string, logger *zap.Logger) *cohereClient { | ||
client := cohereClientSDK.NewClient(cohereClientSDK.WithToken(apiKey)) | ||
return &cohereClient{sdkClient: client, logger: logger, lock: sync.Mutex{}} | ||
} | ||
|
||
func (cl *cohereClient) generateEmbedding(request cohereSDK.EmbedRequest) (cohereSDK.EmbedResponse, error) { | ||
respPtr, err := cl.sdkClient.Embed( | ||
context.TODO(), | ||
&request, | ||
) | ||
if err != nil { | ||
panic(err) | ||
} | ||
resp := cohereSDK.EmbedResponse{ | ||
EmbeddingsFloats: respPtr.EmbeddingsFloats, | ||
EmbeddingsByType: respPtr.EmbeddingsByType, | ||
} | ||
return resp, nil | ||
} | ||
|
||
func (cl *cohereClient) generateTextChat(request cohereSDK.ChatRequest) (cohereSDK.NonStreamedChatResponse, error) { | ||
respPtr, err := cl.sdkClient.Chat( | ||
context.TODO(), | ||
&request, | ||
) | ||
if err != nil { | ||
panic(err) | ||
} | ||
resp := cohereSDK.NonStreamedChatResponse{ | ||
Text: respPtr.Text, | ||
GenerationId: respPtr.GenerationId, | ||
Citations: respPtr.Citations, | ||
Meta: respPtr.Meta, | ||
} | ||
return resp, nil | ||
} | ||
func (cl *cohereClient) generateRerank(request cohereSDK.RerankRequest) (cohereSDK.RerankResponse, error) { | ||
respPtr, err := cl.sdkClient.Rerank( | ||
context.TODO(), | ||
&request, | ||
) | ||
if err != nil { | ||
panic(err) | ||
} | ||
resp := cohereSDK.RerankResponse{ | ||
Results: respPtr.Results, | ||
Meta: respPtr.Meta, | ||
} | ||
return resp, nil | ||
} | ||
|
||
func getAPIKey(setup *structpb.Struct) string { | ||
return setup.GetFields()[cfgAPIKey].GetStringValue() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package cohere | ||
|
||
import ( | ||
"context" | ||
"strings" | ||
"sync" | ||
"testing" | ||
|
||
cohereSDK "github.com/cohere-ai/cohere-go/v2" | ||
"github.com/cohere-ai/cohere-go/v2/core" | ||
qt "github.com/frankban/quicktest" | ||
"go.uber.org/zap" | ||
) | ||
|
||
func newMockClient() *cohereClient { | ||
return &cohereClient{sdkClient: &MockSDKClient{}, logger: zap.NewNop(), lock: sync.Mutex{}} | ||
} | ||
|
||
type MockSDKClient struct { | ||
} | ||
|
||
func (cl *MockSDKClient) Chat(ctx context.Context, request *cohereSDK.ChatRequest, opts ...core.RequestOption) (*cohereSDK.NonStreamedChatResponse, error) { | ||
uid := "944a80f0-c485-4fda-a5e8-2dd68890a5b7" | ||
return &cohereSDK.NonStreamedChatResponse{ | ||
Text: strings.ToUpper(request.Message), | ||
GenerationId: &uid, | ||
}, nil | ||
} | ||
|
||
func (cl *MockSDKClient) Embed(ctx context.Context, request *cohereSDK.EmbedRequest, opts ...core.RequestOption) (*cohereSDK.EmbedResponse, error) { | ||
emb := make([][]float64, 1) | ||
emb[0] = make([]float64, len(request.Texts[0])) | ||
return &cohereSDK.EmbedResponse{ | ||
EmbeddingsFloats: &cohereSDK.EmbedFloatsResponse{Embeddings: emb}, | ||
}, nil | ||
} | ||
|
||
func (cl *MockSDKClient) Rerank(ctx context.Context, request *cohereSDK.RerankRequest, opts ...core.RequestOption) (*cohereSDK.RerankResponse, error) { | ||
docCnt := len(request.Documents) | ||
res := make([]*cohereSDK.RerankResponseResultsItem, docCnt) | ||
for i, doc := range request.Documents { | ||
// reverse the provided documents | ||
res[(docCnt-1)-i] = &cohereSDK.RerankResponseResultsItem{ | ||
Document: &cohereSDK.RerankResponseResultsItemDocument{Text: doc.String}, | ||
} | ||
} | ||
return &cohereSDK.RerankResponse{ | ||
Results: res, | ||
}, nil | ||
} | ||
|
||
func TestClient(t *testing.T) { | ||
c := qt.New(t) | ||
|
||
clt := newMockClient() | ||
|
||
commandTc := struct { | ||
request cohereSDK.ChatRequest | ||
want string | ||
}{ | ||
request: cohereSDK.ChatRequest{Message: "Hello World"}, | ||
want: "HELLO WORLD", | ||
} | ||
c.Run("ok - task command", func(c *qt.C) { | ||
resp, err := clt.generateTextChat(commandTc.request) | ||
c.Check(err, qt.IsNil) | ||
c.Check(resp.Text, qt.Equals, commandTc.want) | ||
|
||
}) | ||
|
||
embedTc := struct { | ||
request cohereSDK.EmbedRequest | ||
want [][]float64 | ||
}{ | ||
request: cohereSDK.EmbedRequest{Texts: []string{"abcde"}}, | ||
want: [][]float64{{0, 0, 0, 0, 0}}, | ||
} | ||
c.Run("ok - task embed", func(c *qt.C) { | ||
resp, err := clt.generateEmbedding(embedTc.request) | ||
c.Check(err, qt.IsNil) | ||
c.Check(len(resp.EmbeddingsFloats.Embeddings[0]), qt.Equals, len(embedTc.want[0])) | ||
|
||
}) | ||
|
||
rerankTc := struct { | ||
request cohereSDK.RerankRequest | ||
want []string | ||
}{ | ||
|
||
request: cohereSDK.RerankRequest{Documents: []*cohereSDK.RerankRequestDocumentsItem{{String: "a"}, {String: "b"}, {String: "c"}, {String: "d"}}}, | ||
|
||
want: []string{"d", "c", "b", "a"}, | ||
} | ||
c.Run("ok - task rerank", func(c *qt.C) { | ||
resp, err := clt.generateRerank(rerankTc.request) | ||
c.Check(err, qt.IsNil) | ||
c.Check(len(resp.Results), qt.Equals, len(rerankTc.want)) | ||
for i, r := range resp.Results { | ||
c.Check(r.Document.Text, qt.Equals, rerankTc.want[i]) | ||
} | ||
}) | ||
} |
Oops, something went wrong.