Skip to content

Commit

Permalink
Fix iv of constant (#2141)
Browse files Browse the repository at this point in the history
* Fix iv of constant

* fix

* fix

* fix

* fix

* Fix

* fix

* fix
  • Loading branch information
wsmoses authored Nov 1, 2024
1 parent 5f841a2 commit de7c147
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
39 changes: 23 additions & 16 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,12 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
void forwardModeInvertedPointerFallback(llvm::Instruction &I) {
using namespace llvm;

if (gutils->isConstantValue(&I))
return;
auto found = gutils->invertedPointers.find(&I);
if (gutils->isConstantValue(&I)) {
assert(found == gutils->invertedPointers.end());
return;
}

assert(found != gutils->invertedPointers.end());
auto placeholder = cast<PHINode>(&*found->second);
gutils->invertedPointers.erase(found);
Expand All @@ -324,6 +327,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {

auto toset = gutils->invertPointerM(&I, Builder2, /*nullShadow*/ true);

assert(toset != placeholder);

gutils->replaceAWithB(placeholder, toset);
placeholder->replaceAllUsesWith(toset);
gutils->erase(placeholder);
Expand Down Expand Up @@ -2145,18 +2150,6 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
void visitBinaryOperator(llvm::BinaryOperator &BO) {
eraseIfUnused(BO);

size_t size = 1;
if (BO.getType()->isSized())
size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
BO.getType()) +
7) /
8;

if (BO.getType()->isIntOrIntVectorTy() &&
TR.intType(size, &BO, /*errifnotfound*/ false) == BaseType::Pointer) {
return;
}

if (BO.getOpcode() == llvm::Instruction::FDiv &&
(Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ReverseModeCombined) &&
Expand Down Expand Up @@ -2289,6 +2282,9 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}

void createBinaryOperatorAdjoint(llvm::BinaryOperator &BO) {
if (gutils->isConstantInstruction(&BO)) {
return;
}
using namespace llvm;

IRBuilder<> Builder2(&BO);
Expand Down Expand Up @@ -2770,8 +2766,19 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
auto rval = EmitNoDerivativeError(ss.str(), BO, gutils, Builder2);
if (!rval)
rval = Constant::getNullValue(gutils->getShadowType(BO.getType()));
if (!gutils->isConstantValue(&BO))
setDiffe(&BO, rval, Builder2);
auto ifound = gutils->invertedPointers.find(&BO);
if (!gutils->isConstantValue(&BO)) {
if (ifound != gutils->invertedPointers.end()) {
auto placeholder = cast<PHINode>(&*ifound->second);
gutils->invertedPointers.erase(ifound);
gutils->replaceAWithB(placeholder, rval);
gutils->erase(placeholder);
gutils->invertedPointers.insert(std::make_pair(
(const Value *)&BO, InvertedPointerVH(gutils, rval)));
}
} else {
assert(ifound == gutils->invertedPointers.end());
}
break;
}
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/benchmarks/ReverseMode/adbench/gmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ int main(const int argc, const char* argv[]) {
getTests(paths, "data/2.5k", "2.5k/");
getTests(paths, "data/10k", "10k/");
}

std::ofstream jsonfile("results.json", std::ofstream::trunc);
json test_results;

Expand Down

0 comments on commit de7c147

Please sign in to comment.