Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(NODE-6258): add signal support to cursor APIs #4364

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
20 changes: 12 additions & 8 deletions src/client-side-encryption/auto_encrypter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { kDecorateResult } from '../constants';
import { getMongoDBClientEncryption } from '../deps';
import { MongoRuntimeError } from '../error';
import { MongoClient, type MongoClientOptions } from '../mongo_client';
import { type Abortable } from '../mongo_types';
import { MongoDBCollectionNamespace } from '../utils';
import { autoSelectSocketOptions } from './client_encryption';
import * as cryptoCallbacks from './crypto_callbacks';
Expand Down Expand Up @@ -372,8 +373,10 @@ export class AutoEncrypter {
async encrypt(
ns: string,
cmd: Document,
options: CommandOptions = {}
options: CommandOptions & Abortable = {}
): Promise<Document | Uint8Array> {
options.signal?.throwIfAborted();

if (this._bypassEncryption) {
// If `bypassAutoEncryption` has been specified, don't encrypt
return cmd;
Expand All @@ -398,7 +401,7 @@ export class AutoEncrypter {
socketOptions: autoSelectSocketOptions(this._client.s.options)
});

return deserialize(await stateMachine.execute(this, context, options.timeoutContext), {
return deserialize(await stateMachine.execute(this, context, options), {
promoteValues: false,
promoteLongs: false
});
Expand All @@ -407,7 +410,12 @@ export class AutoEncrypter {
/**
* Decrypt a command response
*/
async decrypt(response: Uint8Array, options: CommandOptions = {}): Promise<Uint8Array> {
async decrypt(
response: Uint8Array,
options: CommandOptions & Abortable = {}
): Promise<Uint8Array> {
options.signal?.throwIfAborted();

const context = this._mongocrypt.makeDecryptionContext(response);

context.id = this._contextCounter++;
Expand All @@ -419,11 +427,7 @@ export class AutoEncrypter {
socketOptions: autoSelectSocketOptions(this._client.s.options)
});

return await stateMachine.execute(
this,
context,
options.timeoutContext?.csotEnabled() ? options.timeoutContext : undefined
);
return await stateMachine.execute(this, context, options);
}

/**
Expand Down
10 changes: 6 additions & 4 deletions src/client-side-encryption/client_encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ export class ClientEncryption {
TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }));

const dataKey = deserialize(
await stateMachine.execute(this, context, timeoutContext)
await stateMachine.execute(this, context, { timeoutContext })
) as DataKey;

const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(
Expand Down Expand Up @@ -293,7 +293,9 @@ export class ClientEncryption {
resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })
);

const { v: dataKeys } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v: dataKeys } = deserialize(
await stateMachine.execute(this, context, { timeoutContext })
);
if (dataKeys.length === 0) {
return {};
}
Expand Down Expand Up @@ -696,7 +698,7 @@ export class ClientEncryption {
? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }))
: undefined;

const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext }));

return v;
}
Expand Down Expand Up @@ -780,7 +782,7 @@ export class ClientEncryption {
this._timeoutMS != null
? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }))
: undefined;
const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext }));
return v;
}
}
Expand Down
116 changes: 80 additions & 36 deletions src/client-side-encryption/state_machine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ import { CursorTimeoutContext } from '../cursor/abstract_cursor';
import { getSocks, type SocksLib } from '../deps';
import { MongoOperationTimeoutError } from '../error';
import { type MongoClient, type MongoClientOptions } from '../mongo_client';
import { type Abortable } from '../mongo_types';
import { Timeout, type TimeoutContext, TimeoutError } from '../timeout';
import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils';
import {
addAbortListener,
BufferPool,
kDispose,
MongoDBCollectionNamespace,
promiseWithResolvers
} from '../utils';
import { autoSelectSocketOptions, type DataKey } from './client_encryption';
import { MongoCryptError } from './errors';
import { type MongocryptdManager } from './mongocryptd_manager';
Expand Down Expand Up @@ -189,7 +196,7 @@ export class StateMachine {
async execute(
executor: StateMachineExecutable,
context: MongoCryptContext,
timeoutContext?: TimeoutContext
options: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array> {
const keyVaultNamespace = executor._keyVaultNamespace;
const keyVaultClient = executor._keyVaultClient;
Expand All @@ -199,6 +206,7 @@ export class StateMachine {
let result: Uint8Array | null = null;

while (context.state !== MONGOCRYPT_CTX_DONE && context.state !== MONGOCRYPT_CTX_ERROR) {
options.signal?.throwIfAborted();
debug(`[context#${context.id}] ${stateToString.get(context.state) || context.state}`);

switch (context.state) {
Expand All @@ -214,7 +222,7 @@ export class StateMachine {
metaDataClient,
context.ns,
filter,
timeoutContext
options
);
if (collInfo) {
context.addMongoOperationResponse(collInfo);
Expand All @@ -235,9 +243,9 @@ export class StateMachine {
// When we are using the shared library, we don't have a mongocryptd manager.
const markedCommand: Uint8Array = mongocryptdManager
? await mongocryptdManager.withRespawn(
this.markCommand.bind(this, mongocryptdClient, context.ns, command, timeoutContext)
this.markCommand.bind(this, mongocryptdClient, context.ns, command, options)
)
: await this.markCommand(mongocryptdClient, context.ns, command, timeoutContext);
: await this.markCommand(mongocryptdClient, context.ns, command, options);

context.addMongoOperationResponse(markedCommand);
context.finishMongoOperation();
Expand All @@ -246,12 +254,7 @@ export class StateMachine {

case MONGOCRYPT_CTX_NEED_MONGO_KEYS: {
const filter = context.nextMongoOperation();
const keys = await this.fetchKeys(
keyVaultClient,
keyVaultNamespace,
filter,
timeoutContext
);
const keys = await this.fetchKeys(keyVaultClient, keyVaultNamespace, filter, options);

if (keys.length === 0) {
// See docs on EMPTY_V
Expand All @@ -273,7 +276,7 @@ export class StateMachine {
}

case MONGOCRYPT_CTX_NEED_KMS: {
await Promise.all(this.requests(context, timeoutContext));
await Promise.all(this.requests(context, options));
context.finishKMSRequests();
break;
}
Expand Down Expand Up @@ -315,11 +318,13 @@ export class StateMachine {
* @param kmsContext - A C++ KMS context returned from the bindings
* @returns A promise that resolves when the KMS reply has be fully parsed
*/
async kmsRequest(request: MongoCryptKMSRequest, timeoutContext?: TimeoutContext): Promise<void> {
async kmsRequest(
request: MongoCryptKMSRequest,
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<void> {
const parsedUrl = request.endpoint.split(':');
const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT;
const socketOptions = autoSelectSocketOptions(this.options.socketOptions || {});
const options: tls.ConnectionOptions & {
const socketOptions: tls.ConnectionOptions & {
host: string;
port: number;
autoSelectFamily?: boolean;
Expand All @@ -328,7 +333,7 @@ export class StateMachine {
host: parsedUrl[0],
servername: parsedUrl[0],
port,
...socketOptions
...autoSelectSocketOptions(this.options.socketOptions || {})
};
const message = request.message;
const buffer = new BufferPool();
Expand Down Expand Up @@ -363,7 +368,7 @@ export class StateMachine {
throw error;
}
try {
await this.setTlsOptions(providerTlsOptions, options);
await this.setTlsOptions(providerTlsOptions, socketOptions);
} catch (err) {
throw onerror(err);
}
Expand All @@ -380,23 +385,25 @@ export class StateMachine {
.once('close', () => rejectOnNetSocketError(onclose()))
.once('connect', () => resolveOnNetSocketConnect());

let abortListener;

try {
if (this.options.proxyOptions && this.options.proxyOptions.proxyHost) {
const netSocketOptions = {
...socketOptions,
host: this.options.proxyOptions.proxyHost,
port: this.options.proxyOptions.proxyPort || 1080,
...socketOptions
port: this.options.proxyOptions.proxyPort || 1080
};
netSocket.connect(netSocketOptions);
await willConnect;

try {
socks ??= loadSocks();
options.socket = (
socketOptions.socket = (
await socks.SocksClient.createConnection({
existing_socket: netSocket,
command: 'connect',
destination: { host: options.host, port: options.port },
destination: { host: socketOptions.host, port: socketOptions.port },
proxy: {
// host and port are ignored because we pass existing_socket
host: 'iLoveJavaScript',
Expand All @@ -412,7 +419,7 @@ export class StateMachine {
}
}

socket = tls.connect(options, () => {
socket = tls.connect(socketOptions, () => {
socket.write(message);
});

Expand All @@ -422,6 +429,11 @@ export class StateMachine {
resolve
} = promiseWithResolvers<void>();

abortListener = addAbortListener(options?.signal, function () {
destroySockets();
rejectOnTlsSocketError(this.reason);
});

socket
.once('error', err => rejectOnTlsSocketError(onerror(err)))
.once('close', () => rejectOnTlsSocketError(onclose()))
Expand All @@ -436,8 +448,11 @@ export class StateMachine {
resolve();
}
});
await (timeoutContext?.csotEnabled()
? Promise.all([willResolveKmsRequest, Timeout.expires(timeoutContext?.remainingTimeMS)])
await (options?.timeoutContext?.csotEnabled()
? Promise.all([
willResolveKmsRequest,
Timeout.expires(options.timeoutContext?.remainingTimeMS)
])
: willResolveKmsRequest);
} catch (error) {
if (error instanceof TimeoutError)
Expand All @@ -446,16 +461,17 @@ export class StateMachine {
} finally {
// There's no need for any more activity on this socket at this point.
destroySockets();
abortListener?.[kDispose]();
}
}

*requests(context: MongoCryptContext, timeoutContext?: TimeoutContext) {
*requests(context: MongoCryptContext, options?: { timeoutContext?: TimeoutContext } & Abortable) {
for (
let request = context.nextKMSRequest();
request != null;
request = context.nextKMSRequest()
) {
yield this.kmsRequest(request, timeoutContext);
yield this.kmsRequest(request, options);
}
}

Expand Down Expand Up @@ -516,14 +532,16 @@ export class StateMachine {
client: MongoClient,
ns: string,
filter: Document,
timeoutContext?: TimeoutContext
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array | null> {
const { db } = MongoDBCollectionNamespace.fromString(ns);

const cursor = client.db(db).listCollections(filter, {
promoteLongs: false,
promoteValues: false,
timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol())
timeoutContext:
options?.timeoutContext && new CursorTimeoutContext(options?.timeoutContext, Symbol()),
signal: options?.signal
});

// There is always exactly zero or one matching documents, so this should always exhaust the cursor
Expand All @@ -547,17 +565,30 @@ export class StateMachine {
client: MongoClient,
ns: string,
command: Uint8Array,
timeoutContext?: TimeoutContext
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array> {
const { db } = MongoDBCollectionNamespace.fromString(ns);
const bsonOptions = { promoteLongs: false, promoteValues: false };
const rawCommand = deserialize(command, bsonOptions);

const commandOptions: {
timeoutMS?: number;
signal?: AbortSignal;
} = {
timeoutMS: undefined,
signal: undefined
};

if (options?.timeoutContext?.csotEnabled()) {
commandOptions.timeoutMS = options.timeoutContext.remainingTimeMS;
}
if (options?.signal) {
commandOptions.signal = options.signal;
}

const response = await client.db(db).command(rawCommand, {
...bsonOptions,
...(timeoutContext?.csotEnabled()
? { timeoutMS: timeoutContext?.remainingTimeMS }
: undefined)
...commandOptions
});

return serialize(response, this.bsonOptions);
Expand All @@ -575,17 +606,30 @@ export class StateMachine {
client: MongoClient,
keyVaultNamespace: string,
filter: Uint8Array,
timeoutContext?: TimeoutContext
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Array<DataKey>> {
const { db: dbName, collection: collectionName } =
MongoDBCollectionNamespace.fromString(keyVaultNamespace);

const commandOptions: {
timeoutContext?: CursorTimeoutContext;
signal?: AbortSignal;
} = {
timeoutContext: undefined,
signal: undefined
};

if (options?.timeoutContext != null) {
commandOptions.timeoutContext = new CursorTimeoutContext(options.timeoutContext, Symbol());
}
if (options?.signal != null) {
commandOptions.signal = options.signal;
}

return client
.db(dbName)
.collection<DataKey>(collectionName, { readConcern: { level: 'majority' } })
.find(deserialize(filter), {
timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol())
})
.find(deserialize(filter), commandOptions)
.toArray();
}
}
Loading
Loading