Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions #5866

Merged
merged 22 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b9b3a9a
Overhauled the auto-diff system for dynamic dispatch
saipraveenb25 Dec 13, 2024
8923401
More fixes
saipraveenb25 Dec 13, 2024
e6c757a
remove intermediate dumps
saipraveenb25 Dec 13, 2024
1e04526
Update slang-ast-type.h
saipraveenb25 Dec 13, 2024
ce37c12
More fixes + add a workaround for existential no-diff
saipraveenb25 Dec 13, 2024
a439d29
Update reverse-control-flow-3.slang
saipraveenb25 Dec 13, 2024
51c1ae8
remove dumps
saipraveenb25 Dec 13, 2024
aefc1b1
remove more dumps
saipraveenb25 Dec 13, 2024
8c81678
Delete working-reverse-control-flow-3.hlsl
saipraveenb25 Dec 13, 2024
84feca1
Merge branch 'master' into fix-assoc-type-autodiff-2
saipraveenb25 Dec 13, 2024
cab2964
Cleanup comments + unused variables
saipraveenb25 Dec 13, 2024
b824fb3
Merge branch 'fix-assoc-type-autodiff-2' of https://github.com/saipra…
saipraveenb25 Dec 13, 2024
8b3bb2b
More comment cleanup
saipraveenb25 Dec 13, 2024
c9e27d0
Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeVal…
saipraveenb25 Dec 16, 2024
c54c544
Fix array of issues in Falcor tests.
saipraveenb25 Dec 16, 2024
0b8cb6b
Merge branch 'master' into fix-assoc-type-autodiff-2
saipraveenb25 Dec 16, 2024
f886f07
Update slang-ir-autodiff-pairs.cpp
saipraveenb25 Dec 16, 2024
aac5c47
Merge branch 'fix-assoc-type-autodiff-2' of https://github.com/saipra…
saipraveenb25 Dec 16, 2024
d27a215
More fixes for Falcor image tests
saipraveenb25 Dec 16, 2024
4718f7f
Merge branch 'master' into fix-assoc-type-autodiff-2
saipraveenb25 Dec 16, 2024
771b1fd
Merge remote-tracking branch 'official/master' into fix-assoc-type-au…
csyonghe Jan 9, 2025
d3faf93
Small fixups.
csyonghe Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9142,12 +9142,18 @@ void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
if (!decl->hasModifier<NoDiffThisAttribute>())
{
// Build decl-ref-type from interface.
auto interfaceType =
DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
// auto interfaceType =
// DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
auto thisType = DeclRefType::create(
m_astBuilder,
createDefaultSubstitutionsIfNeeded(
m_astBuilder,
this,
makeDeclRef(interfaceDecl->getThisTypeDecl())));

// If the interface is differentiable, make the this type a pair.
if (tryGetDifferentialType(getASTBuilder(), interfaceType))
reqDecl->diffThisType = getDifferentialPairType(interfaceType);
if (tryGetDifferentialType(getASTBuilder(), thisType))
reqDecl->diffThisType = getDifferentialPairType(thisType);
}

auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
Expand All @@ -9172,13 +9178,17 @@ void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
reqDecl->parentDecl = interfaceDecl;
if (!decl->hasModifier<NoDiffThisAttribute>())
{
// Build decl-ref-type from interface.
auto interfaceType =
DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
// Build decl-ref-type for this-type.
auto thisType = DeclRefType::create(
m_astBuilder,
createDefaultSubstitutionsIfNeeded(
m_astBuilder,
this,
makeDeclRef(interfaceDecl->getThisTypeDecl())));

// If the interface is differentiable, make the this type a pair.
if (tryGetDifferentialType(getASTBuilder(), interfaceType))
reqDecl->diffThisType = getDifferentialPairType(interfaceType);
if (tryGetDifferentialType(getASTBuilder(), thisType))
reqDecl->diffThisType = getDifferentialPairType(thisType);
}

auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
Expand Down
28 changes: 26 additions & 2 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,18 @@ Result linkAndOptimizeIR(
bool changed = false;
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE");
if (!codeGenContext->isSpecializationDisabled())
changed |= specializeModule(targetProgram, irModule, codeGenContext->getSink());
{
// Pre-autodiff, we will attempt to specialize as much as possible.
//
// Note: Lowered dynamic-dispatch code cannot be differentiated correctly due to
// missing information, so we defer that to after the auto-dff step.
//
SpecializationOptions specOptions;
specOptions.lowerWitnessLookups = false;
changed |=
specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions);
}

