From dec2eb29a1750ce7234dceefaa03a1fe1713d1ce Mon Sep 17 00:00:00 2001 From: Ron Warholic Date: Thu, 2 Nov 2023 10:37:18 -0400 Subject: [PATCH 1/2] Allow configuration overrides in the connection options. In the case that the user has extensions to the underlying HttpConnection instead of relying on supporting all use cases and features the consumer can simply provide a new subclass of IConnectionProvider (or subclass the existing HttpConnection) that will be used to generate new Thrift connections. --- lib/DBSQLClient.ts | 12 ++- lib/connection/connections/HttpConnection.ts | 2 +- .../contracts/IConnectionOptions.ts | 1 + lib/contracts/IDBSQLClient.ts | 12 +-- tests/unit/DBSQLClient.test.js | 100 ++++++++++++------ 5 files changed, 85 insertions(+), 42 deletions(-) diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 6d5b8bbd..53a00f95 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -46,8 +46,6 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) { export default class DBSQLClient extends EventEmitter implements IDBSQLClient, IClientContext { private static defaultLogger?: IDBSQLLogger; - private connectionProvider?: IConnectionProvider; - private authProvider?: IAuthentication; private client?: TCLIService.Client; @@ -58,6 +56,10 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I private readonly logger: IDBSQLLogger; + private connectionProvider?: IConnectionProvider; + + private ConnectionProviderConstructor: new(o: IConnectionOptions) => IConnectionProvider; + private readonly thrift = thrift; private sessions = new CloseableCollection(); @@ -73,13 +75,15 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I super(); this.logger = options?.logger ?? DBSQLClient.getDefaultLogger(); this.logger.log(LogLevel.info, 'Created DBSQLClient'); + this.ConnectionProviderConstructor = options?.connectionProvider || HttpConnection; } private getConnectionOptions(options: ConnectionOptions): IConnectionOptions { return { + ...options, host: options.host, port: options.port || 443, - path: prependSlash(options.path), + path: prependSlash(options.path || ''), https: true, socketTimeout: options.socketTimeout, proxy: options.proxy, @@ -129,7 +133,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise { this.authProvider = this.initAuthProvider(options, authProvider); - this.connectionProvider = new HttpConnection(this.getConnectionOptions(options)); + this.connectionProvider = new this.ConnectionProviderConstructor(this.getConnectionOptions(options)); const thriftConnection = await this.connectionProvider.getThriftConnection(); diff --git a/lib/connection/connections/HttpConnection.ts b/lib/connection/connections/HttpConnection.ts index 79f24c3d..8845f367 100644 --- a/lib/connection/connections/HttpConnection.ts +++ b/lib/connection/connections/HttpConnection.ts @@ -61,7 +61,7 @@ export default class HttpConnection implements IConnectionProvider { const httpsAgentOptions: https.AgentOptions = { ...this.getAgentDefaultOptions(), minVersion: 'TLSv1.2', - rejectUnauthorized: false, + rejectUnauthorized: !!this.options.rejectUnauthorized, ca: this.options.ca, cert: this.options.cert, key: this.options.key, diff --git a/lib/connection/contracts/IConnectionOptions.ts b/lib/connection/contracts/IConnectionOptions.ts index 340b6fae..bd9cd304 100644 --- a/lib/connection/contracts/IConnectionOptions.ts +++ b/lib/connection/contracts/IConnectionOptions.ts @@ -22,4 +22,5 @@ export default interface IConnectionOptions { ca?: Buffer | string; cert?: Buffer | string; key?: Buffer | string; + rejectUnauthorized?: boolean; } diff --git a/lib/contracts/IDBSQLClient.ts b/lib/contracts/IDBSQLClient.ts index 0a14f435..3b151ebd 100644 --- a/lib/contracts/IDBSQLClient.ts +++ b/lib/contracts/IDBSQLClient.ts @@ -1,11 +1,14 @@ import IDBSQLLogger from './IDBSQLLogger'; import IDBSQLSession from './IDBSQLSession'; import IAuthentication from '../connection/contracts/IAuthentication'; -import { ProxyOptions } from '../connection/contracts/IConnectionOptions'; +import IConnectionOptions, { ProxyOptions } from '../connection/contracts/IConnectionOptions'; import OAuthPersistence from '../connection/auth/DatabricksOAuth/OAuthPersistence'; +import IConnectionProvider from '../connection/contracts/IConnectionProvider'; export interface ClientOptions { logger?: IDBSQLLogger; + + connectionProvider: new (o: IConnectionOptions) => IConnectionProvider; } type AuthOptions = @@ -26,13 +29,8 @@ type AuthOptions = }; export type ConnectionOptions = { - host: string; - port?: number; - path: string; clientId?: string; - socketTimeout?: number; - proxy?: ProxyOptions; -} & AuthOptions; +} & AuthOptions & IConnectionOptions; export interface OpenSessionRequest { initialCatalog?: string; diff --git a/tests/unit/DBSQLClient.test.js b/tests/unit/DBSQLClient.test.js index 55231707..e2719b7c 100644 --- a/tests/unit/DBSQLClient.test.js +++ b/tests/unit/DBSQLClient.test.js @@ -20,7 +20,7 @@ class AuthProviderMock { } } -describe('DBSQLClient.connect', () => { +describe('DBSQLClient.connect', function () { const options = { host: '127.0.0.1', path: '', @@ -31,7 +31,7 @@ describe('DBSQLClient.connect', () => { HttpConnectionModule.default.restore?.(); }); - it('should prepend "/" to path if it is missing', async () => { + it('should prepend "/" to path if it is missing', async function () { const client = new DBSQLClient(); const path = 'example/path'; @@ -40,7 +40,7 @@ describe('DBSQLClient.connect', () => { expect(connectionOptions.path).to.equal(`/${path}`); }); - it('should not prepend "/" to path if it is already available', async () => { + it('should not prepend "/" to path if it is already available', async function () { const client = new DBSQLClient(); const path = '/example/path'; @@ -49,7 +49,7 @@ describe('DBSQLClient.connect', () => { expect(connectionOptions.path).to.equal(path); }); - it('should initialize connection state', async () => { + it('should initialize connection state', async function () { const client = new DBSQLClient(); expect(client.client).to.be.undefined; @@ -63,25 +63,65 @@ describe('DBSQLClient.connect', () => { expect(client.connectionProvider).to.be.instanceOf(HttpConnection); }); - it('should listen for Thrift connection events', async () => { - const client = new DBSQLClient(); - + it('uses the overridden connection provider', async function () { const thriftConnectionMock = { on: sinon.stub(), }; - sinon.stub(HttpConnectionModule, 'default').returns({ - getThriftConnection: () => Promise.resolve(thriftConnectionMock), + function fakeClient(opt) { + expect(opt).to.deep.include(options); + this.getThriftConnection = () => Promise.resolve(thriftConnectionMock); + } + + const client = new DBSQLClient({ + connectionProvider: fakeClient }); await client.connect(options); + expect(client.connectionProvider).to.be.instanceOf(fakeClient); + }); + + it('merges the connection options with the generated ones', async function () { + const thriftConnectionMock = { + on: sinon.stub(), + }; + + function fakeClient(opt) { + expect(opt).to.deep.include({ + myprop: 'abc', + https: true + }); + this.getThriftConnection = () => Promise.resolve(thriftConnectionMock); + } + const client = new DBSQLClient({ + connectionProvider: fakeClient + }); + + await client.connect({...options, myprop: 'abc'}); + expect(client.connectionProvider).to.be.instanceOf(fakeClient); + }); + + it('should listen for Thrift connection events', async function () { + const thriftConnectionMock = { + on: sinon.stub(), + }; + + const client = new DBSQLClient({ + connectionProvider: function() { + return { + getThriftConnection: () => Promise.resolve(thriftConnectionMock), + }; + } + }); + + await client.connect(options); expect(thriftConnectionMock.on.called).to.be.true; }); }); -describe('DBSQLClient.openSession', () => { - it('should successfully open session', async () => { +describe('DBSQLClient.openSession', function () { + it('should successfully open session', async function () { const client = new DBSQLClient(); sinon.stub(client, 'getClient').returns( @@ -99,7 +139,7 @@ describe('DBSQLClient.openSession', () => { expect(session).instanceOf(DBSQLSession); }); - it('should use initial namespace options', async () => { + it('should use initial namespace options', async function () { const client = new DBSQLClient(); sinon.stub(client, 'getClient').returns( @@ -129,7 +169,7 @@ describe('DBSQLClient.openSession', () => { } }); - it('should throw an exception when not connected', async () => { + it('should throw an exception when not connected', async function () { const client = new DBSQLClient(); client.connection = null; @@ -141,7 +181,7 @@ describe('DBSQLClient.openSession', () => { } }); - it('should throw an exception when the connection is lost', async () => { + it('should throw an exception when the connection is lost', async function () { const client = new DBSQLClient(); client.connection = { isConnected() { @@ -158,14 +198,14 @@ describe('DBSQLClient.openSession', () => { }); }); -describe('DBSQLClient.getClient', () => { +describe('DBSQLClient.getClient', function () { const options = { host: '127.0.0.1', path: '', token: 'dapi********************************', }; - it('should throw an error if not connected', async () => { + it('should throw an error if not connected', async function () { const client = new DBSQLClient(); try { await client.getClient(); @@ -178,7 +218,7 @@ describe('DBSQLClient.getClient', () => { } }); - it("should create client if wasn't not initialized yet", async () => { + it("should create client if wasn't not initialized yet", async function () { const client = new DBSQLClient(); const thriftClient = {}; @@ -194,7 +234,7 @@ describe('DBSQLClient.getClient', () => { expect(result).to.be.equal(thriftClient); }); - it('should update auth credentials each time when client is requested', async () => { + it('should update auth credentials each time when client is requested', async function () { const client = new DBSQLClient(); const thriftClient = {}; @@ -241,8 +281,8 @@ describe('DBSQLClient.getClient', () => { }); }); -describe('DBSQLClient.close', () => { - it('should close the connection if it was initiated', async () => { +describe('DBSQLClient.close', function () { + it('should close the connection if it was initiated', async function () { const client = new DBSQLClient(); client.client = {}; client.connectionProvider = {}; @@ -255,7 +295,7 @@ describe('DBSQLClient.close', () => { // No additional asserts needed - it should just reach this point }); - it('should do nothing if the connection does not exist', async () => { + it('should do nothing if the connection does not exist', async function () { const client = new DBSQLClient(); await client.close(); @@ -265,7 +305,7 @@ describe('DBSQLClient.close', () => { // No additional asserts needed - it should just reach this point }); - it('should close sessions that belong to it', async () => { + it('should close sessions that belong to it', async function () { const client = new DBSQLClient(); const thriftClientMock = { @@ -306,8 +346,8 @@ describe('DBSQLClient.close', () => { }); }); -describe('DBSQLClient.initAuthProvider', () => { - it('should use access token auth method', () => { +describe('DBSQLClient.initAuthProvider', function () { + it('should use access token auth method', function () { const client = new DBSQLClient(); const testAccessToken = 'token'; @@ -320,7 +360,7 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider.password).to.be.equal(testAccessToken); }); - it('should use access token auth method by default (compatibility)', () => { + it('should use access token auth method by default (compatibility)', function () { const client = new DBSQLClient(); const testAccessToken = 'token'; @@ -333,7 +373,7 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider.password).to.be.equal(testAccessToken); }); - it('should use Databricks OAuth method (AWS)', () => { + it('should use Databricks OAuth method (AWS)', function () { const client = new DBSQLClient(); const provider = client.initAuthProvider({ @@ -346,7 +386,7 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider.manager).to.be.instanceOf(AWSOAuthManager); }); - it('should use Databricks OAuth method (Azure)', () => { + it('should use Databricks OAuth method (Azure)', function () { const client = new DBSQLClient(); const provider = client.initAuthProvider({ @@ -359,7 +399,7 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider.manager).to.be.instanceOf(AzureOAuthManager); }); - it('should throw error when OAuth not supported for host', () => { + it('should throw error when OAuth not supported for host', function () { const client = new DBSQLClient(); expect(() => { @@ -371,7 +411,7 @@ describe('DBSQLClient.initAuthProvider', () => { }).to.throw(); }); - it('should use custom auth method', () => { + it('should use custom auth method', function () { const client = new DBSQLClient(); const customProvider = {}; @@ -384,7 +424,7 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider).to.be.equal(customProvider); }); - it('should use custom auth method (legacy way)', () => { + it('should use custom auth method (legacy way)', function () { const client = new DBSQLClient(); const customProvider = {}; From a1947cf41a79a447d0b1a48af836555cbd0b5666 Mon Sep 17 00:00:00 2001 From: Ron Warholic Date: Thu, 2 Nov 2023 10:52:19 -0400 Subject: [PATCH 2/2] Extend headers to allow custom headers --- lib/DBSQLClient.ts | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 53a00f95..9bffd60a 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -81,14 +81,12 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I private getConnectionOptions(options: ConnectionOptions): IConnectionOptions { return { ...options, - host: options.host, port: options.port || 443, path: prependSlash(options.path || ''), https: true, - socketTimeout: options.socketTimeout, - proxy: options.proxy, headers: { 'User-Agent': buildUserAgentString(options.clientId), + ...(options.headers || {}) }, }; }