Skip to content

Commit

Permalink
simplify further
Browse files Browse the repository at this point in the history
  • Loading branch information
pca006132 committed Dec 29, 2024
1 parent d0df1fb commit 486c762
Showing 1 changed file with 85 additions and 148 deletions.
233 changes: 85 additions & 148 deletions src/sdf/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct AffineValue {
bool operator==(const AffineValue &other) const {

Check warning on line 33 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L32-L33

Added lines #L32 - L33 were not covered by tests
return var == other.var && a == other.a && b == other.b;
}
AffineValue operator+(double d) { return AffineValue(var, a, b + d); }
AffineValue operator*(double d) { return AffineValue(var, a * d, b * d); }
};

template <>
Expand Down Expand Up @@ -183,17 +185,42 @@ void Context::optimizeAffine() {
return affineValues[operand.toInstIndex()].b;
return {};
};

auto replaceInst = [&](int from, int to) {
auto fromInst = Operand::fromInstIndex(from);
auto toInst = Operand::fromInstIndex(to);
for (auto use : opUses[from]) {
for (auto use : opUses[from])
for (auto &operand : instructions[use].operands)
if (operand == fromInst) operand = toInst;
}
if (operand == Operand::fromInstIndex(from))
operand = Operand::fromInstIndex(to);
opUses[from].clear();
instructions[from] = {OpCode::NOP, {none, none, none}};
};
auto handleAdd = [&](Operand x, Operand y,
bool sub) -> std::optional<AffineValue> {
auto lhs = getConstant(x);
auto rhs = getConstant(y);
if (lhs.has_value() && rhs.has_value()) {
return AffineValue(lhs.value() + rhs.value() * (sub ? -1 : 1));
} else if (lhs.has_value() && y.isResult()) {
return affineValues[y.toInstIndex()] * (sub ? -1 : 1) + lhs.value();
} else if (rhs.has_value() && x.isResult()) {
return affineValues[x.toInstIndex()] + rhs.value() * (sub ? -1 : 1);
} else if (x.isResult() && y.isResult()) {
if (affineValues[x.toInstIndex()].var ==
affineValues[y.toInstIndex()].var) {
auto other = affineValues[y.toInstIndex()];
auto result = affineValues[x.toInstIndex()];
if (sub) other = other * -1;
result.a += other.a;
result.b += other.b;

Check warning on line 213 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L209-L213

Added lines #L209 - L213 were not covered by tests
return result;
}
}
return {};
};
auto constWithUse = [&](double constant, size_t inst) {
auto result = addConstant(constant);
addUse(result, inst);
return result;
};

