Skip to content

Commit

Permalink
feat: make possible to provide additional auth headers needed for AI …
Browse files Browse the repository at this point in the history
…services
  • Loading branch information
nikolaymatrosov committed Oct 30, 2023
1 parent 6d05ff3 commit 08caf19
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@ import {
import { createChannel } from 'nice-grpc';
import { Required } from 'utility-types';
import {
ChannelSslOptions,
GeneratedServiceClientCtor,
IamTokenCredentialsConfig,
OAuthCredentialsConfig,
ServiceAccountCredentialsConfig, WrappedServiceClientType,
SessionConfig, ChannelSslOptions,
ServiceAccountCredentialsConfig,
SessionConfig,
WrappedServiceClientType,
} from './types';
import { IamTokenService } from './token-service/iam-token-service';
import { MetadataTokenService } from './token-service/metadata-token-service';
import { clientFactory } from './utils/client-factory';
import { serviceClients, cloudApi } from '.';
import { cloudApi, serviceClients } from '.';
import { getServiceClientEndpoint } from './service-endpoints';

const isOAuth = (config: SessionConfig): config is OAuthCredentialsConfig => 'oauthToken' in config;
Expand All @@ -39,7 +41,8 @@ const newTokenCreator = (config: SessionConfig): () => Promise<string> => {
yandexPassportOauthToken: config.oauthToken,
});
};
} if (isIamToken(config)) {
}
if (isIamToken(config)) {
const { iamToken } = config;

return async () => iamToken;
Expand All @@ -50,7 +53,11 @@ const newTokenCreator = (config: SessionConfig): () => Promise<string> => {
return async () => tokenService.getToken();
};

const newChannelCredentials = (tokenCreator: TokenCreator, sslOptions?: ChannelSslOptions) => credentials.combineChannelCredentials(
const newChannelCredentials = (
tokenCreator: TokenCreator,
sslOptions?: ChannelSslOptions,
headers?: Record<string, string>,
) => credentials.combineChannelCredentials(
credentials.createSsl(sslOptions?.rootCerts, sslOptions?.privateKey, sslOptions?.certChain),
credentials.createFromMetadataGenerator(
(
Expand All @@ -62,6 +69,15 @@ const newChannelCredentials = (tokenCreator: TokenCreator, sslOptions?: ChannelS
const md = new Metadata();

md.set('authorization', `Bearer ${token}`);
if (headers) {
for (const [key, value] of Object.entries(headers)) {
const lowerCaseKey = key.toLowerCase();

if (lowerCaseKey !== 'authorization') {
md.set(lowerCaseKey, value);
}
}
}

return callback(null, md);
})
Expand All @@ -87,7 +103,7 @@ export class Session {
...config,
};
this.tokenCreator = newTokenCreator(this.config);
this.channelCredentials = newChannelCredentials(this.tokenCreator, this.config.ssl);
this.channelCredentials = newChannelCredentials(this.tokenCreator, this.config.ssl, this.config.headers);
}

get pollInterval(): number {
Expand Down
1 change: 1 addition & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export interface ChannelSslOptions {
export interface GenericCredentialsConfig {
pollInterval?: number;
ssl?: ChannelSslOptions
headers?: Record<string, string>;
}

export interface OAuthCredentialsConfig extends GenericCredentialsConfig {
Expand Down

0 comments on commit 08caf19

Please sign in to comment.