-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions #5866
Changes from all commits
b9b3a9a
8923401
e6c757a
1e04526
ce37c12
a439d29
51c1ae8
aefc1b1
8c81678
84feca1
cab2964
b824fb3
8b3bb2b
c9e27d0
c54c544
0b8cb6b
f886f07
aac5c47
d27a215
4718f7f
771b1fd
d3faf93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,6 +160,40 @@ InstPair ForwardDiffTranscriber::transcribeReinterpret(IRBuilder* builder, IRIns | |
return InstPair(primalVal, diffVal); | ||
} | ||
|
||
InstPair ForwardDiffTranscriber::transcribeDifferentiableTypeAnnotation( | ||
IRBuilder* builder, | ||
IRInst* origInst) | ||
{ | ||
auto primalAnnotation = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to transcribe these annotation insts themselves? Shouldn't they just be considered as non-diff insts themselves? |
||
as<IRDifferentiableTypeAnnotation>(maybeCloneForPrimalInst(builder, origInst)); | ||
|
||
IRDifferentiableTypeAnnotation* annotation = as<IRDifferentiableTypeAnnotation>(origInst); | ||
|
||
differentiableTypeConformanceContext.addTypeToDictionary( | ||
(IRType*)primalAnnotation->getBaseType(), | ||
primalAnnotation->getWitness()); | ||
|
||
auto diffType = differentiateType(builder, (IRType*)annotation->getBaseType()); | ||
if (!diffType) | ||
return InstPair(primalAnnotation, nullptr); | ||
|
||
auto diffTypeDiffWitness = | ||
tryGetDifferentiableWitness(builder, diffType, DiffConformanceKind::Any); | ||
|
||
IRInst* args[] = {diffType, diffTypeDiffWitness}; | ||
|
||
auto diffAnnotation = builder->emitIntrinsicInst( | ||
builder->getVoidType(), | ||
kIROp_DifferentiableTypeAnnotation, | ||
2, | ||
args); | ||
|
||
builder->markInstAsPrimal(diffAnnotation); | ||
builder->markInstAsPrimal(primalAnnotation); | ||
|
||
return InstPair(primalAnnotation, diffAnnotation); | ||
} | ||
|
||
InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar) | ||
{ | ||
if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) | ||
|
@@ -752,9 +786,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig | |
auto pairValType = as<IRDifferentialPairTypeBase>( | ||
pairPtrType ? pairPtrType->getValueType() : pairType); | ||
|
||
auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType( | ||
&argBuilder, | ||
pairValType); | ||
auto diffType = differentiateType(&argBuilder, primalType); | ||
if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType)) | ||
{ | ||
// Create temp var to pass in/out arguments. | ||
|
@@ -795,7 +827,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig | |
if (diffArg) | ||
{ | ||
auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential( | ||
(IRType*)diffType, | ||
(IRType*)as<IRPtrTypeBase>(diffType)->getValueType(), | ||
newVal); | ||
markDiffTypeInst( | ||
&afterBuilder, | ||
|
@@ -827,17 +859,72 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig | |
} | ||
} | ||
} | ||
|
||
{ | ||
// --WORKAROUND-- | ||
// This is a temporary workaround for a very specific case.. | ||
// | ||
// If all the following are true: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why can't the logic above handle this case properly? What exactly is missing that requires this workaround logic? Should we just extend |
||
// 1. the parameter type expects a differential pair, | ||
// 2. the argument is derived from a no_diff type, and | ||
// 3. the argument type is a run-time type (i.e. extract_existential_type), | ||
// then we need to generate a differential 0, but the IR has no | ||
// information on the diff witness. | ||
// | ||
// We will bypass the conformance system & brute-force the lookup for the interface | ||
// keys, but the proper fix is to lower this key mapping during `no_diff` lowering. | ||
// | ||
|
||
// Condition 1 | ||
if (differentiableTypeConformanceContext.isDifferentiableType((originalParamType))) | ||
{ | ||
// Condition 3 | ||
if (auto extractExistentialType = as<IRExtractExistentialType>(primalType)) | ||
{ | ||
// Condition 2 | ||
if (isNoDiffType(extractExistentialType->getOperand(0)->getDataType())) | ||
{ | ||
// Force-differentiate the type (this will perform a search for the witness | ||
// without going through the diff-type annotation list) | ||
// | ||
IRInst* witnessTable = nullptr; | ||
auto diffType = differentiateExtractExistentialType( | ||
&argBuilder, | ||
extractExistentialType, | ||
witnessTable); | ||
|
||
auto pairType = | ||
getOrCreateDiffPairType(&argBuilder, primalType, witnessTable); | ||
auto zeroMethod = argBuilder.emitLookupInterfaceMethodInst( | ||
differentiableTypeConformanceContext.sharedContext->zeroMethodType, | ||
witnessTable, | ||
differentiableTypeConformanceContext.sharedContext | ||
->zeroMethodStructKey); | ||
auto diffZero = argBuilder.emitCallInst(diffType, zeroMethod, 0, nullptr); | ||
auto diffPair = | ||
argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffZero); | ||
|
||
args.add(diffPair); | ||
continue; | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Argument is not differentiable. | ||
// Add original/primal argument. | ||
args.add(primalArg); | ||
} | ||
|
||
IRType* diffReturnType = nullptr; | ||
diffReturnType = tryGetDiffPairType(&argBuilder, origCall->getFullType()); | ||
auto primalReturnType = | ||
(IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType()); | ||
|
||
diffReturnType = tryGetDiffPairType(&argBuilder, primalReturnType); | ||
|
||
if (!diffReturnType) | ||
{ | ||
diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType()); | ||
diffReturnType = primalReturnType; | ||
} | ||
|
||
auto callInst = argBuilder.emitCallInst(diffReturnType, diffCallee, args); | ||
|
@@ -1035,18 +1122,16 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize( | |
IRInst* diffBase = nullptr; | ||
if (instMapD.tryGetValue(origSpecialize->getBase(), diffBase)) | ||
{ | ||
auto diffType = differentiateType(builder, origSpecialize->getFullType()); | ||
if (diffBase) | ||
{ | ||
List<IRInst*> args; | ||
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) | ||
{ | ||
args.add(primalSpecialize->getArg(i)); | ||
} | ||
auto diffSpecialize = builder->emitSpecializeInst( | ||
builder->getTypeKind(), | ||
diffBase, | ||
args.getCount(), | ||
args.getBuffer()); | ||
auto diffSpecialize = | ||
builder->emitSpecializeInst(diffType, diffBase, args.getCount(), args.getBuffer()); | ||
return InstPair(primalSpecialize, diffSpecialize); | ||
} | ||
else | ||
|
@@ -1572,7 +1657,24 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu | |
return InstPair(origFunc, fwdDecor->getForwardDerivativeFunc()); | ||
} | ||
|
||
auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc); | ||
IRFunc* diffFunc = nullptr; | ||
|
||
// If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is weird. When will a function appear inside a generic body but not used as the return value of the generic? |
||
// insert location unchanged). If we're transcribing it as a declaration, we should | ||
// insert into the module. | ||
// | ||
auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc)); | ||
if (!origOuterGen || findInnerMostGenericReturnVal(origOuterGen) != origFunc) | ||
{ | ||
// Dealing with a declaration.. insert into module scope. | ||
IRBuilder subBuilder = *inBuilder; | ||
subBuilder.setInsertInto(inBuilder->getModule()); | ||
diffFunc = transcribeFuncHeaderImpl(&subBuilder, origFunc); | ||
} | ||
else | ||
{ | ||
diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc); | ||
} | ||
|
||
if (auto outerGen = findOuterGeneric(diffFunc)) | ||
{ | ||
|
@@ -1605,7 +1707,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I | |
IRBuilder builder = *inBuilder; | ||
|
||
maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); | ||
|
||
differentiableTypeConformanceContext.setFunc(origFunc); | ||
|
||
auto diffFunc = builder.createFunc(); | ||
|
@@ -1632,12 +1733,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I | |
|
||
// Transfer checkpoint hint decorations | ||
copyCheckpointHints(&builder, origFunc, diffFunc); | ||
|
||
// Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. | ||
if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) | ||
{ | ||
cloneDecoration(&cloneEnv, dictDecor, diffFunc, diffFunc->getModule()); | ||
} | ||
return diffFunc; | ||
} | ||
|
||
|
@@ -2012,6 +2107,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* | |
case kIROp_Reinterpret: | ||
return transcribeReinterpret(builder, origInst); | ||
|
||
case kIROp_DifferentiableTypeAnnotation: | ||
return transcribeDifferentiableTypeAnnotation(builder, origInst); | ||
|
||
// Differentiable insts that should have been lowered in a previous pass. | ||
case kIROp_SwizzledStore: | ||
{ | ||
|
@@ -2138,13 +2236,10 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam( | |
|
||
if (as<IRDifferentialPairType>(diffPairType) || as<IRDifferentialPtrPairType>(diffPairType)) | ||
{ | ||
auto diffType = differentiateType(builder, (IRType*)origParam->getFullType()); | ||
return InstPair( | ||
builder->emitDifferentialPairGetPrimal(diffPairParam), | ||
builder->emitDifferentialPairGetDifferential( | ||
(IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( | ||
builder, | ||
as<IRDifferentialPairTypeBase>(diffPairType)), | ||
diffPairParam)); | ||
builder->emitDifferentialPairGetDifferential(diffType, diffPairParam)); | ||
} | ||
else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType)) | ||
{ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Performing lowerWitnessLookups can open up new opportunities for specializations, so we need to run that specialization-optimization loop again.
This is really calling for cleaning up the specialization pass to be just a peephole optimization pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe
specializeModule()
already runs the witness lowering in a loop with the rest of the specialization pass?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case this is fine, but we should rewrite specialization pass to merge it with the peepholeOptimize pass as soon as possible.