diff --git a/lib/types.js b/lib/types.js index 1594e71d..1e15acfc 100644 --- a/lib/types.js +++ b/lib/types.js @@ -111,6 +111,8 @@ class Type { wrapUnions = 'auto'; } else if (typeof wrapUnions == 'string') { wrapUnions = wrapUnions.toLowerCase(); + } else if (typeof wrapUnions === 'function') { + wrapUnions = 'auto'; } switch (wrapUnions) { case 'always': @@ -196,11 +198,20 @@ class Type { let types = schema.map((obj) => { return Type.forSchema(obj, opts); }); + let projectionFn; if (!UnionType) { - UnionType = isAmbiguous(types) ? WrappedUnionType : UnwrappedUnionType; + if (typeof opts.wrapUnions === 'function') { + // we have a projection function + projectionFn = opts.wrapUnions(types); + UnionType = typeof projectionFn !== 'undefined' + ? UnwrappedUnionType + : WrappedUnionType; + } else { + UnionType = isAmbiguous(types) ? WrappedUnionType : UnwrappedUnionType; + } } LOGICAL_TYPE = logicalType; - type = new UnionType(types, opts); + type = new UnionType(types, opts, projectionFn); } else { // New type definition. type = (function (typeName) { let Type = TYPES[typeName]; @@ -341,10 +352,10 @@ class Type { return branchTypes[name]; }), opts); } catch (err) { - opts.wrapUnions = wrapUnions; throw err; + } finally { + opts.wrapUnions = wrapUnions; } - opts.wrapUnions = wrapUnions; return unionType; } @@ -1226,6 +1237,60 @@ UnionType.prototype._branchConstructor = function () { throw new Error('unions cannot be directly wrapped'); }; + +function generateProjectionIndexer(projectionFn) { + return (val) => { + const index = projectionFn(val); + if (typeof index !== 'number') { + throw new Error(`Projected index '${index}' is not valid`); + } + return index; + }; +} + +function generateDefaultIndexer(types) { + const dynamicBranches = []; + const bucketIndices = {}; + + const getBranchIndex = (any, index) => { + let logicalBranches = dynamicBranches; + for (let i = 0, l = logicalBranches.length; i < l; i++) { + let branch = logicalBranches[i]; + if (branch.type._check(any)) { + if (index === undefined) { + index = branch.index; + } else { + // More than one branch matches the value so we aren't guaranteed to + // infer the correct type. We throw rather than corrupt data. This can + // be fixed by "tightening" the logical types. + throw new Error('ambiguous conversion'); + } + } + } + return index; + } + + types.forEach(function (type, index) { + if (Type.isType(type, 'abstract', 'logical')) { + dynamicBranches.push({index, type}); + } else { + let bucket = getTypeBucket(type); + if (bucketIndices[bucket] !== undefined) { + throw new Error(`ambiguous unwrapped union: ${j(this)}`); + } + bucketIndices[bucket] = index; + } + }); + return (val) => { + let index = bucketIndices[getValueBucket(val)]; + if (dynamicBranches.length) { + // Slower path, we must run the value through all branches. + index = getBranchIndex(val, index); + } + return index; + }; +} + /** * "Natural" union type. * @@ -1246,54 +1311,17 @@ UnionType.prototype._branchConstructor = function () { * + `map`, `record` */ class UnwrappedUnionType extends UnionType { - constructor (schema, opts) { + constructor (schema, opts, /* @private parameter */ _projectionFn) { super(schema, opts); - this._dynamicBranches = null; - this._bucketIndices = {}; - this.types.forEach(function (type, index) { - if (Type.isType(type, 'abstract', 'logical')) { - if (!this._dynamicBranches) { - this._dynamicBranches = []; - } - this._dynamicBranches.push({index, type}); - } else { - let bucket = getTypeBucket(type); - if (this._bucketIndices[bucket] !== undefined) { - throw new Error(`ambiguous unwrapped union: ${j(this)}`); - } - this._bucketIndices[bucket] = index; - } - }, this); - - Object.freeze(this); - } - - _getIndex (val) { - let index = this._bucketIndices[getValueBucket(val)]; - if (this._dynamicBranches) { - // Slower path, we must run the value through all branches. - index = this._getBranchIndex(val, index); + if (!_projectionFn && opts && typeof opts.wrapUnions === 'function') { + _projectionFn = opts.wrapUnions(this.types); } - return index; - } + this._getIndex = _projectionFn + ? generateProjectionIndexer(_projectionFn) + : generateDefaultIndexer(this.types); - _getBranchIndex (any, index) { - let logicalBranches = this._dynamicBranches; - for (let i = 0, l = logicalBranches.length; i < l; i++) { - let branch = logicalBranches[i]; - if (branch.type._check(any)) { - if (index === undefined) { - index = branch.index; - } else { - // More than one branch matches the value so we aren't guaranteed to - // infer the correct type. We throw rather than corrupt data. This can - // be fixed by "tightening" the logical types. - throw new Error('ambiguous conversion'); - } - } - } - return index; + Object.freeze(this); } _check (val, flags, hook, path) { @@ -1355,16 +1383,18 @@ class UnwrappedUnionType extends UnionType { // Using the `coerceBuffers` option can cause corruption and erroneous // failures with unwrapped unions (in rare cases when the union also // contains a record which matches a buffer's JSON representation). - if (isJsonBuffer(val) && this._bucketIndices.buffer !== undefined) { - index = this._bucketIndices.buffer; - } else { - index = this._getIndex(val); + if (isJsonBuffer(val)) { + let bufIndex = this.types.findIndex(t => getTypeBucket(t) === 'buffer'); + if (bufIndex !== -1) { + index = bufIndex; + } } + index ??= this._getIndex(val); break; case 2: // Decoding from JSON, we must unwrap the value. if (val === null) { - index = this._bucketIndices['null']; + index = this._getIndex(null); } else if (typeof val === 'object') { let keys = Object.keys(val); if (keys.length === 1) { diff --git a/test/test_types.js b/test/test_types.js index 0ad09672..e7675761 100644 --- a/test/test_types.js +++ b/test/test_types.js @@ -3505,6 +3505,57 @@ suite('types', () => { assert(Type.isType(t.field('unwrapped').type, 'union:unwrapped')); }); + test('union projection', () => { + const Dog = { + type: 'record', + name: 'Dog', + fields: [ + { type: 'string', name: 'bark' } + ], + }; + const Cat = { + type: 'record', + name: 'Cat', + fields: [ + { type: 'string', name: 'meow' } + ], + }; + const animalTypes = [Dog, Cat]; + + let callsToWrapUnions = 0; + const wrapUnions = (types) => { + callsToWrapUnions++; + assert.deepEqual(types.map(t => t.name), ['Dog', 'Cat']); + return (animal) => { + const animalType = ((animal) => { + if ('bark' in animal) { + return 'Dog'; + } else if ('meow' in animal) { + return 'Cat'; + } + throw new Error('Unknown animal'); + })(animal); + return types.indexOf(types.find(type => type.name === animalType)); + } + }; + + // Ambiguous, but we have a projection function + const Animal = Type.forSchema(animalTypes, { wrapUnions }); + Animal.toBuffer({ meow: '🐈' }); + assert.equal(callsToWrapUnions, 1); + assert.throws(() => Animal.toBuffer({ snap: '🐊' }), /Unknown animal/) + }); + + test('union projection with fallback', () => { + let t = Type.forSchema({ + type: 'record', + fields: [ + {name: 'wrapped', type: ['int', 'double' ]}, // Ambiguous. + ] + }, {wrapUnions: () => undefined }); + assert(Type.isType(t.field('wrapped').type, 'union:wrapped')); + }); + test('invalid wrap unions option', () => { assert.throws(() => { Type.forSchema('string', {wrapUnions: 'FOO'}); diff --git a/types/index.d.ts b/types/index.d.ts index 76a8019b..3c1f5aba 100644 --- a/types/index.d.ts +++ b/types/index.d.ts @@ -95,6 +95,21 @@ interface EncoderOptions { syncMarker: Buffer; } +/** + * A projection function that is used when unwrapping unions. + * This function is called at schema parsing time on each union with its branches' + * types. + * If it returns a non-null (function) value, that function will be called each + * time a value's branch needs to be inferred and should return the branch's + * index. + * The index muss be a number between 0 and length-1 of the passed types. + * In this case (a branch index) the union will use an unwrapped representation. + * Otherwise (undefined), the union will be wrapped. + */ +type BranchProjection = (types: ReadonlyArray) => + | ((val: unknown) => number) + | undefined; + interface ForSchemaOptions { assertLogicalTypes: boolean; logicalTypes: { [type: string]: new (schema: Schema, opts?: any) => types.LogicalType; }; @@ -103,7 +118,7 @@ interface ForSchemaOptions { omitRecordMethods: boolean; registry: { [name: string]: Type }; typeHook: (schema: Schema | string, opts: ForSchemaOptions) => Type | undefined; - wrapUnions: boolean | 'auto' | 'always' | 'never'; + wrapUnions: BranchProjection | boolean | 'auto' | 'always' | 'never'; } interface TypeOptions extends ForSchemaOptions {