if (codeGenContext->getSink()->getErrorCount() != 0)
return SLANG_FAIL;
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE");
Expand Down Expand Up @@ -857,9 +868,20 @@ Result linkAndOptimizeIR(
reportCheckpointIntermediates(codeGenContext, sink, irModule);

// Finalization is always run so AD-related instructions can be removed,
// even the AD pass itself is not run.
// even if the AD pass itself is not run.
//
finalizeAutoDiffPass(targetProgram, irModule);
eliminateDeadCode(irModule, deadCodeEliminationOptions);

// After auto-diff, we can perform more aggressive specialization with dynamic-dispatch
// lowering.
//
if (!codeGenContext->isSpecializationDisabled())
{
SpecializationOptions specOptions;
specOptions.lowerWitnessLookups = true;
specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performing lowerWitnessLookups can open up new opportunities for specializations, so we need to run that specialization-optimization loop again.

This is really calling for cleaning up the specialization pass to be just a peephole optimization pass.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe specializeModule() already runs the witness lowering in a loop with the rest of the specialization pass?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case this is fine, but we should rewrite specialization pass to merge it with the peepholeOptimize pass as soon as possible.

}

finalizeSpecialization(irModule);

Expand Down Expand Up @@ -920,6 +942,8 @@ Result linkAndOptimizeIR(

validateIRModuleIfEnabled(codeGenContext, irModule);

inferAnyValueSizeWhereNecessary(targetProgram, irModule);

// If we have any witness tables that are marked as `KeepAlive`,
// but are not used for dynamic dispatch, unpin them so we don't
// do unnecessary work to lower them.
Expand Down
144 changes: 120 additions & 24 deletions source/slang/slang-ir-autodiff-fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,40 @@ InstPair ForwardDiffTranscriber::transcribeReinterpret(IRBuilder* builder, IRIns
return InstPair(primalVal, diffVal);
}

InstPair ForwardDiffTranscriber::transcribeDifferentiableTypeAnnotation(
IRBuilder* builder,
IRInst* origInst)
{
auto primalAnnotation =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to transcribe these annotation insts themselves? Shouldn't they just be considered as non-diff insts themselves?

as<IRDifferentiableTypeAnnotation>(maybeCloneForPrimalInst(builder, origInst));

IRDifferentiableTypeAnnotation* annotation = as<IRDifferentiableTypeAnnotation>(origInst);

differentiableTypeConformanceContext.addTypeToDictionary(
(IRType*)primalAnnotation->getBaseType(),
primalAnnotation->getWitness());

auto diffType = differentiateType(builder, (IRType*)annotation->getBaseType());
if (!diffType)
return InstPair(primalAnnotation, nullptr);

auto diffTypeDiffWitness =
tryGetDifferentiableWitness(builder, diffType, DiffConformanceKind::Any);

IRInst* args[] = {diffType, diffTypeDiffWitness};

auto diffAnnotation = builder->emitIntrinsicInst(
builder->getVoidType(),
kIROp_DifferentiableTypeAnnotation,
2,
args);

builder->markInstAsPrimal(diffAnnotation);
builder->markInstAsPrimal(primalAnnotation);

return InstPair(primalAnnotation, diffAnnotation);
}

InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar)
{
if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
Expand Down Expand Up @@ -745,16 +779,15 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
while (auto attrType = as<IRAttributedType>(origType))
origType = attrType->getBaseType();
}

if (auto pairType = tryGetDiffPairType(&argBuilder, primalType))
{
auto pairPtrType = as<IRPtrTypeBase>(pairType);

auto pairValType = as<IRDifferentialPairTypeBase>(
pairPtrType ? pairPtrType->getValueType() : pairType);

auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(
&argBuilder,
pairValType);
auto diffType = differentiateType(&argBuilder, primalType);
if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType))
{
// Create temp var to pass in/out arguments.
Expand Down Expand Up @@ -795,7 +828,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
if (diffArg)
{
auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential(
(IRType*)diffType,
(IRType*)as<IRPtrTypeBase>(diffType)->getValueType(),
newVal);
markDiffTypeInst(
&afterBuilder,
Expand Down Expand Up @@ -827,17 +860,72 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
}
}
}