// abstract interpretation to figure out affine values for each instruction,
// and replace them as appropriate
Expand All @@ -203,134 +230,46 @@ void Context::optimizeAffine() {
auto &inst = instructions[i];
AffineValue result = AffineValue(Operand::fromInstIndex(i), 1, 0);
switch (inst.op) {
case OpCode::NOP:
case OpCode::RETURN:
case OpCode::CONSTANT:
case OpCode::LOAD:
case OpCode::STORE:
break;
// notably, neg is special among these unary opcode
case OpCode::ABS:
case OpCode::EXP:
case OpCode::LOG:
case OpCode::SQRT:
case OpCode::FLOOR:
case OpCode::CEIL:
case OpCode::ROUND:
case OpCode::SIN:
case OpCode::COS:
case OpCode::TAN:
case OpCode::ASIN:
case OpCode::ACOS:
case OpCode::ATAN: {
auto x = getConstant(inst.operands[0]);
if (x.has_value())
result = AffineValue(
EvalContext<double>::handle_unary(inst.op, x.value()));
break;
}
case OpCode::NEG:
if (inst.operands[0].isConst())
result = AffineValue(-constants[inst.operands[0].toConstIndex()]);
else if (inst.operands[0].isResult()) {
auto av = affineValues[inst.operands[0].toInstIndex()];
result = AffineValue(av.var, -av.a, -av.b);
}
else if (inst.operands[0].isResult())
result = affineValues[inst.operands[0].toInstIndex()] * -1;

Check warning on line 237 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L233-L237

Added lines #L233 - L237 were not covered by tests
break;
case OpCode::DIV: {
// TODO: handle the case where lhs is divisible by rhs despite rhs is
// not a constant
auto rhs = getConstant(inst.operands[1]);
if (rhs.has_value()) {
if (inst.operands[0].isConst()) {
result = AffineValue(constants[inst.operands[0].toConstIndex()] /
rhs.value());
} else if (inst.operands[0].isResult()) {
auto av = affineValues[inst.operands[0].toInstIndex()];
result =
AffineValue(av.var, av.a / rhs.value(), av.b / rhs.value());
}
if (inst.operands[0].isConst())
result = constants[inst.operands[0].toConstIndex()] / rhs.value();

Check warning on line 245 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L245

Added line #L245 was not covered by tests
else if (inst.operands[0].isResult())
result = affineValues[inst.operands[0].toInstIndex()] *
(1 / rhs.value());
}
break;
}
case OpCode::MOD:
case OpCode::MIN:
case OpCode::MAX:
case OpCode::EQ:
case OpCode::GT: {
// TODO: we can do better than just constant propagation...
auto lhs = getConstant(inst.operands[0]);
auto rhs = getConstant(inst.operands[1]);
if (lhs.has_value() && rhs.has_value())
result = AffineValue(EvalContext<double>::handle_binary(
inst.op, lhs.value(), rhs.value()));
break;
}
case OpCode::ADD: {
auto x = inst.operands[0];
auto y = inst.operands[1];
auto lhs = getConstant(x);
auto rhs = getConstant(y);
if (lhs.has_value() && rhs.has_value()) {
result = AffineValue(lhs.value() + rhs.value());
} else if (lhs.has_value() && y.isResult()) {
result = affineValues[y.toInstIndex()];
result.b += lhs.value();
} else if (rhs.has_value() && x.isResult()) {
result = affineValues[x.toInstIndex()];
result.b += rhs.value();
} else if (x.isResult() && y.isResult()) {
if (affineValues[x.toInstIndex()].var ==
affineValues[y.toInstIndex()].var) {
auto other = affineValues[y.toInstIndex()];
result = affineValues[x.toInstIndex()];
result.a += other.a;
result.b += other.b;
}
}
auto r = handleAdd(inst.operands[0], inst.operands[1], false);
if (r.has_value()) result = r.value();
break;
}
case OpCode::SUB: {
auto x = inst.operands[0];
auto y = inst.operands[1];
auto lhs = getConstant(x);
auto rhs = getConstant(y);
if (lhs.has_value() && rhs.has_value()) {
result = AffineValue(lhs.value() - rhs.value());
} else if (lhs.has_value() && y.isResult()) {
result = affineValues[y.toInstIndex()];
result.a = -result.a;
result.b = lhs.value() - result.b;
} else if (rhs.has_value() && x.isResult()) {
result = affineValues[x.toInstIndex()];
result.b -= rhs.value();
} else if (x.isResult() && y.isResult()) {
if (affineValues[x.toInstIndex()].var ==
affineValues[y.toInstIndex()].var) {
auto other = affineValues[y.toInstIndex()];
result = affineValues[x.toInstIndex()];
result.a -= other.a;
result.b -= other.b;
}
}
auto r = handleAdd(inst.operands[0], inst.operands[1], true);
if (r.has_value()) result = r.value();
break;
}
case OpCode::MUL: {
auto x = inst.operands[0];
auto y = inst.operands[1];
auto lhs = getConstant(x);
auto rhs = getConstant(y);
if (lhs.has_value() && rhs.has_value()) {
if (lhs.has_value() && rhs.has_value())
result = AffineValue(lhs.value() * rhs.value());

Check warning on line 268 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L268

Added line #L268 was not covered by tests
} else if (lhs.has_value() && y.isResult()) {
result = affineValues[y.toInstIndex()];
result.a *= lhs.value();
result.b *= lhs.value();
} else if (rhs.has_value() && x.isResult()) {
result = affineValues[x.toInstIndex()];
result.a *= rhs.value();
result.b *= rhs.value();
}
else if (lhs.has_value() && y.isResult())
result = affineValues[y.toInstIndex()] * lhs.value();
else if (rhs.has_value() && x.isResult())
result = affineValues[x.toInstIndex()] * rhs.value();

Check warning on line 272 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L272

Added line #L272 was not covered by tests
break;
}
case OpCode::FMA: {
Expand All @@ -341,18 +280,12 @@ void Context::optimizeAffine() {
auto b = getConstant(y);
auto c = getConstant(z);

Check warning on line 281 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L275-L281

Added lines #L275 - L281 were not covered by tests
// various cases...
if (b.has_value() && c.has_value()) {
result = affineValues[x.toInstIndex()];
result.a *= b.value();
result.b = result.b * b.value() + c.value();
} else if (a.has_value() && c.has_value()) {
result = affineValues[y.toInstIndex()];
result.a *= a.value();
result.b = result.b * a.value() + c.value();
} else if (a.has_value() && b.has_value()) {
result = affineValues[z.toInstIndex()];
result.b += a.value() * b.value();
}
if (b.has_value() && c.has_value())
result = affineValues[x.toInstIndex()] * b.value() + c.value();
else if (a.has_value() && c.has_value())
result = affineValues[y.toInstIndex()] * a.value() + c.value();
else if (a.has_value() && b.has_value())
result = affineValues[z.toInstIndex()] + a.value() * c.value();

Check warning on line 288 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L283-L288

Added lines #L283 - L288 were not covered by tests
break;
}
case OpCode::CHOICE: {
Expand All @@ -367,6 +300,20 @@ void Context::optimizeAffine() {
}
break;
}
default: {
using ectx = EvalContext<double>;
if (inst.op >= OpCode::ABS && inst.op <= OpCode::ATAN) {
auto x = getConstant(inst.operands[0]);
if (x.has_value()) result = ectx::handle_unary(inst.op, x.value());
} else if (inst.op >= OpCode::DIV && inst.op <= OpCode::GT) {
// TODO: we can do better than just constant propagation...
auto lhs = getConstant(inst.operands[0]);
auto rhs = getConstant(inst.operands[1]);
if (lhs.has_value() && rhs.has_value())
result = ectx::handle_binary(inst.op, lhs.value(), rhs.value());

Check warning on line 313 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L313

Added line #L313 was not covered by tests
}
break;
}
}
affineValues.push_back(result);
if (result.var != Operand::fromInstIndex(i)) {
Expand All @@ -385,31 +332,24 @@ void Context::optimizeAffine() {
replaceInst(static_cast<int>(i),
static_cast<int>(result.var.toInstIndex()));
} else if (result.a == 1.0) {
auto constant = addConstant(result.b);
addUse(constant, i);
instructions[i] = {OpCode::ADD, {constant, result.var, none}};
instructions[i] = {OpCode::ADD,
{constWithUse(result.b, i), result.var, none}};

Check warning on line 336 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L335-L336

Added lines #L335 - L336 were not covered by tests
} else if (result.a == -1.0 && result.b == 0.0) {
instructions[i] = {OpCode::NEG, {result.var, none, none}};

Check warning on line 338 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L338

Added line #L338 was not covered by tests
} else if (result.a == -1.0) {
auto constant = addConstant(result.b);
addUse(constant, i);
instructions[i] = {OpCode::SUB, {constant, result.var, none}};
instructions[i] = {OpCode::SUB,
{constWithUse(result.b, i), result.var, none}};
} else if (result.b == 0.0) {
auto constant = addConstant(result.a);
addUse(constant, i);
instructions[i] = {OpCode::MUL, {constant, result.var, none}};
instructions[i] = {OpCode::MUL,
{constWithUse(result.a, i), result.var, none}};
} else if (result.a == 0.0) {
auto a = addConstant(0.0);
auto b = addConstant(result.b);
addUse(a, i);
addUse(b, i);
instructions[i] = {OpCode::ADD, {b, a, none}};
instructions[i] = {
OpCode::ADD,
{constWithUse(result.b, i), constWithUse(0.0, i), none}};
} else {
auto a = addConstant(result.a);
auto b = addConstant(result.b);
addUse(a, i);
addUse(b, i);
instructions[i] = {OpCode::FMA, {a, result.var, b}};
instructions[i] = {OpCode::FMA,
{constWithUse(result.a, i), result.var,
constWithUse(result.b, i)}};
}
}
}
Expand All @@ -425,12 +365,10 @@ void Context::schedule() {
std::vector<size_t> levelMap;
levelMap.reserve(oldInstructions.size());
for (size_t i = 0; i < oldInstructions.size(); i++) {
const auto &inst = oldInstructions[i];
size_t maxLevel = 0;
for (auto operand : inst.operands) {
if (!operand.isResult()) continue;
maxLevel = std::max(maxLevel, levelMap[operand.toInstIndex()]);
}
for (auto operand : oldInstructions[i].operands)
maxLevel = std::max(
maxLevel, operand.isResult() ? levelMap[operand.toInstIndex()] : 0);
levelMap.push_back(maxLevel + 1);
}

Expand All @@ -443,8 +381,7 @@ void Context::schedule() {
return operand.isResult() && computedInst[operand.toInstIndex()].isNone();
};
auto toNewOperand = [&computedInst](Operand old) {

Check warning on line 383 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L383

Added line #L383 was not covered by tests
if (old.isResult()) return computedInst[old.toInstIndex()];
return old;
return old.isResult() ? computedInst[old.toInstIndex()] : old;
};

while (!stack.empty()) {
Expand Down

0 comments on commit 486c762

Please sign in to comment.