From ab02d5332c1c813c1946db727e789a714e8576c1 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Wed, 15 Jan 2025 15:42:51 -0500 Subject: [PATCH 01/13] feat(NODE-6258): add signal support to cursor APIs --- src/client-side-encryption/auto_encrypter.ts | 20 +- .../client_encryption.ts | 10 +- src/client-side-encryption/state_machine.ts | 116 ++- src/cmap/connection.ts | 24 +- src/cmap/connection_pool.ts | 20 +- src/cmap/wire_protocol/on_data.ts | 11 +- src/collection.ts | 23 +- src/cursor/abstract_cursor.ts | 45 +- src/cursor/aggregation_cursor.ts | 3 +- src/cursor/find_cursor.ts | 3 +- src/db.ts | 18 +- src/index.ts | 1 + src/mongo_types.ts | 32 + src/operations/execute_operation.ts | 9 +- src/operations/list_collections.ts | 5 +- src/operations/operation.ts | 5 +- src/sdam/server.ts | 27 +- src/sdam/topology.ts | 12 +- src/utils.ts | 32 +- .../client-side-encryption/driver.test.ts | 10 +- ...lient_side_operations_timeout.unit.test.ts | 2 +- .../node-specific/abort_signal.test.ts | 872 ++++++++++++++++++ test/tools/utils.ts | 45 +- .../state_machine.test.ts | 15 +- 24 files changed, 1237 insertions(+), 123 deletions(-) create mode 100644 test/integration/node-specific/abort_signal.test.ts diff --git a/src/client-side-encryption/auto_encrypter.ts b/src/client-side-encryption/auto_encrypter.ts index edf731b92ac..1d7a9de4c66 100644 --- a/src/client-side-encryption/auto_encrypter.ts +++ b/src/client-side-encryption/auto_encrypter.ts @@ -11,6 +11,7 @@ import { kDecorateResult } from '../constants'; import { getMongoDBClientEncryption } from '../deps'; import { MongoRuntimeError } from '../error'; import { MongoClient, type MongoClientOptions } from '../mongo_client'; +import { type Abortable } from '../mongo_types'; import { MongoDBCollectionNamespace } from '../utils'; import { autoSelectSocketOptions } from './client_encryption'; import * as cryptoCallbacks from './crypto_callbacks'; @@ -372,8 +373,10 @@ export class AutoEncrypter { async encrypt( ns: string, cmd: Document, - options: CommandOptions = {} + options: CommandOptions & Abortable = {} ): Promise { + options.signal?.throwIfAborted(); + if (this._bypassEncryption) { // If `bypassAutoEncryption` has been specified, don't encrypt return cmd; @@ -398,7 +401,7 @@ export class AutoEncrypter { socketOptions: autoSelectSocketOptions(this._client.s.options) }); - return deserialize(await stateMachine.execute(this, context, options.timeoutContext), { + return deserialize(await stateMachine.execute(this, context, options), { promoteValues: false, promoteLongs: false }); @@ -407,7 +410,12 @@ export class AutoEncrypter { /** * Decrypt a command response */ - async decrypt(response: Uint8Array, options: CommandOptions = {}): Promise { + async decrypt( + response: Uint8Array, + options: CommandOptions & Abortable = {} + ): Promise { + options.signal?.throwIfAborted(); + const context = this._mongocrypt.makeDecryptionContext(response); context.id = this._contextCounter++; @@ -419,11 +427,7 @@ export class AutoEncrypter { socketOptions: autoSelectSocketOptions(this._client.s.options) }); - return await stateMachine.execute( - this, - context, - options.timeoutContext?.csotEnabled() ? options.timeoutContext : undefined - ); + return await stateMachine.execute(this, context, options); } /** diff --git a/src/client-side-encryption/client_encryption.ts b/src/client-side-encryption/client_encryption.ts index 7482c513d37..487969cf4de 100644 --- a/src/client-side-encryption/client_encryption.ts +++ b/src/client-side-encryption/client_encryption.ts @@ -225,7 +225,7 @@ export class ClientEncryption { TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })); const dataKey = deserialize( - await stateMachine.execute(this, context, timeoutContext) + await stateMachine.execute(this, context, { timeoutContext }) ) as DataKey; const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString( @@ -293,7 +293,9 @@ export class ClientEncryption { resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }) ); - const { v: dataKeys } = deserialize(await stateMachine.execute(this, context, timeoutContext)); + const { v: dataKeys } = deserialize( + await stateMachine.execute(this, context, { timeoutContext }) + ); if (dataKeys.length === 0) { return {}; } @@ -696,7 +698,7 @@ export class ClientEncryption { ? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })) : undefined; - const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext)); + const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext })); return v; } @@ -780,7 +782,7 @@ export class ClientEncryption { this._timeoutMS != null ? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })) : undefined; - const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext)); + const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext })); return v; } } diff --git a/src/client-side-encryption/state_machine.ts b/src/client-side-encryption/state_machine.ts index d10776abe73..07dad3c578a 100644 --- a/src/client-side-encryption/state_machine.ts +++ b/src/client-side-encryption/state_machine.ts @@ -15,8 +15,15 @@ import { CursorTimeoutContext } from '../cursor/abstract_cursor'; import { getSocks, type SocksLib } from '../deps'; import { MongoOperationTimeoutError } from '../error'; import { type MongoClient, type MongoClientOptions } from '../mongo_client'; +import { type Abortable } from '../mongo_types'; import { Timeout, type TimeoutContext, TimeoutError } from '../timeout'; -import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils'; +import { + addAbortListener, + BufferPool, + kDispose, + MongoDBCollectionNamespace, + promiseWithResolvers +} from '../utils'; import { autoSelectSocketOptions, type DataKey } from './client_encryption'; import { MongoCryptError } from './errors'; import { type MongocryptdManager } from './mongocryptd_manager'; @@ -189,7 +196,7 @@ export class StateMachine { async execute( executor: StateMachineExecutable, context: MongoCryptContext, - timeoutContext?: TimeoutContext + options: { timeoutContext?: TimeoutContext } & Abortable ): Promise { const keyVaultNamespace = executor._keyVaultNamespace; const keyVaultClient = executor._keyVaultClient; @@ -199,6 +206,7 @@ export class StateMachine { let result: Uint8Array | null = null; while (context.state !== MONGOCRYPT_CTX_DONE && context.state !== MONGOCRYPT_CTX_ERROR) { + options.signal?.throwIfAborted(); debug(`[context#${context.id}] ${stateToString.get(context.state) || context.state}`); switch (context.state) { @@ -214,7 +222,7 @@ export class StateMachine { metaDataClient, context.ns, filter, - timeoutContext + options ); if (collInfo) { context.addMongoOperationResponse(collInfo); @@ -235,9 +243,9 @@ export class StateMachine { // When we are using the shared library, we don't have a mongocryptd manager. const markedCommand: Uint8Array = mongocryptdManager ? await mongocryptdManager.withRespawn( - this.markCommand.bind(this, mongocryptdClient, context.ns, command, timeoutContext) + this.markCommand.bind(this, mongocryptdClient, context.ns, command, options) ) - : await this.markCommand(mongocryptdClient, context.ns, command, timeoutContext); + : await this.markCommand(mongocryptdClient, context.ns, command, options); context.addMongoOperationResponse(markedCommand); context.finishMongoOperation(); @@ -246,12 +254,7 @@ export class StateMachine { case MONGOCRYPT_CTX_NEED_MONGO_KEYS: { const filter = context.nextMongoOperation(); - const keys = await this.fetchKeys( - keyVaultClient, - keyVaultNamespace, - filter, - timeoutContext - ); + const keys = await this.fetchKeys(keyVaultClient, keyVaultNamespace, filter, options); if (keys.length === 0) { // See docs on EMPTY_V @@ -273,7 +276,7 @@ export class StateMachine { } case MONGOCRYPT_CTX_NEED_KMS: { - await Promise.all(this.requests(context, timeoutContext)); + await Promise.all(this.requests(context, options)); context.finishKMSRequests(); break; } @@ -315,11 +318,13 @@ export class StateMachine { * @param kmsContext - A C++ KMS context returned from the bindings * @returns A promise that resolves when the KMS reply has be fully parsed */ - async kmsRequest(request: MongoCryptKMSRequest, timeoutContext?: TimeoutContext): Promise { + async kmsRequest( + request: MongoCryptKMSRequest, + options?: { timeoutContext?: TimeoutContext } & Abortable + ): Promise { const parsedUrl = request.endpoint.split(':'); const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT; - const socketOptions = autoSelectSocketOptions(this.options.socketOptions || {}); - const options: tls.ConnectionOptions & { + const socketOptions: tls.ConnectionOptions & { host: string; port: number; autoSelectFamily?: boolean; @@ -328,7 +333,7 @@ export class StateMachine { host: parsedUrl[0], servername: parsedUrl[0], port, - ...socketOptions + ...autoSelectSocketOptions(this.options.socketOptions || {}) }; const message = request.message; const buffer = new BufferPool(); @@ -363,7 +368,7 @@ export class StateMachine { throw error; } try { - await this.setTlsOptions(providerTlsOptions, options); + await this.setTlsOptions(providerTlsOptions, socketOptions); } catch (err) { throw onerror(err); } @@ -380,23 +385,25 @@ export class StateMachine { .once('close', () => rejectOnNetSocketError(onclose())) .once('connect', () => resolveOnNetSocketConnect()); + let abortListener; + try { if (this.options.proxyOptions && this.options.proxyOptions.proxyHost) { const netSocketOptions = { + ...socketOptions, host: this.options.proxyOptions.proxyHost, - port: this.options.proxyOptions.proxyPort || 1080, - ...socketOptions + port: this.options.proxyOptions.proxyPort || 1080 }; netSocket.connect(netSocketOptions); await willConnect; try { socks ??= loadSocks(); - options.socket = ( + socketOptions.socket = ( await socks.SocksClient.createConnection({ existing_socket: netSocket, command: 'connect', - destination: { host: options.host, port: options.port }, + destination: { host: socketOptions.host, port: socketOptions.port }, proxy: { // host and port are ignored because we pass existing_socket host: 'iLoveJavaScript', @@ -412,7 +419,7 @@ export class StateMachine { } } - socket = tls.connect(options, () => { + socket = tls.connect(socketOptions, () => { socket.write(message); }); @@ -422,6 +429,11 @@ export class StateMachine { resolve } = promiseWithResolvers(); + abortListener = addAbortListener(options?.signal, function () { + destroySockets(); + rejectOnTlsSocketError(this.reason); + }); + socket .once('error', err => rejectOnTlsSocketError(onerror(err))) .once('close', () => rejectOnTlsSocketError(onclose())) @@ -436,8 +448,11 @@ export class StateMachine { resolve(); } }); - await (timeoutContext?.csotEnabled() - ? Promise.all([willResolveKmsRequest, Timeout.expires(timeoutContext?.remainingTimeMS)]) + await (options?.timeoutContext?.csotEnabled() + ? Promise.all([ + willResolveKmsRequest, + Timeout.expires(options.timeoutContext?.remainingTimeMS) + ]) : willResolveKmsRequest); } catch (error) { if (error instanceof TimeoutError) @@ -446,16 +461,17 @@ export class StateMachine { } finally { // There's no need for any more activity on this socket at this point. destroySockets(); + abortListener?.[kDispose](); } } - *requests(context: MongoCryptContext, timeoutContext?: TimeoutContext) { + *requests(context: MongoCryptContext, options?: { timeoutContext?: TimeoutContext } & Abortable) { for ( let request = context.nextKMSRequest(); request != null; request = context.nextKMSRequest() ) { - yield this.kmsRequest(request, timeoutContext); + yield this.kmsRequest(request, options); } } @@ -516,14 +532,16 @@ export class StateMachine { client: MongoClient, ns: string, filter: Document, - timeoutContext?: TimeoutContext + options?: { timeoutContext?: TimeoutContext } & Abortable ): Promise { const { db } = MongoDBCollectionNamespace.fromString(ns); const cursor = client.db(db).listCollections(filter, { promoteLongs: false, promoteValues: false, - timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol()) + timeoutContext: + options?.timeoutContext && new CursorTimeoutContext(options?.timeoutContext, Symbol()), + signal: options?.signal }); // There is always exactly zero or one matching documents, so this should always exhaust the cursor @@ -547,17 +565,30 @@ export class StateMachine { client: MongoClient, ns: string, command: Uint8Array, - timeoutContext?: TimeoutContext + options?: { timeoutContext?: TimeoutContext } & Abortable ): Promise { const { db } = MongoDBCollectionNamespace.fromString(ns); const bsonOptions = { promoteLongs: false, promoteValues: false }; const rawCommand = deserialize(command, bsonOptions); + const commandOptions: { + timeoutMS?: number; + signal?: AbortSignal; + } = { + timeoutMS: undefined, + signal: undefined + }; + + if (options?.timeoutContext?.csotEnabled()) { + commandOptions.timeoutMS = options.timeoutContext.remainingTimeMS; + } + if (options?.signal) { + commandOptions.signal = options.signal; + } + const response = await client.db(db).command(rawCommand, { ...bsonOptions, - ...(timeoutContext?.csotEnabled() - ? { timeoutMS: timeoutContext?.remainingTimeMS } - : undefined) + ...commandOptions }); return serialize(response, this.bsonOptions); @@ -575,17 +606,30 @@ export class StateMachine { client: MongoClient, keyVaultNamespace: string, filter: Uint8Array, - timeoutContext?: TimeoutContext + options?: { timeoutContext?: TimeoutContext } & Abortable ): Promise> { const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(keyVaultNamespace); + const commandOptions: { + timeoutContext?: CursorTimeoutContext; + signal?: AbortSignal; + } = { + timeoutContext: undefined, + signal: undefined + }; + + if (options?.timeoutContext != null) { + commandOptions.timeoutContext = new CursorTimeoutContext(options.timeoutContext, Symbol()); + } + if (options?.signal != null) { + commandOptions.signal = options.signal; + } + return client .db(dbName) .collection(collectionName, { readConcern: { level: 'majority' } }) - .find(deserialize(filter), { - timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol()) - }) + .find(deserialize(filter), commandOptions) .toArray(); } } diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 6df81b34d94..40644bf1be5 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -33,7 +33,7 @@ import { import type { ServerApi, SupportedNodeConnectionOptions } from '../mongo_client'; import { type MongoClientAuthProviders } from '../mongo_client_auth_providers'; import { MongoLoggableComponent, type MongoLogger, SeverityLevel } from '../mongo_logger'; -import { type CancellationToken, TypedEventEmitter } from '../mongo_types'; +import { type Abortable, type CancellationToken, TypedEventEmitter } from '../mongo_types'; import { ReadPreference, type ReadPreferenceLike } from '../read_preference'; import { ServerType } from '../sdam/common'; import { applySession, type ClientSession, updateSessionFromResponse } from '../sessions'; @@ -438,7 +438,7 @@ export class Connection extends TypedEventEmitter { private async *sendWire( message: WriteProtocolMessageType, - options: CommandOptions, + options: CommandOptions & Abortable, responseType?: MongoDBResponseConstructor ): AsyncGenerator { this.throwIfAborted(); @@ -453,7 +453,8 @@ export class Connection extends TypedEventEmitter { await this.writeCommand(message, { agreedCompressor: this.description.compressor ?? 'none', zlibCompressionLevel: this.description.zlibCompressionLevel, - timeoutContext: options.timeoutContext + timeoutContext: options.timeoutContext, + signal: options.signal }); if (options.noResponse || message.moreToCome) { @@ -473,7 +474,10 @@ export class Connection extends TypedEventEmitter { ); } - for await (const response of this.readMany({ timeoutContext: options.timeoutContext })) { + for await (const response of this.readMany({ + timeoutContext: options.timeoutContext, + signal: options.signal + })) { this.socket.setTimeout(0); const bson = response.parse(); @@ -676,7 +680,7 @@ export class Connection extends TypedEventEmitter { agreedCompressor?: CompressorName; zlibCompressionLevel?: number; timeoutContext?: TimeoutContext; - } + } & Abortable ): Promise { const finalCommand = options.agreedCompressor === 'none' || !OpCompressedRequest.canCompress(command) @@ -701,7 +705,7 @@ export class Connection extends TypedEventEmitter { if (this.socket.write(buffer)) return; - const drainEvent = once(this.socket, 'drain'); + const drainEvent = once(this.socket, 'drain', { signal: options.signal }); const timeout = options?.timeoutContext?.timeoutForSocketWrite; if (timeout) { try { @@ -729,9 +733,11 @@ export class Connection extends TypedEventEmitter { * * Note that `for-await` loops call `return` automatically when the loop is exited. */ - private async *readMany(options: { - timeoutContext?: TimeoutContext; - }): AsyncGenerator { + private async *readMany( + options: { + timeoutContext?: TimeoutContext; + } & Abortable + ): AsyncGenerator { try { this.dataEvents = onData(this.messageStream, options); this.messageStream.resume(); diff --git a/src/cmap/connection_pool.ts b/src/cmap/connection_pool.ts index bb2069de846..63c4860259c 100644 --- a/src/cmap/connection_pool.ts +++ b/src/cmap/connection_pool.ts @@ -25,10 +25,18 @@ import { MongoRuntimeError, MongoServerError } from '../error'; -import { CancellationToken, TypedEventEmitter } from '../mongo_types'; +import { type Abortable, CancellationToken, TypedEventEmitter } from '../mongo_types'; import type { Server } from '../sdam/server'; import { type TimeoutContext, TimeoutError } from '../timeout'; -import { type Callback, List, makeCounter, now, promiseWithResolvers } from '../utils'; +import { + addAbortListener, + type Callback, + kDispose, + List, + makeCounter, + now, + promiseWithResolvers +} from '../utils'; import { connect } from './connect'; import { Connection, type ConnectionEvents, type ConnectionOptions } from './connection'; import { @@ -316,7 +324,7 @@ export class ConnectionPool extends TypedEventEmitter { * will be held by the pool. This means that if a connection is checked out it MUST be checked back in or * explicitly destroyed by the new owner. */ - async checkOut(options: { timeoutContext: TimeoutContext }): Promise { + async checkOut(options: { timeoutContext: TimeoutContext } & Abortable): Promise { const checkoutTime = now(); this.emitAndLog( ConnectionPool.CONNECTION_CHECK_OUT_STARTED, @@ -334,6 +342,11 @@ export class ConnectionPool extends TypedEventEmitter { checkoutTime }; + const abortListener = addAbortListener(options.signal, function () { + waitQueueMember.cancelled = true; + reject(this.reason); + }); + this.waitQueue.push(waitQueueMember); process.nextTick(() => this.processWaitQueue()); @@ -364,6 +377,7 @@ export class ConnectionPool extends TypedEventEmitter { } throw error; } finally { + abortListener?.[kDispose](); timeout?.clear(); } } diff --git a/src/cmap/wire_protocol/on_data.ts b/src/cmap/wire_protocol/on_data.ts index f6732618330..82dd7b40dbe 100644 --- a/src/cmap/wire_protocol/on_data.ts +++ b/src/cmap/wire_protocol/on_data.ts @@ -1,7 +1,8 @@ import { type EventEmitter } from 'events'; +import { type Abortable } from '../../mongo_types'; import { type TimeoutContext } from '../../timeout'; -import { List, promiseWithResolvers } from '../../utils'; +import { addAbortListener, kDispose, List, promiseWithResolvers } from '../../utils'; /** * @internal @@ -21,8 +22,10 @@ type PendingPromises = Omit< */ export function onData( emitter: EventEmitter, - { timeoutContext }: { timeoutContext?: TimeoutContext } + { timeoutContext, signal }: { timeoutContext?: TimeoutContext } & Abortable ) { + signal?.throwIfAborted(); + // Setup pending events and pending promise lists /** * When the caller has not yet called .next(), we store the @@ -90,6 +93,9 @@ export function onData( // Adding event handlers emitter.on('data', eventHandler); emitter.on('error', errorHandler); + const abortListener = addAbortListener(signal, function () { + errorHandler(this.reason); + }); const timeoutForSocketRead = timeoutContext?.timeoutForSocketRead; timeoutForSocketRead?.throwIfExpired(); @@ -115,6 +121,7 @@ export function onData( // Adding event handlers emitter.off('data', eventHandler); emitter.off('error', errorHandler); + abortListener?.[kDispose](); finished = true; timeoutForSocketRead?.clear(); const doneResult = { value: undefined, done: finished } as const; diff --git a/src/collection.ts b/src/collection.ts index d7cdc12e8e5..468797bd0fd 100644 --- a/src/collection.ts +++ b/src/collection.ts @@ -14,6 +14,7 @@ import type { Db } from './db'; import { MongoInvalidArgumentError, MongoOperationTimeoutError } from './error'; import type { MongoClient, PkFactory } from './mongo_client'; import type { + Abortable, Filter, Flatten, OptionalUnlessRequiredId, @@ -505,7 +506,7 @@ export class Collection { async findOne(filter: Filter): Promise | null>; async findOne( filter: Filter, - options: Omit + options: Omit & Abortable ): Promise | null>; // allow an override of the schema. @@ -513,12 +514,12 @@ export class Collection { async findOne(filter: Filter): Promise; async findOne( filter: Filter, - options?: Omit + options?: Omit & Abortable ): Promise; async findOne( filter: Filter = {}, - options: FindOptions = {} + options: FindOptions & Abortable = {} ): Promise | null> { const cursor = this.find(filter, options).limit(-1).batchSize(1); const res = await cursor.next(); @@ -532,9 +533,15 @@ export class Collection { * @param filter - The filter predicate. If unspecified, then all documents in the collection will match the predicate */ find(): FindCursor>; - find(filter: Filter, options?: FindOptions): FindCursor>; - find(filter: Filter, options?: FindOptions): FindCursor; - find(filter: Filter = {}, options: FindOptions = {}): FindCursor> { + find(filter: Filter, options?: FindOptions & Abortable): FindCursor>; + find( + filter: Filter, + options?: FindOptions & Abortable + ): FindCursor; + find( + filter: Filter = {}, + options: FindOptions & Abortable = {} + ): FindCursor> { return new FindCursor>( this.client, this.s.namespace, @@ -792,7 +799,7 @@ export class Collection { */ async countDocuments( filter: Filter = {}, - options: CountDocumentsOptions = {} + options: CountDocumentsOptions & Abortable = {} ): Promise { const pipeline = []; pipeline.push({ $match: filter }); @@ -1006,7 +1013,7 @@ export class Collection { */ aggregate( pipeline: Document[] = [], - options?: AggregateOptions + options?: AggregateOptions & Abortable ): AggregationCursor { if (!Array.isArray(pipeline)) { throw new MongoInvalidArgumentError( diff --git a/src/cursor/abstract_cursor.ts b/src/cursor/abstract_cursor.ts index 8eccdfcf630..feaf07347cf 100644 --- a/src/cursor/abstract_cursor.ts +++ b/src/cursor/abstract_cursor.ts @@ -12,7 +12,7 @@ import { MongoTailableCursorError } from '../error'; import type { MongoClient } from '../mongo_client'; -import { TypedEventEmitter } from '../mongo_types'; +import { type Abortable, TypedEventEmitter } from '../mongo_types'; import { executeOperation } from '../operations/execute_operation'; import { GetMoreOperation } from '../operations/get_more'; import { KillCursorsOperation } from '../operations/kill_cursors'; @@ -22,7 +22,13 @@ import { type AsyncDisposable, configureResourceManagement } from '../resource_m import type { Server } from '../sdam/server'; import { ClientSession, maybeClearPinnedConnection } from '../sessions'; import { type CSOTTimeoutContext, type Timeout, TimeoutContext } from '../timeout'; -import { type MongoDBNamespace, squashError } from '../utils'; +import { + addAbortListener, + type Disposable, + kDispose, + type MongoDBNamespace, + squashError +} from '../utils'; /** * @internal @@ -247,12 +253,14 @@ export abstract class AbstractCursor< /** @internal */ protected deserializationOptions: OnDemandDocumentDeserializeOptions; + protected signal: AbortSignal | undefined; + private abortListener: Disposable | undefined; /** @internal */ protected constructor( client: MongoClient, namespace: MongoDBNamespace, - options: AbstractCursorOptions = {} + options: AbstractCursorOptions & Abortable = {} ) { super(); @@ -352,6 +360,11 @@ export abstract class AbstractCursor< }; this.timeoutContext = options.timeoutContext; + this.signal = options.signal; + this.abortListener = addAbortListener( + this.signal, + () => void this.close().then(undefined, squashError) + ); } /** @@ -455,6 +468,8 @@ export abstract class AbstractCursor< } async *[Symbol.asyncIterator](): AsyncGenerator { + this.signal?.throwIfAborted(); + if (this.closed) { return; } @@ -481,6 +496,8 @@ export abstract class AbstractCursor< } yield document; + + this.signal?.throwIfAborted(); } } finally { // Only close the cursor if it has not already been closed. This finally clause handles @@ -496,9 +513,16 @@ export abstract class AbstractCursor< } stream(options?: CursorStreamOptions): Readable & AsyncIterable { + const readable = new ReadableCursorStream(this); + const abortListener = addAbortListener(this.signal, function () { + readable.destroy(this.reason); + }); + readable.once('end', () => { + abortListener?.[kDispose](); + }); + if (options?.transform) { const transform = options.transform; - const readable = new ReadableCursorStream(this); const transformedStream = readable.pipe( new Transform({ @@ -522,10 +546,12 @@ export abstract class AbstractCursor< return transformedStream; } - return new ReadableCursorStream(this); + return readable; } async hasNext(): Promise { + this.signal?.throwIfAborted(); + if (this.cursorId === Long.ZERO) { return false; } @@ -551,6 +577,8 @@ export abstract class AbstractCursor< /** Get the next available document from the cursor, returns null if no more documents are available. */ async next(): Promise { + this.signal?.throwIfAborted(); + if (this.cursorId === Long.ZERO) { throw new MongoCursorExhaustedError(); } @@ -581,6 +609,8 @@ export abstract class AbstractCursor< * Try to get the next available document from the cursor or `null` if an empty batch is returned */ async tryNext(): Promise { + this.signal?.throwIfAborted(); + if (this.cursorId === Long.ZERO) { throw new MongoCursorExhaustedError(); } @@ -620,6 +650,8 @@ export abstract class AbstractCursor< * @deprecated - Will be removed in a future release. Use for await...of instead. */ async forEach(iterator: (doc: TSchema) => boolean | void): Promise { + this.signal?.throwIfAborted(); + if (typeof iterator !== 'function') { throw new MongoInvalidArgumentError('Argument "iterator" must be a function'); } @@ -645,6 +677,8 @@ export abstract class AbstractCursor< * cursor.rewind() can be used to reset the cursor. */ async toArray(): Promise { + this.signal?.throwIfAborted(); + const array: TSchema[] = []; // at the end of the loop (since readBufferedDocuments is called) the buffer will be empty // then, the 'await of' syntax will run a getMore call @@ -968,6 +1002,7 @@ export abstract class AbstractCursor< /** @internal */ private async cleanup(timeoutMS?: number, error?: Error) { + this.abortListener?.[kDispose](); this.isClosed = true; const session = this.cursorSession; const timeoutContextForKillCursors = (): CursorTimeoutContext | undefined => { diff --git a/src/cursor/aggregation_cursor.ts b/src/cursor/aggregation_cursor.ts index cace0a4b6a2..5598485c822 100644 --- a/src/cursor/aggregation_cursor.ts +++ b/src/cursor/aggregation_cursor.ts @@ -8,6 +8,7 @@ import { validateExplainTimeoutOptions } from '../explain'; import type { MongoClient } from '../mongo_client'; +import { type Abortable } from '../mongo_types'; import { AggregateOperation, type AggregateOptions } from '../operations/aggregate'; import { executeOperation } from '../operations/execute_operation'; import type { ClientSession } from '../sessions'; @@ -39,7 +40,7 @@ export class AggregationCursor extends ExplainableCursor client: MongoClient, namespace: MongoDBNamespace, pipeline: Document[] = [], - options: AggregateOptions = {} + options: AggregateOptions & Abortable = {} ) { super(client, namespace, options); diff --git a/src/cursor/find_cursor.ts b/src/cursor/find_cursor.ts index 28cb373614d..4c89307e66a 100644 --- a/src/cursor/find_cursor.ts +++ b/src/cursor/find_cursor.ts @@ -72,7 +72,8 @@ export class FindCursor extends ExplainableCursor { const options = { ...this.findOptions, // NOTE: order matters here, we may need to refine this ...this.cursorOptions, - session + session, + signal: this.signal }; if (options.explain) { diff --git a/src/db.ts b/src/db.ts index 121d6fc4f1e..bcadc4937e7 100644 --- a/src/db.ts +++ b/src/db.ts @@ -8,7 +8,7 @@ import { ListCollectionsCursor } from './cursor/list_collections_cursor'; import { RunCommandCursor, type RunCursorCommandOptions } from './cursor/run_command_cursor'; import { MongoInvalidArgumentError } from './error'; import type { MongoClient, PkFactory } from './mongo_client'; -import type { TODO_NODE_3286 } from './mongo_types'; +import type { Abortable, TODO_NODE_3286 } from './mongo_types'; import type { AggregateOptions } from './operations/aggregate'; import { CollectionsOperation } from './operations/collections'; import { @@ -273,7 +273,7 @@ export class Db { * @param command - The command to run * @param options - Optional settings for the command */ - async command(command: Document, options?: RunCommandOptions): Promise { + async command(command: Document, options?: RunCommandOptions & Abortable): Promise { // Intentionally, we do not inherit options from parent for this operation. return await executeOperation( this.client, @@ -284,7 +284,8 @@ export class Db { ...resolveBSONOptions(options), timeoutMS: options?.timeoutMS ?? this.timeoutMS, session: options?.session, - readPreference: options?.readPreference + readPreference: options?.readPreference, + signal: options?.signal }) ) ); @@ -351,22 +352,25 @@ export class Db { */ listCollections( filter: Document, - options: Exclude & { nameOnly: true } + options: Exclude & { nameOnly: true } & Abortable ): ListCollectionsCursor>; listCollections( filter: Document, - options: Exclude & { nameOnly: false } + options: Exclude & { nameOnly: false } & Abortable ): ListCollectionsCursor; listCollections< T extends Pick | CollectionInfo = | Pick | CollectionInfo - >(filter?: Document, options?: ListCollectionsOptions): ListCollectionsCursor; + >(filter?: Document, options?: ListCollectionsOptions & Abortable): ListCollectionsCursor; listCollections< T extends Pick | CollectionInfo = | Pick | CollectionInfo - >(filter: Document = {}, options: ListCollectionsOptions = {}): ListCollectionsCursor { + >( + filter: Document = {}, + options: ListCollectionsOptions & Abortable = {} + ): ListCollectionsCursor { return new ListCollectionsCursor(this, filter, resolveOptions(this, options)); } diff --git a/src/index.ts b/src/index.ts index 794bed95884..a80cf54b891 100644 --- a/src/index.ts +++ b/src/index.ts @@ -430,6 +430,7 @@ export type { SeverityLevel } from './mongo_logger'; export type { + Abortable, CommonEvents, EventsDescription, GenericListener, diff --git a/src/mongo_types.ts b/src/mongo_types.ts index be116b36997..f042cf661bf 100644 --- a/src/mongo_types.ts +++ b/src/mongo_types.ts @@ -474,6 +474,38 @@ export class TypedEventEmitter extends EventEm /** @public */ export class CancellationToken extends TypedEventEmitter<{ cancel(): void }> {} +/** @public */ +export type Abortable = { + /** + * When provided the corresponding `AbortController` can be used to cancel an asynchronous action. + * + * The driver will convert the abort event into a promise rejection with an error that has the name `'AbortError'`. + * + * The cause of the error will be set to `signal.reason` + * + * @example + * ```js + * const controller = new AbortController(); + * const { signal } = controller; + * req,on('close', () => controller.abort(new Error('Request aborted by user'))); + * + * try { + * const res = await fetch('...', { signal }); + * await collection.insertOne(await res.json(), { signal }); + * catch (error) { + * if (error.name === 'AbortError') { + * // error is MongoAbortError or DOMException, + * // both represent the signal being aborted + * error.cause === signal.reason; // true + * } + * } + * ``` + * + * @see MongoAbortError + */ + signal?: AbortSignal | undefined; +}; + /** * Helper types for dot-notation filter attributes */ diff --git a/src/operations/execute_operation.ts b/src/operations/execute_operation.ts index 81601a6e160..30bf0dd8343 100644 --- a/src/operations/execute_operation.ts +++ b/src/operations/execute_operation.ts @@ -64,7 +64,10 @@ export async function executeOperation< throw new MongoRuntimeError('This method requires a valid operation instance'); } + // Like CSOT, an operation signal interruption does not relate to auto-connect + operation.options.signal?.throwIfAborted(); const topology = await autoConnect(client); + operation.options.signal?.throwIfAborted(); // The driver sessions spec mandates that we implicitly create sessions for operations // that are not explicitly provided with a session. @@ -198,7 +201,8 @@ async function tryOperation< let server = await topology.selectServer(selector, { session, operationName: operation.commandName, - timeoutContext + timeoutContext, + signal: operation.options.signal }); const hasReadAspect = operation.hasAspect(Aspect.READ_OPERATION); @@ -260,7 +264,8 @@ async function tryOperation< server = await topology.selectServer(selector, { session, operationName: operation.commandName, - previousServer + previousServer, + signal: operation.options.signal }); if (hasWriteAspect && !supportsRetryableWrites(server)) { diff --git a/src/operations/list_collections.ts b/src/operations/list_collections.ts index 6b3296fcf00..57f8aff45e6 100644 --- a/src/operations/list_collections.ts +++ b/src/operations/list_collections.ts @@ -2,6 +2,7 @@ import type { Binary, Document } from '../bson'; import { CursorResponse } from '../cmap/wire_protocol/responses'; import { type CursorTimeoutContext, type CursorTimeoutMode } from '../cursor/abstract_cursor'; import type { Db } from '../db'; +import { type Abortable } from '../mongo_types'; import type { Server } from '../sdam/server'; import type { ClientSession } from '../sessions'; import { type TimeoutContext } from '../timeout'; @@ -10,7 +11,9 @@ import { CommandOperation, type CommandOperationOptions } from './command'; import { Aspect, defineAspects } from './operation'; /** @public */ -export interface ListCollectionsOptions extends Omit { +export interface ListCollectionsOptions + extends Omit, + Abortable { /** Since 4.0: If true, will only return the collection name in the response, and will omit additional info */ nameOnly?: boolean; /** Since 4.0: If true and nameOnly is true, allows a user without the required privilege (i.e. listCollections action on the database) to run the command when access control is enforced. */ diff --git a/src/operations/operation.ts b/src/operations/operation.ts index 029047543a3..190f2a522bd 100644 --- a/src/operations/operation.ts +++ b/src/operations/operation.ts @@ -1,4 +1,5 @@ import { type BSONSerializeOptions, type Document, resolveBSONOptions } from '../bson'; +import { type Abortable } from '../mongo_types'; import { ReadPreference, type ReadPreferenceLike } from '../read_preference'; import type { Server } from '../sdam/server'; import type { ClientSession } from '../sessions'; @@ -59,7 +60,7 @@ export abstract class AbstractOperation { // BSON serialization options bsonOptions?: BSONSerializeOptions; - options: OperationOptions; + options: OperationOptions & Abortable; /** Specifies the time an operation will run until it throws a timeout error. */ timeoutMS?: number; @@ -68,7 +69,7 @@ export abstract class AbstractOperation { static aspects?: Set; - constructor(options: OperationOptions = {}) { + constructor(options: OperationOptions & Abortable = {}) { this.readPreference = this.hasAspect(Aspect.WRITE_OPERATION) ? ReadPreference.primary : (ReadPreference.fromOptions(options) ?? ReadPreference.primary); diff --git a/src/sdam/server.ts b/src/sdam/server.ts index 1aa19a3e18c..9094c2fc4ef 100644 --- a/src/sdam/server.ts +++ b/src/sdam/server.ts @@ -36,7 +36,7 @@ import { needsRetryableWriteLabel } from '../error'; import type { ServerApi } from '../mongo_client'; -import { TypedEventEmitter } from '../mongo_types'; +import { type Abortable, TypedEventEmitter } from '../mongo_types'; import type { GetMoreOptions } from '../operations/get_more'; import type { ClientSession } from '../sessions'; import { type TimeoutContext } from '../timeout'; @@ -107,7 +107,7 @@ export type ServerEvents = { /** @internal */ export type ServerCommandOptions = Omit & { timeoutContext: TimeoutContext; -}; +} & Abortable; /** @internal */ export class Server extends TypedEventEmitter { @@ -285,7 +285,7 @@ export class Server extends TypedEventEmitter { public async command( ns: MongoDBNamespace, cmd: Document, - options: ServerCommandOptions, + paramOpts: ServerCommandOptions, responseType?: MongoDBResponseConstructor ): Promise { if (ns.db == null || typeof ns === 'string') { @@ -297,24 +297,25 @@ export class Server extends TypedEventEmitter { } // Clone the options - const finalOptions = Object.assign({}, options, { + const options = { + ...paramOpts, wireProtocolCommand: false, directConnection: this.topology.s.options.directConnection - }); + }; // There are cases where we need to flag the read preference not to get sent in // the command, such as pre-5.0 servers attempting to perform an aggregate write // with a non-primary read preference. In this case the effective read preference // (primary) is not the same as the provided and must be removed completely. - if (finalOptions.omitReadPreference) { - delete finalOptions.readPreference; + if (options.omitReadPreference) { + delete options.readPreference; } if (this.description.iscryptd) { - finalOptions.omitMaxTimeMS = true; + options.omitMaxTimeMS = true; } - const session = finalOptions.session; + const session = options.session; let conn = session?.pinnedConnection; this.incrementOperationCount(); @@ -333,11 +334,11 @@ export class Server extends TypedEventEmitter { try { try { - const res = await conn.command(ns, cmd, finalOptions, responseType); + const res = await conn.command(ns, cmd, options, responseType); throwIfWriteConcernError(res); return res; } catch (commandError) { - throw this.decorateCommandError(conn, cmd, finalOptions, commandError); + throw this.decorateCommandError(conn, cmd, options, commandError); } } catch (operationError) { if ( @@ -346,11 +347,11 @@ export class Server extends TypedEventEmitter { ) { await this.pool.reauthenticate(conn); try { - const res = await conn.command(ns, cmd, finalOptions, responseType); + const res = await conn.command(ns, cmd, options, responseType); throwIfWriteConcernError(res); return res; } catch (commandError) { - throw this.decorateCommandError(conn, cmd, finalOptions, commandError); + throw this.decorateCommandError(conn, cmd, options, commandError); } } else { throw operationError; diff --git a/src/sdam/topology.ts b/src/sdam/topology.ts index b6cad4097e8..6f87e922710 100644 --- a/src/sdam/topology.ts +++ b/src/sdam/topology.ts @@ -31,15 +31,17 @@ import { } from '../error'; import type { MongoClient, ServerApi } from '../mongo_client'; import { MongoLoggableComponent, type MongoLogger, SeverityLevel } from '../mongo_logger'; -import { TypedEventEmitter } from '../mongo_types'; +import { type Abortable, TypedEventEmitter } from '../mongo_types'; import { ReadPreference, type ReadPreferenceLike } from '../read_preference'; import type { ClientSession } from '../sessions'; import { Timeout, TimeoutContext, TimeoutError } from '../timeout'; import type { Transaction } from '../transactions'; import { + addAbortListener, type Callback, type EventEmitterWithState, HostAddress, + kDispose, List, makeStateMachine, now, @@ -525,7 +527,7 @@ export class Topology extends TypedEventEmitter { */ async selectServer( selector: string | ReadPreference | ServerSelector, - options: SelectServerOptions + options: SelectServerOptions & Abortable ): Promise { let serverSelector; if (typeof selector !== 'function') { @@ -602,6 +604,11 @@ export class Topology extends TypedEventEmitter { previousServer: options.previousServer }; + const abortListener = addAbortListener(options.signal, function () { + waitQueueMember.cancelled = true; + reject(this.reason); + }); + this.waitQueue.push(waitQueueMember); processWaitQueue(this); @@ -647,6 +654,7 @@ export class Topology extends TypedEventEmitter { // Other server selection error throw error; } finally { + abortListener?.[kDispose](); if (options.timeoutContext?.clearServerSelectionTimeout) timeout?.clear(); } } diff --git a/src/utils.ts b/src/utils.ts index c23161612a8..1db48e35faa 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -27,6 +27,7 @@ import { MongoRuntimeError } from './error'; import type { MongoClient } from './mongo_client'; +import { type Abortable } from './mongo_types'; import type { CommandOperationOptions, OperationParent } from './operations/command'; import type { Hint, OperationOptions } from './operations/operation'; import { ReadConcern } from './read_concern'; @@ -1349,19 +1350,24 @@ export const randomBytes = promisify(crypto.randomBytes); * @param ee - An event emitter that may emit `ev` * @param name - An event name to wait for */ -export async function once(ee: EventEmitter, name: string): Promise { +export async function once(ee: EventEmitter, name: string, options?: Abortable): Promise { + options?.signal?.throwIfAborted(); + const { promise, resolve, reject } = promiseWithResolvers(); const onEvent = (data: T) => resolve(data); const onError = (error: Error) => reject(error); + const abortListener = addAbortListener(options?.signal, function () { + reject(this.reason); + }); ee.once(name, onEvent).once('error', onError); + try { - const res = await promise; - ee.off('error', onError); - return res; - } catch (error) { + return await promise; + } finally { ee.off(name, onEvent); - throw error; + ee.off('error', onError); + abortListener?.[kDispose](); } } @@ -1468,3 +1474,17 @@ export function decorateDecryptionResult( decorateDecryptionResult(decrypted[k], originalValue, false); } } + +export const kDispose: unique symbol = (Symbol.dispose as any) ?? Symbol('dispose'); +export interface Disposable { + [kDispose](): void; +} + +export function addAbortListener( + signal: AbortSignal | undefined | null, + listener: (this: AbortSignal, event: Event) => void +): Disposable | undefined { + if (signal == null) return; + signal.addEventListener('abort', listener, { once: true }); + return { [kDispose]: () => signal.removeEventListener('abort', listener) }; +} diff --git a/test/integration/client-side-encryption/driver.test.ts b/test/integration/client-side-encryption/driver.test.ts index 720d67c4565..a7c1e617c2a 100644 --- a/test/integration/client-side-encryption/driver.test.ts +++ b/test/integration/client-side-encryption/driver.test.ts @@ -831,12 +831,12 @@ describe('CSOT', function () { describe('State machine', function () { const stateMachine = new StateMachine({} as any); - const timeoutContext = () => { - return new CSOTTimeoutContext({ + const timeoutContext = () => ({ + timeoutContext: new CSOTTimeoutContext({ timeoutMS: 1000, serverSelectionTimeoutMS: 30000 - }); - }; + }) + }); const timeoutMS = 1000; @@ -1001,7 +1001,7 @@ describe('CSOT', function () { const { result: error } = await measureDuration(() => stateMachine - .fetchKeys(client, 'test.test', BSON.serialize({}), timeoutContext) + .fetchKeys(client, 'test.test', BSON.serialize({}), { timeoutContext }) .catch(e => e) ); expect(error).to.be.instanceOf(MongoOperationTimeoutError); diff --git a/test/integration/client-side-operations-timeout/client_side_operations_timeout.unit.test.ts b/test/integration/client-side-operations-timeout/client_side_operations_timeout.unit.test.ts index 58bfb79de23..3515aaad921 100644 --- a/test/integration/client-side-operations-timeout/client_side_operations_timeout.unit.test.ts +++ b/test/integration/client-side-operations-timeout/client_side_operations_timeout.unit.test.ts @@ -138,7 +138,7 @@ describe('CSOT spec unit tests', function () { timeoutMS: 500, serverSelectionTimeoutMS: 30000 }); - const err = await stateMachine.kmsRequest(request, timeoutContext).catch(e => e); + const err = await stateMachine.kmsRequest(request, { timeoutContext }).catch(e => e); expect(err).to.be.instanceOf(MongoOperationTimeoutError); expect(err.errmsg).to.equal('KMS request timed out'); }); diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts new file mode 100644 index 00000000000..cef6fab66ce --- /dev/null +++ b/test/integration/node-specific/abort_signal.test.ts @@ -0,0 +1,872 @@ +import * as events from 'node:events'; +import { TLSSocket } from 'node:tls'; +import * as util from 'node:util'; + +import { expect } from 'chai'; +import * as semver from 'semver'; +import * as sinon from 'sinon'; + +import { + type AbstractCursor, + AggregationCursor, + type AutoEncryptionOptions, + ClientEncryption, + type Collection, + type Db, + FindCursor, + ListCollectionsCursor, + type Log, + type MongoClient, + promiseWithResolvers, + ReadPreference, + setDifference, + StateMachine +} from '../../mongodb'; +import { + clearFailPoint, + configureFailPoint, + DOMException, + findLast, + sleep +} from '../../tools/utils'; + +const failPointMetadata = { requires: { mongodb: '>=4.4' } }; + +const isAsyncGenerator = (value: any): value is AsyncGenerator => + value[Symbol.toStringTag] === 'AsyncGenerator'; + +const makeDescriptorGetter = value => prop => [prop, Object.getOwnPropertyDescriptor(value, prop)]; + +function getAllProps(value) { + const props = []; + for (let obj = value; obj !== Object.prototype; obj = Object.getPrototypeOf(obj)) { + props.push(...Object.getOwnPropertyNames(obj).map(makeDescriptorGetter(obj))); + props.push(...Object.getOwnPropertySymbols(obj).map(makeDescriptorGetter(obj))); + } + return props; +} + +describe('AbortSignal support', () => { + let client: MongoClient; + let db: Db; + let collection: Collection<{ a: number; ssn: string }>; + const logs: Log[] = []; + + beforeEach(async function () { + logs.length = 0; + + client = this.configuration.newClient( + {}, + { + monitorCommands: true, + appName: 'abortSignalClient', + __enableMongoLogger: true, + __internalLoggerConfig: { MONGODB_LOG_SERVER_SELECTION: 'debug' }, + mongodbLogPath: { write: log => logs.push(log) }, + serverSelectionTimeoutMS: 10_000, + maxPoolSize: 1 + } + ); + await client.connect(); + db = client.db('abortSignal'); + collection = db.collection('support'); + }); + + afterEach(async function () { + logs.length = 0; + const utilClient = this.configuration.newClient(); + try { + await utilClient.db('abortSignal').collection('support').deleteMany({}); + } finally { + await utilClient.close(); + } + await client?.close(); + }); + + function testCursor(cursorName: string, constructor: any) { + let method; + let filter; + + beforeEach(function () { + method = (cursorName === 'listCollections' ? db[cursorName] : collection[cursorName]).bind( + cursorName === 'listCollections' ? db : collection + ); + filter = cursorName === 'aggregate' ? [] : {}; + }); + + describe(`when ${cursorName}() is given a signal`, () => { + const cursorAPIs = { + tryNext: [], + hasNext: [], + next: [], + toArray: [], + forEach: [async () => true], + [Symbol.asyncIterator]: [] + }; + + async function captureCursorAPIResult(cursor, cursorAPI, args) { + try { + const apiReturnValue = cursor[cursorAPI](...args); + return isAsyncGenerator(apiReturnValue) + ? await apiReturnValue.next() + : await apiReturnValue; + } catch (error) { + return error; + } + } + + it('should test all the async APIs', () => { + const knownNotTested = [ + 'asyncDispose', + 'close', + 'getMore', + 'cursorInit', + 'fetchBatch', + 'cleanup', + 'transformDocument', + Symbol.asyncDispose + ]; + + const allCursorAsyncAPIs = getAllProps(constructor.prototype) + .filter(([, { value }]) => util.types.isAsyncFunction(value)) + .map(([key]) => key); + + expect(setDifference(Object.keys(cursorAPIs), allCursorAsyncAPIs)).to.be.empty; + + const notTested = allCursorAsyncAPIs.filter( + fn => knownNotTested.includes(fn) && Object.keys(cursorAPIs).includes(fn) + ); + + expect(notTested, 'new async function found, should respond to signal state or be internal') + .to.be.empty; + }); + + describe('and the signal is already aborted', () => { + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(() => { + const controller = new AbortController(); + signal = controller.signal; + controller.abort(); + + cursor = method(cursorName === 'aggregate' ? [] : {}, { signal }); + }); + + afterEach(async () => { + await cursor.close(); + }); + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + it(`rejects ${cursorAPI.toString()}`, async () => { + const result = await captureCursorAPIResult(cursor, cursorAPI, args); + expect(result).to.be.instanceOf(DOMException); + }); + } + }); + + describe('and the signal is aborted after use', () => { + let controller: AbortController; + let signal: AbortSignal; + let cursor: FindCursor<{ a: number }>; + + beforeEach(() => { + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal }); + }); + + afterEach(async () => { + await cursor.close(); + }); + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + it(`resolves ${cursorAPI.toString()} without Error`, async () => { + const result = await captureCursorAPIResult(cursor, cursorAPI, args); + controller.abort(); + expect(result).to.not.be.instanceOf(Error); + }); + + it(`rejects ${cursorAPI.toString()} when aborted after start but before await`, async () => { + const willBeResultBlocked = /* await */ captureCursorAPIResult(cursor, cursorAPI, args); + + controller.abort(); + const result = await willBeResultBlocked; + + expect(result).to.be.instanceOf(DOMException); + }); + + it(`rejects ${cursorAPI.toString()} on the subsequent call`, async () => { + const result = await captureCursorAPIResult(cursor, cursorAPI, args); + expect(result).to.not.be.instanceOf(Error); + + controller.abort(); + + const error = await captureCursorAPIResult(cursor, cursorAPI, args); + expect(error).to.be.instanceOf(DOMException); + }); + } + }); + + describe('and the signal is aborted in between iterations', () => { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + const commandsStarted = []; + + beforeEach(async function () { + commandsStarted.length = 0; + const utilClient = this.configuration.newClient(); + try { + const collection = utilClient.db('abortSignal').collection('support'); + await collection.drop({}).catch(() => null); + await collection.insertMany([ + { a: 1, ssn: '0000-00-0001' }, + { a: 2, ssn: '0000-00-0002' }, + { a: 3, ssn: '0000-00-0003' } + ]); + if (cursorName === 'listCollections') { + for (let i = 0; i < 3; i++) { + await db.dropCollection(`c${i}`).catch(() => null); + await db.createCollection(`c${i}`); + } + } + } finally { + await utilClient.close(); + } + + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal, batchSize: 1 }); + client.on('commandStarted', e => commandsStarted.push(e)); + }); + + afterEach(async () => { + await cursor?.close(); + sinon.restore(); + }); + + it(`rejects for-await on the next iteration`, async () => { + let didLoop = false; + let thrownError; + + try { + for await (const _ of cursor) { + if (didLoop) controller.abort(); + didLoop = true; + } + } catch (error) { + thrownError = error; + } + + expect(thrownError).to.be.instanceOf(DOMException); + // Check that we didn't run two getMore before inspecting the state of the signal. + // If we didn't check _after_ re-entering our asyncIterator on `yield`, + // we may have called .next()->.fetchBatch() etc. without preventing that work from being done + expect(commandsStarted.map(e => e.commandName)).to.deep.equal([cursorName, 'getMore']); + await sleep(10); + expect(commandsStarted.map(e => e.commandName)).to.deep.equal([ + cursorName, + 'getMore', + 'killCursors' + ]); + }); + }); + + describe('and the signal is aborted during server selection', () => { + const metadata: MongoDBMetadataUI = { requires: { topology: 'replicaset' } }; + + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(() => { + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { + signal, + // Pick an unselectable server + readPreference: new ReadPreference('secondary', [ + { something: 'that does not exist' } + ]) + }); + }); + + afterEach(async () => { + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()}`, metadata, async () => { + const willBeResult = captureCursorAPIResult(cursor, cursorAPI, args); + + await sleep(3); + expect( + findLast( + logs, + l => + l.operation === cursorName && + l.message === 'Waiting for suitable server to become available' + ) + ).to.exist; + + controller.abort(); + const start = performance.now(); + const result = await willBeResult; + const end = performance.now(); + expect(end - start).to.be.lessThan(1000); // should be way less than 5s server selection timeout + + expect(result).to.be.instanceOf(DOMException); + }); + } + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + + describe('and the signal is aborted during connection checkout', failPointMetadata, () => { + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(async function () { + await configureFailPoint(this.configuration, { + configureFailPoint: 'failCommand', + mode: { times: 1 }, + data: { + appName: 'abortSignalClient', + failCommands: [cursorName], + blockConnection: true, + blockTimeMS: 300 + } + }); + + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal }); + }); + + afterEach(async function () { + await clearFailPoint(this.configuration); + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()}`, async () => { + const checkoutSucceededFirst = events.once(client, 'connectionCheckedOut'); + const checkoutStartedBlocked = events.once(client, 'connectionCheckOutStarted'); + + const _ = captureCursorAPIResult(cursor, cursorAPI, args); + const willBeResultBlocked = captureCursorAPIResult(cursor, cursorAPI, args); + + await checkoutSucceededFirst; + await checkoutStartedBlocked; + + controller.abort(); + const result = await willBeResultBlocked; + + expect(result).to.be.instanceOf(DOMException); + }); + } + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + + describe('and the signal is aborted during connection write', () => { + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(async function () { + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal }); + }); + + afterEach(async function () { + sinon.restore(); + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()}`, async () => { + await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. + + // client.once('commandStarted', () => controller.abort()); + const willBeResultBlocked = captureCursorAPIResult(cursor, cursorAPI, args); + + for (const [, server] of client.topology.s.servers) { + //@ts-expect-error: private property + for (const connection of server.pool.connections) { + //@ts-expect-error: private property + const stub = sinon.stub(connection.socket, 'write').callsFake(function (...args) { + controller.abort(); + sleep(1).then(() => { + stub.wrappedMethod.apply(this, args); + this.emit('drain'); + }); + return false; + }); + } + } + + const result = await willBeResultBlocked; + + expect(result).to.be.instanceOf(DOMException); + }); + } + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + + describe('and the signal is aborted during connection read', failPointMetadata, () => { + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(async function () { + await configureFailPoint(this.configuration, { + configureFailPoint: 'failCommand', + mode: { times: 1 }, + data: { + appName: 'abortSignalClient', + failCommands: [cursorName], + blockConnection: true, + blockTimeMS: 300 + } + }); + + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal }); + }); + + afterEach(async function () { + await clearFailPoint(this.configuration); + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()}`, async () => { + await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. + + client.on('commandStarted', e => e.commandName === cursorName && controller.abort()); + const willBeResultBlocked = captureCursorAPIResult(cursor, cursorAPI, args); + + const result = await willBeResultBlocked; + + expect(result).to.be.instanceOf(DOMException); + }); + } + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + + const fleMetadata: MongoDBMetadataUI = { + requires: { + clientSideEncryption: true, + mongodb: '>=7.0.0', + topology: '!single' + } + }; + + if (cursorName !== 'listCollections') { + describe('setup fle', fleMetadata, () => { + let autoEncryption: AutoEncryptionOptions; + let client: MongoClient; + let db; + let collection; + let method; + let filter; + + before(async function () { + if ( + !this.configuration.clientSideEncryption.enabled || + semver.lt(this.configuration.version, '7.0.0') || + this.configuration.topologyType === 'Single' + ) { + return this.skip(); + } + + autoEncryption = { + keyVaultNamespace: 'admin.datakeys', + kmsProviders: { + local: { key: Buffer.alloc(96) } + }, + tlsOptions: { + kmip: { + tlsCAFile: process.env.KMIP_TLS_CA_FILE, + tlsCertificateKeyFile: process.env.KMIP_TLS_CERT_FILE + } + }, + encryptedFieldsMap: { + 'abortSignal.support': { + fields: [ + { + path: 'ssn', + keyId: null, + bsonType: 'string' + } + ] + } + } + }; + + let utilClient = this.configuration.newClient({}, {}); + + try { + await utilClient + .db('abortSignal') + .collection('support') + .drop({}) + .catch(() => null); + + const clientEncryption = new ClientEncryption(utilClient, { + ...autoEncryption, + encryptedFieldsMap: undefined + }); + + autoEncryption.encryptedFieldsMap['abortSignal.support'] = ( + await clientEncryption.createEncryptedCollection( + utilClient.db('abortSignal'), + 'support', + { + provider: 'local', + createCollectionOptions: { + encryptedFields: autoEncryption.encryptedFieldsMap['abortSignal.support'] + } + } + ) + ).encryptedFields; + } finally { + await utilClient.close(); + } + + utilClient = this.configuration.newClient({}, { autoEncryption }); + try { + await utilClient + .db('abortSignal') + .collection('support') + .insertMany([ + { a: 1, ssn: '0000-00-0001' }, + { a: 2, ssn: '0000-00-0002' }, + { a: 3, ssn: '0000-00-0003' } + ]); + } finally { + await utilClient.close(); + } + }); + + beforeEach(async function () { + client = this.configuration.newClient( + {}, + { + autoEncryption, + monitorCommands: true, + appName: 'abortSignalClient', + __enableMongoLogger: true, + __internalLoggerConfig: { MONGODB_LOG_SERVER_SELECTION: 'debug' }, + mongodbLogPath: { write: log => logs.push(log) }, + serverSelectionTimeoutMS: 10_000, + maxPoolSize: 1 + } + ); + await client.connect(); + db = client.db('abortSignal'); + collection = db.collection('support'); + + method = collection[cursorName].bind(collection); + filter = cursorName === 'aggregate' ? [] : {}; + }); + + afterEach(async function () { + await client?.close(); + }); + + describe('and the signal is aborted during command encryption', fleMetadata, () => { + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(async function () { + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal }); + }); + + afterEach(async function () { + sinon.restore(); + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()}`, fleMetadata, async () => { + const willBeResultBlocked = captureCursorAPIResult(cursor, cursorAPI, args); + + const stub = sinon + .stub(client.options.autoEncrypter, 'encrypt') + .callsFake(function (...args) { + controller.abort(); + return stub.wrappedMethod.apply(this, args); + }); + + const result = await willBeResultBlocked; + + expect(result).to.be.instanceOf(DOMException); + }); + } + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + + describe('and the signal is aborted during command decryption', fleMetadata, () => { + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(async function () { + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal }); + }); + + afterEach(async function () { + sinon.restore(); + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()}`, fleMetadata, async () => { + const willBeResultBlocked = captureCursorAPIResult(cursor, cursorAPI, args); + + const stub = sinon + .stub(client.options.autoEncrypter, 'decrypt') + .callsFake(function (...args) { + controller.abort(); + return stub.wrappedMethod.apply(this, args); + }); + + const result = await willBeResultBlocked; + + expect(result).to.be.instanceOf(DOMException); + }); + } + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + }); + } + }); + } + + testCursor('find', FindCursor); + testCursor('aggregate', AggregationCursor); + testCursor('listCollections', ListCollectionsCursor); + + describe('cursor stream example', () => { + beforeEach(async function () { + const utilClient = this.configuration.newClient(); + try { + const collection = utilClient.db('abortSignal').collection('support'); + await collection.drop({}).catch(() => null); + await collection.insertMany([ + { a: 1, ssn: '0000-00-0001' }, + { a: 2, ssn: '0000-00-0002' }, + { a: 3, ssn: '0000-00-0003' } + ]); + } finally { + await utilClient.close(); + } + }); + + it('follows expected stream error handling', async () => { + const controller = new AbortController(); + const { signal } = controller; + const cursor = collection.find({}, { signal, batchSize: 1 }); + const cursorStream = cursor.stream(); + + const { promise, resolve, reject } = promiseWithResolvers(); + + cursorStream + .on('data', () => controller.abort()) + .on('error', reject) + .on('close', resolve); + + expect(await promise.catch(error => error)).to.be.instanceOf(DOMException); + }); + }); + + describe('KMS requests', function () { + const stateMachine = new StateMachine({} as any); + const request = { + addResponse: _response => undefined, + status: { + type: 1, + code: 1, + message: 'notARealStatus' + }, + bytesNeeded: 500, + kmsProvider: 'notRealAgain', + endpoint: 'fake', + message: Buffer.from('foobar') + }; + + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(async function () { + controller = new AbortController(); + signal = controller.signal; + }); + + afterEach(async function () { + sinon.restore(); + await cursor?.close(); + }); + + describe('when StateMachine.kmsRequest() is passed an AbortSignal', function () { + beforeEach(async function () { + sinon.stub(TLSSocket.prototype, 'connect').callsFake(function (..._args) { + return this; + }); + }); + + afterEach(async function () { + sinon.restore(); + }); + + it('the kms request rejects when signal is aborted', async function () { + const err = stateMachine.kmsRequest(request, { signal }).catch(e => e); + await sleep(1); + controller.abort(); + expect(await err).to.be.instanceOf(DOMException); + }); + }); + }); + + describe('when a signal passed to countDocuments() is aborted', failPointMetadata, () => { + let controller: AbortController; + let signal: AbortSignal; + + beforeEach(async function () { + await configureFailPoint(this.configuration, { + configureFailPoint: 'failCommand', + mode: { times: 1 }, + data: { + appName: 'abortSignalClient', + failCommands: ['aggregate'], + blockConnection: true, + blockTimeMS: 300 + } + }); + + controller = new AbortController(); + signal = controller.signal; + }); + + afterEach(async function () { + await clearFailPoint(this.configuration); + }); + + // We don't fully cover countDocuments because of the above tests for aggregate. + // However, if countDocuments were ever to be implemented using a different command + // This would catch the change: + it(`rejects countDocuments`, async () => { + client.on( + 'commandStarted', + // Abort a bit after aggregate has started: + ev => ev.commandName === 'aggregate' && sleep(10).then(() => controller.abort()) + ); + + const result = await collection.countDocuments({}, { signal }).catch(error => error); + + expect(result).to.be.instanceOf(DOMException); + }); + }); + + describe('when a signal passed to findOne() is aborted', failPointMetadata, () => { + let controller: AbortController; + let signal: AbortSignal; + + beforeEach(async function () { + await configureFailPoint(this.configuration, { + configureFailPoint: 'failCommand', + mode: { times: 1 }, + data: { + appName: 'abortSignalClient', + failCommands: ['find'], + blockConnection: true, + blockTimeMS: 300 + } + }); + + controller = new AbortController(); + signal = controller.signal; + }); + + afterEach(async function () { + await clearFailPoint(this.configuration); + }); + + it(`rejects findOne`, async () => { + client.on( + 'commandStarted', + // Abort a bit after find has started: + ev => ev.commandName === 'find' && sleep(10).then(() => controller.abort()) + ); + + const result = await collection.findOne({}, { signal }).catch(error => error); + + expect(result).to.be.instanceOf(DOMException); + }); + }); + + describe('when a signal passed to db.command() is aborted', failPointMetadata, () => { + let controller: AbortController; + let signal: AbortSignal; + + beforeEach(async function () { + await configureFailPoint(this.configuration, { + configureFailPoint: 'failCommand', + mode: { times: 1 }, + data: { + appName: 'abortSignalClient', + failCommands: ['ping'], + blockConnection: true, + blockTimeMS: 300 + } + }); + + controller = new AbortController(); + signal = controller.signal; + }); + + afterEach(async function () { + await clearFailPoint(this.configuration); + }); + + it(`rejects command`, async () => { + client.on( + 'commandStarted', + // Abort a bit after ping has started: + ev => ev.commandName === 'ping' && sleep(10).then(() => controller.abort()) + ); + + const result = await db.command({ ping: 1 }, { signal }).catch(error => error); + + expect(result).to.be.instanceOf(DOMException); + }); + }); +}); diff --git a/test/tools/utils.ts b/test/tools/utils.ts index 1829fd4412c..549175dc910 100644 --- a/test/tools/utils.ts +++ b/test/tools/utils.ts @@ -622,7 +622,7 @@ export async function clearFailPoint(configuration: TestConfiguration, url = con export async function makeMultiBatchWrite( configuration: TestConfiguration -): Promise { +): Promise[]> { const { maxBsonObjectSize, maxMessageSizeBytes } = await configuration.hello(); const length = maxMessageSizeBytes / maxBsonObjectSize + 1; @@ -637,10 +637,10 @@ export async function makeMultiBatchWrite( export async function makeMultiResponseBatchModelArray( configuration: TestConfiguration -): Promise { +): Promise[]> { const { maxBsonObjectSize } = await configuration.hello(); const namespace = `foo.${new BSON.ObjectId().toHexString()}`; - const models: AnyClientBulkWriteModel[] = [ + const models: AnyClientBulkWriteModel[] = [ { name: 'updateOne', namespace, @@ -693,3 +693,42 @@ export function mergeTestMetadata( } }; } + +export function findLast( + array: T[], + predicate: (value: T, index: number, array: T[]) => value is S, + thisArg?: any +): S | undefined; +export function findLast( + array: T[], + predicate: (value: T, index: number, array: T[]) => boolean, + thisArg?: any +): T | undefined; +export function findLast( + array: unknown[], + predicate: (value: unknown, index: number, array: unknown[]) => boolean, + thisArg?: any +): unknown | undefined { + if (typeof array.findLast === 'function') return array.findLast(predicate, thisArg); + + for (let i = array.length - 1; i >= 0; i--) { + if (predicate.call(thisArg, array[i], i, array)) { + return array[i]; + } + } + + return undefined; +} + +// Node.js 16 doesn't make this global, but it can still be obtained. +export const DOMException: { + new: ( + message?: string, + nameOrOptions?: string | { name?: string; cause?: unknown } + ) => DOMException; +} = (() => { + if (globalThis.DOMException != null) return globalThis.DOMException; + const ac = new AbortController(); + ac.abort(); + return ac.signal.reason.constructor; +})(); diff --git a/test/unit/client-side-encryption/state_machine.test.ts b/test/unit/client-side-encryption/state_machine.test.ts index ad319c44ade..3d6a92765a8 100644 --- a/test/unit/client-side-encryption/state_machine.test.ts +++ b/test/unit/client-side-encryption/state_machine.test.ts @@ -81,7 +81,12 @@ describe('StateMachine', function () { a: new Long('0'), b: new Int32(0) }; - const options = { promoteLongs: false, promoteValues: false }; + const options = { + promoteLongs: false, + promoteValues: false, + signal: undefined, + timeoutMS: undefined + }; const serializedCommand = serialize(command); const stateMachine = new StateMachine({} as any); @@ -493,7 +498,7 @@ describe('StateMachine', function () { }); await stateMachine - .fetchKeys(client, 'keyVault', BSON.serialize({ a: 1 }), context) + .fetchKeys(client, 'keyVault', BSON.serialize({ a: 1 }), { timeoutContext: context }) .catch(e => squashError(e)); const { timeoutContext } = findSpy.getCalls()[0].args[1] as FindOptions; @@ -535,7 +540,7 @@ describe('StateMachine', function () { }); await sleep(300); await stateMachine - .markCommand(client, 'keyVault', BSON.serialize({ a: 1 }), timeoutContext) + .markCommand(client, 'keyVault', BSON.serialize({ a: 1 }), { timeoutContext }) .catch(e => squashError(e)); expect(dbCommandSpy.getCalls()[0].args[1].timeoutMS).to.not.be.undefined; expect(dbCommandSpy.getCalls()[0].args[1].timeoutMS).to.be.lessThanOrEqual(205); @@ -576,7 +581,9 @@ describe('StateMachine', function () { }); await sleep(300); await stateMachine - .fetchCollectionInfo(client, 'keyVault', BSON.serialize({ a: 1 }), context) + .fetchCollectionInfo(client, 'keyVault', BSON.serialize({ a: 1 }), { + timeoutContext: context + }) .catch(e => squashError(e)); const [_filter, { timeoutContext }] = listCollectionsSpy.getCalls()[0].args; expect(timeoutContext).to.exist; From 0d1c165eb97da3f949557c9b0dc88f7cfdb53b32 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 11:41:19 -0500 Subject: [PATCH 02/13] chore: readmany options --- src/cmap/connection.ts | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 40644bf1be5..26d4b142172 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -474,10 +474,7 @@ export class Connection extends TypedEventEmitter { ); } - for await (const response of this.readMany({ - timeoutContext: options.timeoutContext, - signal: options.signal - })) { + for await (const response of this.readMany(options)) { this.socket.setTimeout(0); const bson = response.parse(); From 3de15cfa54eac364a258ba0c887f9a686e725ff0 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 11:42:19 -0500 Subject: [PATCH 03/13] chore: drain options --- src/cmap/connection.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 26d4b142172..461aa9c8a8f 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -702,7 +702,7 @@ export class Connection extends TypedEventEmitter { if (this.socket.write(buffer)) return; - const drainEvent = once(this.socket, 'drain', { signal: options.signal }); + const drainEvent = once(this.socket, 'drain', options); const timeout = options?.timeoutContext?.timeoutForSocketWrite; if (timeout) { try { From 96c36128b6b1d9f642c06abaece643d9674e5921 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 11:52:20 -0500 Subject: [PATCH 04/13] chore: explicit signal --- src/cursor/aggregation_cursor.ts | 3 ++- src/cursor/list_collections_cursor.ts | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/cursor/aggregation_cursor.ts b/src/cursor/aggregation_cursor.ts index 5598485c822..2ba3e810c1a 100644 --- a/src/cursor/aggregation_cursor.ts +++ b/src/cursor/aggregation_cursor.ts @@ -74,7 +74,8 @@ export class AggregationCursor extends ExplainableCursor const options = { ...this.aggregateOptions, ...this.cursorOptions, - session + session, + signal: this.signal }; if (options.explain) { try { diff --git a/src/cursor/list_collections_cursor.ts b/src/cursor/list_collections_cursor.ts index 9b69de1b935..a1e8aa35ad1 100644 --- a/src/cursor/list_collections_cursor.ts +++ b/src/cursor/list_collections_cursor.ts @@ -38,7 +38,8 @@ export class ListCollectionsCursor< const operation = new ListCollectionsOperation(this.parent, this.filter, { ...this.cursorOptions, ...this.options, - session + session, + signal: this.signal }); const response = await executeOperation(this.parent.client, operation, this.timeoutContext); From ee00a38ade9dcf63b3acf4c2600379c3212311cf Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 11:58:34 -0500 Subject: [PATCH 05/13] docs: fix up api docs --- src/mongo_types.ts | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/mongo_types.ts b/src/mongo_types.ts index f042cf661bf..ddbd86a5077 100644 --- a/src/mongo_types.ts +++ b/src/mongo_types.ts @@ -477,31 +477,26 @@ export class CancellationToken extends TypedEventEmitter<{ cancel(): void }> {} /** @public */ export type Abortable = { /** - * When provided the corresponding `AbortController` can be used to cancel an asynchronous action. + * When provided, the corresponding `AbortController` can be used to abort an asynchronous action. * - * The driver will convert the abort event into a promise rejection with an error that has the name `'AbortError'`. - * - * The cause of the error will be set to `signal.reason` + * The `signal.reason` value is used as the error thrown. * * @example * ```js * const controller = new AbortController(); * const { signal } = controller; - * req,on('close', () => controller.abort(new Error('Request aborted by user'))); + * req.on('close', () => controller.abort(new Error('Request aborted by user'))); * * try { * const res = await fetch('...', { signal }); - * await collection.insertOne(await res.json(), { signal }); + * await collection.findOne(await res.json(), { signal }); * catch (error) { + * // depends on abort reason used, but by default this is true. * if (error.name === 'AbortError') { - * // error is MongoAbortError or DOMException, - * // both represent the signal being aborted - * error.cause === signal.reason; // true + * error === signal.reason; // true * } * } * ``` - * - * @see MongoAbortError */ signal?: AbortSignal | undefined; }; From 8955d8e57fca96f5f4cbea87ea894d6492d11461 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 12:16:18 -0500 Subject: [PATCH 06/13] test: better helper name --- .../node-specific/abort_signal.test.ts | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts index cef6fab66ce..824c9c6a60e 100644 --- a/test/integration/node-specific/abort_signal.test.ts +++ b/test/integration/node-specific/abort_signal.test.ts @@ -104,7 +104,7 @@ describe('AbortSignal support', () => { [Symbol.asyncIterator]: [] }; - async function captureCursorAPIResult(cursor, cursorAPI, args) { + async function iterateUntilDocumentOrError(cursor, cursorAPI, args) { try { const apiReturnValue = cursor[cursorAPI](...args); return isAsyncGenerator(apiReturnValue) @@ -159,7 +159,7 @@ describe('AbortSignal support', () => { for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { it(`rejects ${cursorAPI.toString()}`, async () => { - const result = await captureCursorAPIResult(cursor, cursorAPI, args); + const result = await iterateUntilDocumentOrError(cursor, cursorAPI, args); expect(result).to.be.instanceOf(DOMException); }); } @@ -182,13 +182,17 @@ describe('AbortSignal support', () => { for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { it(`resolves ${cursorAPI.toString()} without Error`, async () => { - const result = await captureCursorAPIResult(cursor, cursorAPI, args); + const result = await iterateUntilDocumentOrError(cursor, cursorAPI, args); controller.abort(); expect(result).to.not.be.instanceOf(Error); }); it(`rejects ${cursorAPI.toString()} when aborted after start but before await`, async () => { - const willBeResultBlocked = /* await */ captureCursorAPIResult(cursor, cursorAPI, args); + const willBeResultBlocked = /* await */ iterateUntilDocumentOrError( + cursor, + cursorAPI, + args + ); controller.abort(); const result = await willBeResultBlocked; @@ -197,12 +201,12 @@ describe('AbortSignal support', () => { }); it(`rejects ${cursorAPI.toString()} on the subsequent call`, async () => { - const result = await captureCursorAPIResult(cursor, cursorAPI, args); + const result = await iterateUntilDocumentOrError(cursor, cursorAPI, args); expect(result).to.not.be.instanceOf(Error); controller.abort(); - const error = await captureCursorAPIResult(cursor, cursorAPI, args); + const error = await iterateUntilDocumentOrError(cursor, cursorAPI, args); expect(error).to.be.instanceOf(DOMException); }); } @@ -298,7 +302,7 @@ describe('AbortSignal support', () => { }); it(`rejects ${cursorAPI.toString()}`, metadata, async () => { - const willBeResult = captureCursorAPIResult(cursor, cursorAPI, args); + const willBeResult = iterateUntilDocumentOrError(cursor, cursorAPI, args); await sleep(3); expect( @@ -357,8 +361,8 @@ describe('AbortSignal support', () => { const checkoutSucceededFirst = events.once(client, 'connectionCheckedOut'); const checkoutStartedBlocked = events.once(client, 'connectionCheckOutStarted'); - const _ = captureCursorAPIResult(cursor, cursorAPI, args); - const willBeResultBlocked = captureCursorAPIResult(cursor, cursorAPI, args); + const _ = iterateUntilDocumentOrError(cursor, cursorAPI, args); + const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); await checkoutSucceededFirst; await checkoutStartedBlocked; @@ -396,7 +400,7 @@ describe('AbortSignal support', () => { await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. // client.once('commandStarted', () => controller.abort()); - const willBeResultBlocked = captureCursorAPIResult(cursor, cursorAPI, args); + const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); for (const [, server] of client.topology.s.servers) { //@ts-expect-error: private property @@ -456,7 +460,7 @@ describe('AbortSignal support', () => { await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. client.on('commandStarted', e => e.commandName === cursorName && controller.abort()); - const willBeResultBlocked = captureCursorAPIResult(cursor, cursorAPI, args); + const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); const result = await willBeResultBlocked; @@ -608,7 +612,7 @@ describe('AbortSignal support', () => { }); it(`rejects ${cursorAPI.toString()}`, fleMetadata, async () => { - const willBeResultBlocked = captureCursorAPIResult(cursor, cursorAPI, args); + const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); const stub = sinon .stub(client.options.autoEncrypter, 'encrypt') @@ -646,7 +650,7 @@ describe('AbortSignal support', () => { }); it(`rejects ${cursorAPI.toString()}`, fleMetadata, async () => { - const willBeResultBlocked = captureCursorAPIResult(cursor, cursorAPI, args); + const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); const stub = sinon .stub(client.options.autoEncrypter, 'decrypt') From 8e6ec041ed2e1360d824e0495ff5fc26c8d11a22 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 12:23:28 -0500 Subject: [PATCH 07/13] test: update name --- test/integration/node-specific/abort_signal.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts index 824c9c6a60e..ee1c5d3931f 100644 --- a/test/integration/node-specific/abort_signal.test.ts +++ b/test/integration/node-specific/abort_signal.test.ts @@ -187,7 +187,7 @@ describe('AbortSignal support', () => { expect(result).to.not.be.instanceOf(Error); }); - it(`rejects ${cursorAPI.toString()} when aborted after start but before await`, async () => { + it(`aborts in-flight ${cursorAPI.toString()} when aborted after start but before await`, async () => { const willBeResultBlocked = /* await */ iterateUntilDocumentOrError( cursor, cursorAPI, From a73940ef2bb60d4d2137ef9c34b324b26f51e799 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 12:36:59 -0500 Subject: [PATCH 08/13] test: improve iteration test organization --- .../node-specific/abort_signal.test.ts | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts index ee1c5d3931f..51c2af9f52d 100644 --- a/test/integration/node-specific/abort_signal.test.ts +++ b/test/integration/node-specific/abort_signal.test.ts @@ -245,35 +245,51 @@ describe('AbortSignal support', () => { client.on('commandStarted', e => commandsStarted.push(e)); }); + const waitForKillCursors = async () => { + for await (const [ev] of events.on(client, 'commandStarted')) { + if (ev.commandName === 'killCursors') return ev; + } + }; + afterEach(async () => { await cursor?.close(); sinon.restore(); }); it(`rejects for-await on the next iteration`, async () => { - let didLoop = false; + let loop = 0; let thrownError; try { for await (const _ of cursor) { - if (didLoop) controller.abort(); - didLoop = true; + if (loop) controller.abort(); + loop += 1; } } catch (error) { thrownError = error; } expect(thrownError).to.be.instanceOf(DOMException); + expect(loop).to.equal(2); + }); + + it('does not run more than one getMore and kills the cursor', async () => { + const killCursors = waitForKillCursors(); + try { + let loop = 0; + for await (const _ of cursor) { + if (loop) controller.abort(); + loop += 1; + } + } catch { + //ignore; + } + // Check that we didn't run two getMore before inspecting the state of the signal. // If we didn't check _after_ re-entering our asyncIterator on `yield`, // we may have called .next()->.fetchBatch() etc. without preventing that work from being done expect(commandsStarted.map(e => e.commandName)).to.deep.equal([cursorName, 'getMore']); - await sleep(10); - expect(commandsStarted.map(e => e.commandName)).to.deep.equal([ - cursorName, - 'getMore', - 'killCursors' - ]); + await killCursors; }); }); From 45a4b65ea91ab9524e6d44b813b3ee2f426cb99f Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 12:37:50 -0500 Subject: [PATCH 09/13] test: cruft --- test/integration/node-specific/abort_signal.test.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts index 51c2af9f52d..19158e2ea39 100644 --- a/test/integration/node-specific/abort_signal.test.ts +++ b/test/integration/node-specific/abort_signal.test.ts @@ -415,7 +415,6 @@ describe('AbortSignal support', () => { it(`rejects ${cursorAPI.toString()}`, async () => { await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. - // client.once('commandStarted', () => controller.abort()); const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); for (const [, server] of client.topology.s.servers) { From e9338cbc272c72458adcbbfbcb514486899b34da Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 12:57:43 -0500 Subject: [PATCH 10/13] feat: make sure connections are closed after abort if aborted during socket r/w --- src/cmap/connection.ts | 37 ++++++++++--------- .../node-specific/abort_signal.test.ts | 26 +++++++++++++ 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 461aa9c8a8f..8be150d4ce5 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -704,21 +704,21 @@ export class Connection extends TypedEventEmitter { const drainEvent = once(this.socket, 'drain', options); const timeout = options?.timeoutContext?.timeoutForSocketWrite; - if (timeout) { - try { - return await Promise.race([drainEvent, timeout]); - } catch (error) { - let err = error; - if (TimeoutError.is(error)) { - err = new MongoOperationTimeoutError('Timed out at socket write'); - this.cleanup(err); - } - throw error; - } finally { - timeout.clear(); + const drained = timeout ? Promise.race([drainEvent, timeout]) : drainEvent; + try { + return await drained; + } catch (writeError) { + if (TimeoutError.is(writeError)) { + const timeoutError = new MongoOperationTimeoutError('Timed out at socket write'); + this.onError(timeoutError); + throw timeoutError; + } else if (writeError === options.signal?.reason) { + this.onError(writeError); } + throw writeError; + } finally { + timeout?.clear(); } - return await drainEvent; } /** @@ -748,16 +748,17 @@ export class Connection extends TypedEventEmitter { } } } catch (readError) { - const err = readError; if (TimeoutError.is(readError)) { - const error = new MongoOperationTimeoutError( + const timeoutError = new MongoOperationTimeoutError( `Timed out during socket read (${readError.duration}ms)` ); this.dataEvents = null; - this.onError(error); - throw error; + this.onError(timeoutError); + throw timeoutError; + } else if (readError === options.signal?.reason) { + this.onError(readError); } - throw err; + throw readError; } finally { this.dataEvents = null; this.messageStream.pause(); diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts index 19158e2ea39..7136a577642 100644 --- a/test/integration/node-specific/abort_signal.test.ts +++ b/test/integration/node-specific/abort_signal.test.ts @@ -12,6 +12,7 @@ import { type AutoEncryptionOptions, ClientEncryption, type Collection, + type ConnectionClosedEvent, type Db, FindCursor, ListCollectionsCursor, @@ -400,21 +401,31 @@ describe('AbortSignal support', () => { let controller: AbortController; let signal: AbortSignal; let cursor: AbstractCursor<{ a: number }>; + let checkedOutId; + const waitForConnectionClosed = async () => { + for await (const [ev] of events.on(client, 'connectionClosed')) { + if ((ev as ConnectionClosedEvent).connectionId === checkedOutId) return ev; + } + }; beforeEach(async function () { + checkedOutId = undefined; controller = new AbortController(); signal = controller.signal; cursor = method(filter, { signal }); }); afterEach(async function () { + checkedOutId = undefined; sinon.restore(); await cursor?.close(); }); it(`rejects ${cursorAPI.toString()}`, async () => { await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. + const connectionClosed = waitForConnectionClosed(); + client.on('connectionCheckedOut', ev => (checkedOutId = ev.connectionId)); const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); for (const [, server] of client.topology.s.servers) { @@ -435,6 +446,8 @@ describe('AbortSignal support', () => { const result = await willBeResultBlocked; expect(result).to.be.instanceOf(DOMException); + + await connectionClosed; }); } @@ -461,25 +474,38 @@ describe('AbortSignal support', () => { } }); + checkedOutId = undefined; controller = new AbortController(); signal = controller.signal; cursor = method(filter, { signal }); }); + let checkedOutId; + const waitForConnectionClosed = async () => { + for await (const [ev] of events.on(client, 'connectionClosed')) { + if ((ev as ConnectionClosedEvent).connectionId === checkedOutId) return ev; + } + }; + afterEach(async function () { + checkedOutId = undefined; await clearFailPoint(this.configuration); await cursor?.close(); }); it(`rejects ${cursorAPI.toString()}`, async () => { await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. + const connectionClosed = waitForConnectionClosed(); + client.on('connectionCheckedOut', ev => (checkedOutId = ev.connectionId)); client.on('commandStarted', e => e.commandName === cursorName && controller.abort()); const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); const result = await willBeResultBlocked; expect(result).to.be.instanceOf(DOMException); + + await connectionClosed; }); } From 152be95bb9b93b76fca840e31a2897676a20311d Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 13:03:03 -0500 Subject: [PATCH 11/13] test: remove redundant fle tests --- .../node-specific/abort_signal.test.ts | 199 ------------------ 1 file changed, 199 deletions(-) diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts index 7136a577642..7c9d2507f02 100644 --- a/test/integration/node-specific/abort_signal.test.ts +++ b/test/integration/node-specific/abort_signal.test.ts @@ -513,205 +513,6 @@ describe('AbortSignal support', () => { test(cursorAPI, args); } }); - - const fleMetadata: MongoDBMetadataUI = { - requires: { - clientSideEncryption: true, - mongodb: '>=7.0.0', - topology: '!single' - } - }; - - if (cursorName !== 'listCollections') { - describe('setup fle', fleMetadata, () => { - let autoEncryption: AutoEncryptionOptions; - let client: MongoClient; - let db; - let collection; - let method; - let filter; - - before(async function () { - if ( - !this.configuration.clientSideEncryption.enabled || - semver.lt(this.configuration.version, '7.0.0') || - this.configuration.topologyType === 'Single' - ) { - return this.skip(); - } - - autoEncryption = { - keyVaultNamespace: 'admin.datakeys', - kmsProviders: { - local: { key: Buffer.alloc(96) } - }, - tlsOptions: { - kmip: { - tlsCAFile: process.env.KMIP_TLS_CA_FILE, - tlsCertificateKeyFile: process.env.KMIP_TLS_CERT_FILE - } - }, - encryptedFieldsMap: { - 'abortSignal.support': { - fields: [ - { - path: 'ssn', - keyId: null, - bsonType: 'string' - } - ] - } - } - }; - - let utilClient = this.configuration.newClient({}, {}); - - try { - await utilClient - .db('abortSignal') - .collection('support') - .drop({}) - .catch(() => null); - - const clientEncryption = new ClientEncryption(utilClient, { - ...autoEncryption, - encryptedFieldsMap: undefined - }); - - autoEncryption.encryptedFieldsMap['abortSignal.support'] = ( - await clientEncryption.createEncryptedCollection( - utilClient.db('abortSignal'), - 'support', - { - provider: 'local', - createCollectionOptions: { - encryptedFields: autoEncryption.encryptedFieldsMap['abortSignal.support'] - } - } - ) - ).encryptedFields; - } finally { - await utilClient.close(); - } - - utilClient = this.configuration.newClient({}, { autoEncryption }); - try { - await utilClient - .db('abortSignal') - .collection('support') - .insertMany([ - { a: 1, ssn: '0000-00-0001' }, - { a: 2, ssn: '0000-00-0002' }, - { a: 3, ssn: '0000-00-0003' } - ]); - } finally { - await utilClient.close(); - } - }); - - beforeEach(async function () { - client = this.configuration.newClient( - {}, - { - autoEncryption, - monitorCommands: true, - appName: 'abortSignalClient', - __enableMongoLogger: true, - __internalLoggerConfig: { MONGODB_LOG_SERVER_SELECTION: 'debug' }, - mongodbLogPath: { write: log => logs.push(log) }, - serverSelectionTimeoutMS: 10_000, - maxPoolSize: 1 - } - ); - await client.connect(); - db = client.db('abortSignal'); - collection = db.collection('support'); - - method = collection[cursorName].bind(collection); - filter = cursorName === 'aggregate' ? [] : {}; - }); - - afterEach(async function () { - await client?.close(); - }); - - describe('and the signal is aborted during command encryption', fleMetadata, () => { - function test(cursorAPI, args) { - let controller: AbortController; - let signal: AbortSignal; - let cursor: AbstractCursor<{ a: number }>; - - beforeEach(async function () { - controller = new AbortController(); - signal = controller.signal; - cursor = method(filter, { signal }); - }); - - afterEach(async function () { - sinon.restore(); - await cursor?.close(); - }); - - it(`rejects ${cursorAPI.toString()}`, fleMetadata, async () => { - const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); - - const stub = sinon - .stub(client.options.autoEncrypter, 'encrypt') - .callsFake(function (...args) { - controller.abort(); - return stub.wrappedMethod.apply(this, args); - }); - - const result = await willBeResultBlocked; - - expect(result).to.be.instanceOf(DOMException); - }); - } - - for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { - test(cursorAPI, args); - } - }); - - describe('and the signal is aborted during command decryption', fleMetadata, () => { - function test(cursorAPI, args) { - let controller: AbortController; - let signal: AbortSignal; - let cursor: AbstractCursor<{ a: number }>; - - beforeEach(async function () { - controller = new AbortController(); - signal = controller.signal; - cursor = method(filter, { signal }); - }); - - afterEach(async function () { - sinon.restore(); - await cursor?.close(); - }); - - it(`rejects ${cursorAPI.toString()}`, fleMetadata, async () => { - const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); - - const stub = sinon - .stub(client.options.autoEncrypter, 'decrypt') - .callsFake(function (...args) { - controller.abort(); - return stub.wrappedMethod.apply(this, args); - }); - - const result = await willBeResultBlocked; - - expect(result).to.be.instanceOf(DOMException); - }); - } - - for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { - test(cursorAPI, args); - } - }); - }); - } }); } From 5af48af775fa4f57d36d8b186150faee4bf13280 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 13:04:27 -0500 Subject: [PATCH 12/13] chore: make findLast simple --- test/integration/node-specific/abort_signal.test.ts | 3 --- test/tools/utils.ts | 9 +-------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts index 7c9d2507f02..52cab22bef1 100644 --- a/test/integration/node-specific/abort_signal.test.ts +++ b/test/integration/node-specific/abort_signal.test.ts @@ -3,14 +3,11 @@ import { TLSSocket } from 'node:tls'; import * as util from 'node:util'; import { expect } from 'chai'; -import * as semver from 'semver'; import * as sinon from 'sinon'; import { type AbstractCursor, AggregationCursor, - type AutoEncryptionOptions, - ClientEncryption, type Collection, type ConnectionClosedEvent, type Db, diff --git a/test/tools/utils.ts b/test/tools/utils.ts index 549175dc910..80906d0a343 100644 --- a/test/tools/utils.ts +++ b/test/tools/utils.ts @@ -710,14 +710,7 @@ export function findLast( thisArg?: any ): unknown | undefined { if (typeof array.findLast === 'function') return array.findLast(predicate, thisArg); - - for (let i = array.length - 1; i >= 0; i--) { - if (predicate.call(thisArg, array[i], i, array)) { - return array[i]; - } - } - - return undefined; + return array.slice().reverse().find(predicate, thisArg); } // Node.js 16 doesn't make this global, but it can still be obtained. From 6e8bd464b19dd05cfedeb7e2e575ece155ac128e Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 17 Jan 2025 15:39:58 -0500 Subject: [PATCH 13/13] test: no kill cursors on lb and don't wait on connection close --- .../node-specific/abort_signal.test.ts | 54 +++++++++---------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts index 52cab22bef1..c9c9aa1560d 100644 --- a/test/integration/node-specific/abort_signal.test.ts +++ b/test/integration/node-specific/abort_signal.test.ts @@ -215,8 +215,18 @@ describe('AbortSignal support', () => { let signal: AbortSignal; let cursor: AbstractCursor<{ a: number }>; const commandsStarted = []; + let waitForKillCursors; beforeEach(async function () { + waitForKillCursors = + this.configuration.topologyType === 'LoadBalanced' + ? async () => null + : async () => { + for await (const [ev] of events.on(client, 'commandStarted')) { + if (ev.commandName === 'killCursors') return ev; + } + }; + commandsStarted.length = 0; const utilClient = this.configuration.newClient(); try { @@ -243,12 +253,6 @@ describe('AbortSignal support', () => { client.on('commandStarted', e => commandsStarted.push(e)); }); - const waitForKillCursors = async () => { - for await (const [ev] of events.on(client, 'commandStarted')) { - if (ev.commandName === 'killCursors') return ev; - } - }; - afterEach(async () => { await cursor?.close(); sinon.restore(); @@ -398,36 +402,30 @@ describe('AbortSignal support', () => { let controller: AbortController; let signal: AbortSignal; let cursor: AbstractCursor<{ a: number }>; - let checkedOutId; - const waitForConnectionClosed = async () => { - for await (const [ev] of events.on(client, 'connectionClosed')) { - if ((ev as ConnectionClosedEvent).connectionId === checkedOutId) return ev; - } - }; beforeEach(async function () { - checkedOutId = undefined; controller = new AbortController(); signal = controller.signal; cursor = method(filter, { signal }); }); afterEach(async function () { - checkedOutId = undefined; sinon.restore(); await cursor?.close(); }); it(`rejects ${cursorAPI.toString()}`, async () => { await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. - const connectionClosed = waitForConnectionClosed(); - client.on('connectionCheckedOut', ev => (checkedOutId = ev.connectionId)); const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); + let cursorCommandSocket; + for (const [, server] of client.topology.s.servers) { //@ts-expect-error: private property for (const connection of server.pool.connections) { + //@ts-expect-error: private property + cursorCommandSocket = connection.socket; //@ts-expect-error: private property const stub = sinon.stub(connection.socket, 'write').callsFake(function (...args) { controller.abort(); @@ -444,7 +442,7 @@ describe('AbortSignal support', () => { expect(result).to.be.instanceOf(DOMException); - await connectionClosed; + expect(cursorCommandSocket).to.have.property('destroyed', true); }); } @@ -471,30 +469,28 @@ describe('AbortSignal support', () => { } }); - checkedOutId = undefined; controller = new AbortController(); signal = controller.signal; cursor = method(filter, { signal }); }); - let checkedOutId; - const waitForConnectionClosed = async () => { - for await (const [ev] of events.on(client, 'connectionClosed')) { - if ((ev as ConnectionClosedEvent).connectionId === checkedOutId) return ev; - } - }; - afterEach(async function () { - checkedOutId = undefined; await clearFailPoint(this.configuration); await cursor?.close(); }); it(`rejects ${cursorAPI.toString()}`, async () => { await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. - const connectionClosed = waitForConnectionClosed(); - client.on('connectionCheckedOut', ev => (checkedOutId = ev.connectionId)); + let cursorCommandSocket; + for (const [, server] of client.topology.s.servers) { + //@ts-expect-error: private property + for (const connection of server.pool.connections) { + //@ts-expect-error: private property + cursorCommandSocket = connection.socket; + } + } + client.on('commandStarted', e => e.commandName === cursorName && controller.abort()); const willBeResultBlocked = iterateUntilDocumentOrError(cursor, cursorAPI, args); @@ -502,7 +498,7 @@ describe('AbortSignal support', () => { expect(result).to.be.instanceOf(DOMException); - await connectionClosed; + expect(cursorCommandSocket).to.have.property('destroyed', true); }); }