{
// --WORKAROUND--
// This is a temporary workaround for a very specific case..
//
// If all the following are true:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't the logic above handle this case properly? What exactly is missing that requires this workaround logic? Should we just extend isNoDiffType to handle the ExtractExistentialType case, so that line 846 can handle this case?

// 1. the parameter type expects a differential pair,
// 2. the argument is derived from a no_diff type, and
// 3. the argument type is a run-time type (i.e. extract_existential_type),
// then we need to generate a differential 0, but the IR has no
// information on the diff witness.
//
// We will bypass the conformance system & brute-force the lookup for the interface
// keys, but the proper fix is to lower this key mapping during `no_diff` lowering.
//

// Condition 1
if (differentiableTypeConformanceContext.isDifferentiableType((originalParamType)))
{
// Condition 3
if (auto extractExistentialType = as<IRExtractExistentialType>(primalType))
{
// Condition 2
if (isNoDiffType(extractExistentialType->getOperand(0)->getDataType()))
{
// Force-differentiate the type (this will perform a search for the witness
// without going through the diff-type annotation list)
//
IRInst* witnessTable = nullptr;
auto diffType = differentiateExtractExistentialType(
&argBuilder,
extractExistentialType,
witnessTable);

auto pairType =
getOrCreateDiffPairType(&argBuilder, primalType, witnessTable);
auto zeroMethod = argBuilder.emitLookupInterfaceMethodInst(
differentiableTypeConformanceContext.sharedContext->zeroMethodType,
witnessTable,
differentiableTypeConformanceContext.sharedContext
->zeroMethodStructKey);
auto diffZero = argBuilder.emitCallInst(diffType, zeroMethod, 0, nullptr);
auto diffPair =
argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffZero);

args.add(diffPair);
continue;
}
}
}
}

// Argument is not differentiable.
// Add original/primal argument.
args.add(primalArg);
}

IRType* diffReturnType = nullptr;
diffReturnType = tryGetDiffPairType(&argBuilder, origCall->getFullType());
auto primalReturnType =
(IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType());

diffReturnType = tryGetDiffPairType(&argBuilder, primalReturnType);

if (!diffReturnType)
{
diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType());
diffReturnType = primalReturnType;
}

auto callInst = argBuilder.emitCallInst(diffReturnType, diffCallee, args);
Expand Down Expand Up @@ -1035,18 +1123,16 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(
IRInst* diffBase = nullptr;
if (instMapD.tryGetValue(origSpecialize->getBase(), diffBase))
{
auto diffType = differentiateType(builder, origSpecialize->getFullType());
if (diffBase)
{
List<IRInst*> args;
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
{
args.add(primalSpecialize->getArg(i));
}
auto diffSpecialize = builder->emitSpecializeInst(
builder->getTypeKind(),
diffBase,
args.getCount(),
args.getBuffer());
auto diffSpecialize =
builder->emitSpecializeInst(diffType, diffBase, args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
else
Expand Down Expand Up @@ -1572,7 +1658,24 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
return InstPair(origFunc, fwdDecor->getForwardDerivativeFunc());
}

auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);
IRFunc* diffFunc = nullptr;

// If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird. When will a function appear inside a generic body but not used as the return value of the generic?

// insert location unchanges). If we're transcribing it as a declaration, we should
// insert into the module.
//
auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc));
if (!origOuterGen || !(findInnerMostGenericReturnVal(origOuterGen) == origFunc))
{
// Dealing with a declaration.. insert into module scope.
IRBuilder subBuilder = *inBuilder;
subBuilder.setInsertInto(inBuilder->getModule());
diffFunc = transcribeFuncHeaderImpl(&subBuilder, origFunc);
}
else
{
diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);
}

if (auto outerGen = findOuterGeneric(diffFunc))
{
Expand Down Expand Up @@ -1605,7 +1708,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
IRBuilder builder = *inBuilder;

maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);

differentiableTypeConformanceContext.setFunc(origFunc);

auto diffFunc = builder.createFunc();
Expand All @@ -1632,12 +1734,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I

// Transfer checkpoint hint decorations
copyCheckpointHints(&builder, origFunc, diffFunc);

// Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
{
cloneDecoration(&cloneEnv, dictDecor, diffFunc, diffFunc->getModule());
}
return diffFunc;
}

Expand Down Expand Up @@ -2012,6 +2108,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_Reinterpret:
return transcribeReinterpret(builder, origInst);

case kIROp_DifferentiableTypeAnnotation:
return transcribeDifferentiableTypeAnnotation(builder, origInst);

// Differentiable insts that should have been lowered in a previous pass.
case kIROp_SwizzledStore:
{
Expand Down Expand Up @@ -2138,13 +2237,10 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(

if (as<IRDifferentialPairType>(diffPairType) || as<IRDifferentialPtrPairType>(diffPairType))
{
auto diffType = differentiateType(builder, (IRType*)origParam->getFullType());
return InstPair(
builder->emitDifferentialPairGetPrimal(diffPairParam),
builder->emitDifferentialPairGetDifferential(
(IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(
builder,
as<IRDifferentialPairTypeBase>(diffPairType)),
diffPairParam));
builder->emitDifferentialPairGetDifferential(diffType, diffPairParam));
}
else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
{
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-autodiff-fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase

InstPair transcribeReinterpret(IRBuilder* builder, IRInst* origInst);

InstPair transcribeDifferentiableTypeAnnotation(IRBuilder* builder, IRInst* origInst);

virtual IRFuncType* differentiateFunctionType(
IRBuilder* builder,
IRInst* func,
Expand Down
Loading
Loading