diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 6d5b8bbd..9bffd60a 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -46,8 +46,6 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) { export default class DBSQLClient extends EventEmitter implements IDBSQLClient, IClientContext { private static defaultLogger?: IDBSQLLogger; - private connectionProvider?: IConnectionProvider; - private authProvider?: IAuthentication; private client?: TCLIService.Client; @@ -58,6 +56,10 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I private readonly logger: IDBSQLLogger; + private connectionProvider?: IConnectionProvider; + + private ConnectionProviderConstructor: new(o: IConnectionOptions) => IConnectionProvider; + private readonly thrift = thrift; private sessions = new CloseableCollection(); @@ -73,18 +75,18 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I super(); this.logger = options?.logger ?? DBSQLClient.getDefaultLogger(); this.logger.log(LogLevel.info, 'Created DBSQLClient'); + this.ConnectionProviderConstructor = options?.connectionProvider || HttpConnection; } private getConnectionOptions(options: ConnectionOptions): IConnectionOptions { return { - host: options.host, + ...options, port: options.port || 443, - path: prependSlash(options.path), + path: prependSlash(options.path || ''), https: true, - socketTimeout: options.socketTimeout, - proxy: options.proxy, headers: { 'User-Agent': buildUserAgentString(options.clientId), + ...(options.headers || {}) }, }; } @@ -129,7 +131,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise { this.authProvider = this.initAuthProvider(options, authProvider); - this.connectionProvider = new HttpConnection(this.getConnectionOptions(options)); + this.connectionProvider = new this.ConnectionProviderConstructor(this.getConnectionOptions(options)); const thriftConnection = await this.connectionProvider.getThriftConnection(); diff --git a/lib/connection/connections/HttpConnection.ts b/lib/connection/connections/HttpConnection.ts index 79f24c3d..8845f367 100644 --- a/lib/connection/connections/HttpConnection.ts +++ b/lib/connection/connections/HttpConnection.ts @@ -61,7 +61,7 @@ export default class HttpConnection implements IConnectionProvider { const httpsAgentOptions: https.AgentOptions = { ...this.getAgentDefaultOptions(), minVersion: 'TLSv1.2', - rejectUnauthorized: false, + rejectUnauthorized: !!this.options.rejectUnauthorized, ca: this.options.ca, cert: this.options.cert, key: this.options.key, diff --git a/lib/connection/contracts/IConnectionOptions.ts b/lib/connection/contracts/IConnectionOptions.ts index 340b6fae..bd9cd304 100644 --- a/lib/connection/contracts/IConnectionOptions.ts +++ b/lib/connection/contracts/IConnectionOptions.ts @@ -22,4 +22,5 @@ export default interface IConnectionOptions { ca?: Buffer | string; cert?: Buffer | string; key?: Buffer | string; + rejectUnauthorized?: boolean; } diff --git a/lib/contracts/IDBSQLClient.ts b/lib/contracts/IDBSQLClient.ts index 0a14f435..3b151ebd 100644 --- a/lib/contracts/IDBSQLClient.ts +++ b/lib/contracts/IDBSQLClient.ts @@ -1,11 +1,14 @@ import IDBSQLLogger from './IDBSQLLogger'; import IDBSQLSession from './IDBSQLSession'; import IAuthentication from '../connection/contracts/IAuthentication'; -import { ProxyOptions } from '../connection/contracts/IConnectionOptions'; +import IConnectionOptions, { ProxyOptions } from '../connection/contracts/IConnectionOptions'; import OAuthPersistence from '../connection/auth/DatabricksOAuth/OAuthPersistence'; +import IConnectionProvider from '../connection/contracts/IConnectionProvider'; export interface ClientOptions { logger?: IDBSQLLogger; + + connectionProvider: new (o: IConnectionOptions) => IConnectionProvider; } type AuthOptions = @@ -26,13 +29,8 @@ type AuthOptions = }; export type ConnectionOptions = { - host: string; - port?: number; - path: string; clientId?: string; - socketTimeout?: number; - proxy?: ProxyOptions; -} & AuthOptions; +} & AuthOptions & IConnectionOptions; export interface OpenSessionRequest { initialCatalog?: string; diff --git a/tests/unit/DBSQLClient.test.js b/tests/unit/DBSQLClient.test.js index 55231707..e2719b7c 100644 --- a/tests/unit/DBSQLClient.test.js +++ b/tests/unit/DBSQLClient.test.js @@ -20,7 +20,7 @@ class AuthProviderMock { } } -describe('DBSQLClient.connect', () => { +describe('DBSQLClient.connect', function () { const options = { host: '127.0.0.1', path: '', @@ -31,7 +31,7 @@ describe('DBSQLClient.connect', () => { HttpConnectionModule.default.restore?.(); }); - it('should prepend "/" to path if it is missing', async () => { + it('should prepend "/" to path if it is missing', async function () { const client = new DBSQLClient(); const path = 'example/path'; @@ -40,7 +40,7 @@ describe('DBSQLClient.connect', () => { expect(connectionOptions.path).to.equal(`/${path}`); }); - it('should not prepend "/" to path if it is already available', async () => { + it('should not prepend "/" to path if it is already available', async function () { const client = new DBSQLClient(); const path = '/example/path'; @@ -49,7 +49,7 @@ describe('DBSQLClient.connect', () => { expect(connectionOptions.path).to.equal(path); }); - it('should initialize connection state', async () => { + it('should initialize connection state', async function () { const client = new DBSQLClient(); expect(client.client).to.be.undefined; @@ -63,25 +63,65 @@ describe('DBSQLClient.connect', () => { expect(client.connectionProvider).to.be.instanceOf(HttpConnection); }); - it('should listen for Thrift connection events', async () => { - const client = new DBSQLClient(); - + it('uses the overridden connection provider', async function () { const thriftConnectionMock = { on: sinon.stub(), }; - sinon.stub(HttpConnectionModule, 'default').returns({ - getThriftConnection: () => Promise.resolve(thriftConnectionMock), + function fakeClient(opt) { + expect(opt).to.deep.include(options); + this.getThriftConnection = () => Promise.resolve(thriftConnectionMock); + } + + const client = new DBSQLClient({ + connectionProvider: fakeClient }); await client.connect(options); + expect(client.connectionProvider).to.be.instanceOf(fakeClient); + }); + + it('merges the connection options with the generated ones', async function () { + const thriftConnectionMock = { + on: sinon.stub(), + }; + + function fakeClient(opt) { + expect(opt).to.deep.include({ + myprop: 'abc', + https: true + }); + this.getThriftConnection = () => Promise.resolve(thriftConnectionMock); + } + const client = new DBSQLClient({ + connectionProvider: fakeClient + }); + + await client.connect({...options, myprop: 'abc'}); + expect(client.connectionProvider).to.be.instanceOf(fakeClient); + }); + + it('should listen for Thrift connection events', async function () { + const thriftConnectionMock = { + on: sinon.stub(), + }; + + const client = new DBSQLClient({ + connectionProvider: function() { + return { + getThriftConnection: () => Promise.resolve(thriftConnectionMock), + }; + } + }); + + await client.connect(options); expect(thriftConnectionMock.on.called).to.be.true; }); }); -describe('DBSQLClient.openSession', () => { - it('should successfully open session', async () => { +describe('DBSQLClient.openSession', function () { + it('should successfully open session', async function () { const client = new DBSQLClient(); sinon.stub(client, 'getClient').returns( @@ -99,7 +139,7 @@ describe('DBSQLClient.openSession', () => { expect(session).instanceOf(DBSQLSession); }); - it('should use initial namespace options', async () => { + it('should use initial namespace options', async function () { const client = new DBSQLClient(); sinon.stub(client, 'getClient').returns( @@ -129,7 +169,7 @@ describe('DBSQLClient.openSession', () => { } }); - it('should throw an exception when not connected', async () => { + it('should throw an exception when not connected', async function () { const client = new DBSQLClient(); client.connection = null; @@ -141,7 +181,7 @@ describe('DBSQLClient.openSession', () => { } }); - it('should throw an exception when the connection is lost', async () => { + it('should throw an exception when the connection is lost', async function () { const client = new DBSQLClient(); client.connection = { isConnected() { @@ -158,14 +198,14 @@ describe('DBSQLClient.openSession', () => { }); }); -describe('DBSQLClient.getClient', () => { +describe('DBSQLClient.getClient', function () { const options = { host: '127.0.0.1', path: '', token: 'dapi********************************', }; - it('should throw an error if not connected', async () => { + it('should throw an error if not connected', async function () { const client = new DBSQLClient(); try { await client.getClient(); @@ -178,7 +218,7 @@ describe('DBSQLClient.getClient', () => { } }); - it("should create client if wasn't not initialized yet", async () => { + it("should create client if wasn't not initialized yet", async function () { const client = new DBSQLClient(); const thriftClient = {}; @@ -194,7 +234,7 @@ describe('DBSQLClient.getClient', () => { expect(result).to.be.equal(thriftClient); }); - it('should update auth credentials each time when client is requested', async () => { + it('should update auth credentials each time when client is requested', async function () { const client = new DBSQLClient(); const thriftClient = {}; @@ -241,8 +281,8 @@ describe('DBSQLClient.getClient', () => { }); }); -describe('DBSQLClient.close', () => { - it('should close the connection if it was initiated', async () => { +describe('DBSQLClient.close', function () { + it('should close the connection if it was initiated', async function () { const client = new DBSQLClient(); client.client = {}; client.connectionProvider = {}; @@ -255,7 +295,7 @@ describe('DBSQLClient.close', () => { // No additional asserts needed - it should just reach this point }); - it('should do nothing if the connection does not exist', async () => { + it('should do nothing if the connection does not exist', async function () { const client = new DBSQLClient(); await client.close(); @@ -265,7 +305,7 @@ describe('DBSQLClient.close', () => { // No additional asserts needed - it should just reach this point }); - it('should close sessions that belong to it', async () => { + it('should close sessions that belong to it', async function () { const client = new DBSQLClient(); const thriftClientMock = { @@ -306,8 +346,8 @@ describe('DBSQLClient.close', () => { }); }); -describe('DBSQLClient.initAuthProvider', () => { - it('should use access token auth method', () => { +describe('DBSQLClient.initAuthProvider', function () { + it('should use access token auth method', function () { const client = new DBSQLClient(); const testAccessToken = 'token'; @@ -320,7 +360,7 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider.password).to.be.equal(testAccessToken); }); - it('should use access token auth method by default (compatibility)', () => { + it('should use access token auth method by default (compatibility)', function () { const client = new DBSQLClient(); const testAccessToken = 'token'; @@ -333,7 +373,7 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider.password).to.be.equal(testAccessToken); }); - it('should use Databricks OAuth method (AWS)', () => { + it('should use Databricks OAuth method (AWS)', function () { const client = new DBSQLClient(); const provider = client.initAuthProvider({ @@ -346,7 +386,7 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider.manager).to.be.instanceOf(AWSOAuthManager); }); - it('should use Databricks OAuth method (Azure)', () => { + it('should use Databricks OAuth method (Azure)', function () { const client = new DBSQLClient(); const provider = client.initAuthProvider({ @@ -359,7 +399,7 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider.manager).to.be.instanceOf(AzureOAuthManager); }); - it('should throw error when OAuth not supported for host', () => { + it('should throw error when OAuth not supported for host', function () { const client = new DBSQLClient(); expect(() => { @@ -371,7 +411,7 @@ describe('DBSQLClient.initAuthProvider', () => { }).to.throw(); }); - it('should use custom auth method', () => { + it('should use custom auth method', function () { const client = new DBSQLClient(); const customProvider = {}; @@ -384,7 +424,7 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider).to.be.equal(customProvider); }); - it('should use custom auth method (legacy way)', () => { + it('should use custom auth method (legacy way)', function () { const client = new DBSQLClient(); const customProvider = {};