Skip to content

Commit

Permalink
feat: Combine deployment resolution and model (#90)
Browse files Browse the repository at this point in the history
* merge deployment ID and model into one parameter

* add tests

* use in orchestration

* fixes

* fix test

* fix merge

* Update packages/gen-ai-hub/src/utils/deployment-resolver.test.ts

Co-authored-by: Matthias Kuhr <[email protected]>

* fix type tests

* Remove unused types

* Rename

* add executable ID

* rename group ID to resource group

* import type

* improvements from review

* remove duplicate request config merging

* fix typo

---------

Co-authored-by: Matthias Kuhr <[email protected]>
  • Loading branch information
marikaner and MatKuhr authored Aug 29, 2024
1 parent 2bd80b6 commit b20b520
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 168 deletions.
10 changes: 5 additions & 5 deletions packages/core/src/http-client.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { removeLeadingSlashes } from '@sap-cloud-sdk/util';
import { mergeIgnoreCase, removeLeadingSlashes } from '@sap-cloud-sdk/util';
import {
executeHttpRequest,
HttpRequestConfig,
Expand Down Expand Up @@ -102,13 +102,13 @@ function mergeWithDefaultRequestConfig(
return {
...defaultConfig,
...requestConfig,
headers: {
headers: mergeIgnoreCase({
...defaultConfig.headers,
...requestConfig?.headers
},
params: {
}),
params: mergeIgnoreCase({
...defaultConfig.params,
...requestConfig?.params
}
})
};
}
20 changes: 8 additions & 12 deletions packages/gen-ai-hub/src/client/openai/openai-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ describe('openai client', () => {
chatCompletionEndpoint
);

const response = await client.chatCompletion(
'gpt-35-turbo',
prompt,
'1234'
);
const response = await client.chatCompletion(prompt, {
deploymentId: '1234'
});
expect(response).toEqual(mockResponse);
});

Expand All @@ -86,7 +84,7 @@ describe('openai client', () => {
);

await expect(
client.chatCompletion('gpt-4', prompt, '1234')
client.chatCompletion(prompt, { deploymentId: '1234' })
).rejects.toThrow('status code 400');
});
});
Expand All @@ -111,11 +109,9 @@ describe('openai client', () => {
},
embeddingsEndpoint
);
const response = await client.embeddings(
'text-embedding-ada-002',
prompt,
'1234'
);
const response = await client.embeddings(prompt, {
deploymentId: '1234'
});
expect(response).toEqual(mockResponse);
});

Expand All @@ -138,7 +134,7 @@ describe('openai client', () => {
);

await expect(
client.embeddings('text-embedding-3-large', prompt, '1234')
client.embeddings(prompt, { deploymentId: '1234' })
).rejects.toThrow('status code 400');
});
});
Expand Down
72 changes: 19 additions & 53 deletions packages/gen-ai-hub/src/client/openai/openai-client.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import { HttpRequestConfig } from '@sap-cloud-sdk/http-client';
import { CustomRequestConfig, executeRequest } from '@sap-ai-sdk/core';
import { type CustomRequestConfig, executeRequest } from '@sap-ai-sdk/core';
import {
DeploymentResolver,
resolveDeployment
getDeploymentId,
type ModelDeployment
} from '../../utils/deployment-resolver.js';
import {
import type {
OpenAiChatCompletionParameters,
OpenAiEmbeddingParameters,
OpenAiEmbeddingOutput,
Expand All @@ -21,87 +20,54 @@ const apiVersion = '2024-02-01';
export class OpenAiClient {
/**
* Creates a completion for the chat messages.
* @param model - The model to use for the chat completion.
* @param data - The input parameters for the chat completion.
* @param deploymentResolver - A deployment id or a function to retrieve it.
* @param modelDeployment - This configuration is used to retrieve a deployment. Depending on the configuration use either the given deployment ID or the model name to retrieve matching deployments. If model and deployment ID are given, the model is verified against the deployment.
* @param requestConfig - The request configuration.
* @returns The completion result.
*/
async chatCompletion(
model: OpenAiChatModel | { name: OpenAiChatModel; version: string },
data: OpenAiChatCompletionParameters,
deploymentResolver?: DeploymentResolver,
modelDeployment: ModelDeployment<OpenAiChatModel>,
requestConfig?: CustomRequestConfig
): Promise<OpenAiChatCompletionOutput> {
const deploymentId = await resolveOpenAiDeployment(
model,
deploymentResolver
const deploymentId = await getDeploymentId(
modelDeployment,
'azure-openai',
requestConfig
);
const response = await executeRequest(
{
url: `/inference/deployments/${deploymentId}/chat/completions`,
apiVersion
},
data,
mergeRequestConfig(requestConfig)
requestConfig
);
return response.data;
}

/**
* Creates an embedding vector representing the given text.
* @param model - The model to use for the embedding computation.
* @param data - The text to embed.
* @param deploymentResolver - A deployment id or a function to retrieve it.
* @param modelDeployment - This configuration is used to retrieve a deployment. Depending on the configuration use either the given deployment ID or the model name to retrieve matching deployments. If model and deployment ID are given, the model is verified against the deployment.
* @param requestConfig - The request configuration.
* @returns The completion result.
*/
async embeddings(
model:
| OpenAiEmbeddingModel
| { name: OpenAiEmbeddingModel; version: string },
data: OpenAiEmbeddingParameters,
deploymentResolver?: DeploymentResolver,
modelDeployment: ModelDeployment<OpenAiEmbeddingModel>,
requestConfig?: CustomRequestConfig
): Promise<OpenAiEmbeddingOutput> {
const deploymentId = await resolveOpenAiDeployment(
model,
deploymentResolver
const deploymentId = await getDeploymentId(
modelDeployment,
'azure-openai',
requestConfig
);
const response = await executeRequest(
{ url: `/inference/deployments/${deploymentId}/embeddings`, apiVersion },
data,
mergeRequestConfig(requestConfig)
requestConfig
);
return response.data;
}
}

async function resolveOpenAiDeployment(
model: string | { name: string; version: string },
resolver?: DeploymentResolver
) {
if (typeof resolver === 'string') {
return resolver;
}
const llm =
typeof model === 'string' ? { name: model, version: 'latest' } : model;
const deployment = await resolveDeployment({
scenarioId: 'foundation-models',
executableId: 'azure-openai',
model: llm
});
return deployment.id;
}

function mergeRequestConfig(
requestConfig?: CustomRequestConfig
): HttpRequestConfig {
return {
method: 'POST',
headers: {
'content-type': 'application/json'
},
params: { 'api-version': apiVersion },
...requestConfig
};
}
32 changes: 20 additions & 12 deletions packages/gen-ai-hub/src/orchestration/orchestration-client.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import { executeRequest, CustomRequestConfig } from '@sap-ai-sdk/core';
import {
DeploymentResolver,
resolveDeployment
} from '../utils/deployment-resolver.js';
import { pickValueIgnoreCase } from '@sap-cloud-sdk/util';
import { resolveDeployment } from '../utils/deployment-resolver.js';
import {
CompletionPostRequest,
CompletionPostResponse
Expand All @@ -16,25 +14,35 @@ export class OrchestrationClient {
/**
* Creates a completion for the chat messages.
* @param data - The input parameters for the chat completion.
* @param deploymentResolver - A deployment ID or a function to retrieve it.
* @param deploymentId - A deployment ID or undefined to retrieve it based on the given model.
* @param requestConfig - Request configuration.
* @returns The completion result.
*/
async chatCompletion(
data: OrchestrationCompletionParameters,
deploymentResolver: DeploymentResolver = () =>
resolveDeployment({ scenarioId: 'orchestration' }),
deploymentId?: string,
requestConfig?: CustomRequestConfig
): Promise<CompletionPostResponse> {
const body = constructCompletionPostRequest(data);
const deployment =
typeof deploymentResolver === 'function'
? (await deploymentResolver()).id
: deploymentResolver;
deploymentId =
deploymentId ??
(
await resolveDeployment({
scenarioId: 'orchestration',
model: {
name: data.llmConfig.model_name,
version: data.llmConfig.model_version
},
resourceGroup: pickValueIgnoreCase(
requestConfig?.headers,
'ai-resource-group'
)
})
).id;

const response = await executeRequest(
{
url: `/inference/deployments/${deployment}/completion`
url: `/inference/deployments/${deploymentId}/completion`
},
body,
requestConfig
Expand Down
81 changes: 55 additions & 26 deletions packages/gen-ai-hub/src/utils/deployment-resolver.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
} from '../../../../test-util/mock-http.js';
import { resolveDeployment } from './deployment-resolver.js';

describe('Deployment resolver', () => {
describe('deployment resolver', () => {
beforeEach(() => {
mockClientCredentialsGrantCall();
});
Expand All @@ -18,29 +18,30 @@ describe('Deployment resolver', () => {
beforeEach(() => {
mockResponse();
});

it('should return the first deployment, if multiple are given', async () => {
const { id, configurationId } = await resolveDeployment({
const { id } = await resolveDeployment({
scenarioId: 'foundation-models'
});
expect(id).toBe('1');
expect(configurationId).toBe('c1');
});
it('should return the deployment with the correct model name', async () => {
const { id, configurationId } = await resolveDeployment({

it('should return the first deployment with the correct model name', async () => {
const { id } = await resolveDeployment({
scenarioId: 'foundation-models',
model: { name: 'gpt-4o' }
});
expect(id).toBe('2');
expect(configurationId).toBe('c2');
expect(id).toBe('1');
});
it('should return the deployment with the correct model name', async () => {
const { id, configurationId } = await resolveDeployment({

it('should return the deployment with the correct model name and version', async () => {
const { id } = await resolveDeployment({
scenarioId: 'foundation-models',
model: { name: 'gpt-4o', version: '0613' }
});
expect(id).toBe('2');
expect(configurationId).toBe('c2');
});

it('should throw in case no deployment with the given model name is found', async () => {
await expect(
resolveDeployment({
Expand All @@ -49,7 +50,8 @@ describe('Deployment resolver', () => {
})
).rejects.toThrow('No deployment matched the given criteria');
});
it('should throw in case no deployment with the given model version is found', async () => {

it('should throw in case no deployment with the given model and version is found', async () => {
await expect(
resolveDeployment({
scenarioId: 'foundation-models',
Expand All @@ -73,9 +75,47 @@ describe('Deployment resolver', () => {
});

await expect(
resolveDeployment({ scenarioId: 'foundation-models' })
resolveDeployment({
scenarioId: 'foundation-models',
model: { name: 'gpt-4o', version: '0613' }
})
).rejects.toThrow('No deployment matched the given criteria');
});

it('should consider custom resource group', async () => {
nock(aiCoreDestination.url, {
reqheaders: {
'ai-resource-group': 'otherId'
}
})
.get('/v2/lm/deployments')
.query({ scenarioId: 'foundation-models', status: 'RUNNING' })
.reply(200, {
resources: [
{
id: '5',
details: {
resources: {
backend_details: {
model: {
name: 'gpt-4o',
version: 'latest'
}
}
}
}
}
]
});

const { id } = await resolveDeployment({
scenarioId: 'foundation-models',
model: { name: 'gpt-4o' },
resourceGroup: 'otherId'
});

expect(id).toBe('5');
});
});

function mockResponse() {
Expand All @@ -87,32 +127,22 @@ function mockResponse() {
.get('/v2/lm/deployments')
.query({ scenarioId: 'foundation-models', status: 'RUNNING' })
.reply(200, {
count: 1,
resources: [
{
configurationId: 'c1',
id: '1',
deploymentUrl: 'https://foo.com/v2/inference/deployments/1',
details: {
resources: {
backend_details: {
model: {
name: 'gpt-4-32k',
name: 'gpt-4o',
version: 'latest'
}
}
},
scaling: {
backend_details: {}
}
},
lastOperation: 'CREATE',
status: 'RUNNING'
}
},
{
configurationId: 'c2',
id: '2',
deploymentUrl: 'https://foo.com/v2/inference/deployments/2',
details: {
resources: {
backend_details: {
Expand All @@ -122,8 +152,7 @@ function mockResponse() {
}
}
}
},
status: 'RUNNING'
}
}
]
});
Expand Down
Loading

0 comments on commit b20b520

Please sign in to comment.