From 07895b0c2795405f3ea19022a6de07fd3240cc69 Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Thu, 4 Jul 2024 16:58:07 +0100 Subject: [PATCH] add support for $mul, $div, $divfloor, $mod, $modfloor, $pow in functional backend --- backends/functional/cxx.cc | 3 + backends/functional/cxx_runtime/sim.h | 89 +++++++++++++++++++++++++-- backends/functional/smtlib.cc | 6 ++ kernel/functionalir.cc | 73 +++++++++++++++++++++- kernel/functionalir.h | 17 ++++- 5 files changed, 177 insertions(+), 11 deletions(-) diff --git a/backends/functional/cxx.cc b/backends/functional/cxx.cc index 21e287a96c4..81af54c073e 100644 --- a/backends/functional/cxx.cc +++ b/backends/functional/cxx.cc @@ -105,6 +105,9 @@ template struct CxxPrintVisitor { void concat(Node, Node a, int, Node b, int) { print("{}.concat({})", a, b); } void add(Node, Node a, Node b, int) { print("{} + {}", a, b); } void sub(Node, Node a, Node b, int) { print("{} - {}", a, b); } + void mul(Node, Node a, Node b, int) { print("{} * {}", a, b); } + void unsigned_div(Node, Node a, Node b, int) { print("{} / {}", a, b); } + void unsigned_mod(Node, Node a, Node b, int) { print("{} % {}", a, b); } void bitwise_and(Node, Node a, Node b, int) { print("{} & {}", a, b); } void bitwise_or(Node, Node a, Node b, int) { print("{} | {}", a, b); } void bitwise_xor(Node, Node a, Node b, int) { print("{} ^ {}", a, b); } diff --git a/backends/functional/cxx_runtime/sim.h b/backends/functional/cxx_runtime/sim.h index 1985322f1c8..040c6aefce8 100644 --- a/backends/functional/cxx_runtime/sim.h +++ b/backends/functional/cxx_runtime/sim.h @@ -22,6 +22,8 @@ #include #include +#include +#include template class Signal { @@ -149,6 +151,48 @@ class Signal { uint32_t as_int() const { return as_numeric(); } +private: + std::string as_string_p2(int b) const { + std::string ret; + for(int i = (n - 1) - (n - 1) % b; i >= 0; i -= b) + ret += "0123456789abcdef"[(*this >> Signal<32>(i)).as_int() & ((1< t = *this; + Signal b = 10; + do{ + ret += (char)('0' + (t % b).as_int()); + t = t / b; + }while(t.any()); + std::reverse(ret.begin(), ret.end()); + return ret; + } +public: + std::string as_string(int base = 16, bool showbase = true) const { + std::string ret; + if(showbase) { + ret += std::to_string(n); + switch(base) { + case 2: ret += "'b"; break; + case 8: ret += "'o"; break; + case 10: ret += "'d"; break; + case 16: ret += "'h"; break; + default: assert(0); + } + } + switch(base) { + case 2: return ret + as_string_p2(1); + case 8: return ret + as_string_p2(3); + case 10: return ret + as_string_b10(); + case 16: return ret + as_string_p2(4); + default: assert(0); + } + } + friend std::ostream &operator << (std::ostream &os, Signal const &s) { return os << s.as_string(); } + Signal operator ~() const { Signal ret; @@ -160,11 +204,11 @@ class Signal { Signal operator -() const { Signal ret; - bool carry = true; + int x = 1; for(size_t i = 0; i < n; i++) { - int r = !_bits[i] + carry; - ret._bits[i] = (r & 1) != 0; - carry = (r >> 1) != 0; + x += (int)!_bits[i]; + ret._bits[i] = (x & 1) != 0; + x >>= 1; } return ret; } @@ -172,9 +216,8 @@ class Signal { Signal operator +(Signal const &b) const { Signal ret; - size_t i; int x = 0; - for(i = 0; i < n; i++){ + for(size_t i = 0; i < n; i++){ x += (int)_bits[i] + (int)b._bits[i]; ret._bits[i] = x & 1; x >>= 1; @@ -194,6 +237,40 @@ class Signal { return ret; } + Signal operator *(Signal const &b) const + { + Signal ret; + int x = 0; + for(size_t i = 0; i < n; i++){ + for(size_t j = 0; j <= i; j++) + x += (int)_bits[j] & (int)b._bits[i-j]; + ret._bits[i] = x & 1; + x >>= 1; + } + return ret; + } + +private: + Signal divmod(Signal const &b, bool modulo) const + { + if(!b.any()) return 0; + Signal q = 0; + Signal r = 0; + for(size_t i = n; i-- != 0; ){ + r = r << Signal<1>(1); + r._bits[0] = _bits[i]; + if(r >= b){ + r = r - b; + q._bits[i] = true; + } + } + return modulo ? r : q; + } +public: + + Signal operator /(Signal const &b) const { return divmod(b, false); } + Signal operator %(Signal const &b) const { return divmod(b, true); } + bool operator ==(Signal const &b) const { for(size_t i = 0; i < n; i++) diff --git a/backends/functional/smtlib.cc b/backends/functional/smtlib.cc index afddbd83440..55efc83c1c2 100644 --- a/backends/functional/smtlib.cc +++ b/backends/functional/smtlib.cc @@ -119,6 +119,12 @@ template struct SmtPrintVisitor { std::string sub(Node, Node a, Node b, int) { return format("(bvsub %0 %1)", np(a), np(b)); } + std::string mul(Node, Node a, Node b, int) { return format("(bvmul %0 %1)", np(a), np(b)); } + + std::string unsigned_div(Node, Node a, Node b, int) { return format("(bvudiv %0 %1)", np(a), np(b)); } + + std::string unsigned_mod(Node, Node a, Node b, int) { return format("(bvurem %0 %1)", np(a), np(b)); } + std::string bitwise_and(Node, Node a, Node b, int) { return format("(bvand %0 %1)", np(a), np(b)); } std::string bitwise_or(Node, Node a, Node b, int) { return format("(bvor %0 %1)", np(a), np(b)); } diff --git a/kernel/functionalir.cc b/kernel/functionalir.cc index 3fd2d8240e7..6498ef63629 100644 --- a/kernel/functionalir.cc +++ b/kernel/functionalir.cc @@ -37,6 +37,15 @@ class CellSimplifier { return factory.mux(lower_b, factory.constant(RTLIL::Const(y_width, new_width)), overflow, new_width); } } + T sign(T a, int a_width) { + return factory.slice(a, a_width, a_width - 1, 1); + } + T neg_if(T a, int a_width, T s) { + return factory.mux(a, factory.unary_minus(a, a_width), s, a_width); + } + T abs(T a, int a_width) { + return neg_if(a, a_width, sign(a, a_width)); + } public: T reduce_or(T a, int width) { if (width == 1) @@ -71,6 +80,23 @@ class CellSimplifier { return factory.bitwise_or(aa, bb, width); } CellSimplifier(Factory &f) : factory(f) {} +private: + T handle_pow(T a0, int a_width, T b, int b_width, int y_width, bool is_signed) { + T a = extend(a0, a_width, y_width, is_signed); + T r = factory.constant(Const(1, y_width)); + for(int i = 0; i < b_width; i++) { + T b_bit = factory.slice(b, b_width, i, 1); + r = factory.mux(r, factory.mul(r, a, y_width), b_bit, y_width); + a = factory.mul(a, a, y_width); + } + if (is_signed) { + T a_ge_1 = factory.unsigned_greater_than(abs(a0, a_width), factory.constant(Const(1, a_width)), a_width); + T zero_result = factory.bitwise_and(a_ge_1, sign(b, b_width), 1); + r = factory.mux(r, factory.constant(Const(0, y_width)), zero_result, y_width); + } + return r; + } +public: T handle(IdString cellType, dict parameters, dict inputs) { int a_width = parameters.at(ID(A_WIDTH), Const(-1)).as_int(); @@ -78,7 +104,7 @@ class CellSimplifier { int y_width = parameters.at(ID(Y_WIDTH), Const(-1)).as_int(); bool a_signed = parameters.at(ID(A_SIGNED), Const(0)).as_bool(); bool b_signed = parameters.at(ID(B_SIGNED), Const(0)).as_bool(); - if(cellType.in({ID($add), ID($sub), ID($and), ID($or), ID($xor), ID($xnor)})){ + if(cellType.in({ID($add), ID($sub), ID($and), ID($or), ID($xor), ID($xnor), ID($mul)})){ bool is_signed = a_signed && b_signed; T a = extend(inputs.at(ID(A)), a_width, y_width, is_signed); T b = extend(inputs.at(ID(B)), b_width, y_width, is_signed); @@ -86,6 +112,8 @@ class CellSimplifier { return factory.add(a, b, y_width); else if(cellType == ID($sub)) return factory.sub(a, b, y_width); + else if(cellType == ID($mul)) + return factory.mul(a, b, y_width); else if(cellType == ID($and)) return factory.bitwise_and(a, b, y_width); else if(cellType == ID($or)) @@ -160,7 +188,7 @@ class CellSimplifier { T b = inputs.at(ID(B)); T shr = logical_shift_right(a, b, width, b_width); if(b_signed) { - T sign_b = factory.slice(b, b_width, b_width - 1, 1); + T sign_b = sign(b, b_width); T shl = logical_shift_left(a, factory.unary_minus(b, b_width), width, b_width); T y = factory.mux(shr, shl, sign_b, width); return extend(y, width, y_width, false); @@ -182,7 +210,46 @@ class CellSimplifier { int offset = parameters.at(ID(OFFSET)).as_int(); T a = inputs.at(ID(A)); return factory.slice(a, a_width, offset, y_width); - }else{ + }else if(cellType.in({ID($div), ID($mod), ID($divfloor), ID($modfloor)})) { + int width = max(a_width, b_width); + bool is_signed = a_signed && b_signed; + T a = extend(inputs.at(ID(A)), a_width, width, is_signed); + T b = extend(inputs.at(ID(B)), b_width, width, is_signed); + if(is_signed) { + if(cellType == ID($div)) { + T abs_y = factory.unsigned_div(abs(a, width), abs(b, width), width); + T out_sign = factory.not_equal(sign(a, width), sign(b, width), 1); + return neg_if(extend(abs_y, width, y_width, true), y_width, out_sign); + } else if(cellType == ID($mod)) { + T abs_y = factory.unsigned_mod(abs(a, width), abs(b, width), width); + return neg_if(extend(abs_y, width, y_width, true), y_width, sign(a, width)); + } else if(cellType == ID($divfloor)) { + T b_sign = sign(b, width); + T a1 = neg_if(a, width, b_sign); + T b1 = neg_if(b, width, b_sign); + T a1_sign = sign(a1, width); + T a2 = factory.mux(a1, factory.bitwise_not(a1, width), a1_sign, width); + T y1 = factory.unsigned_div(a2, b1, width); + T y2 = factory.mux(y1, factory.bitwise_not(y1, width), a1_sign, width); + return extend(y2, width, y_width, true); + } else if(cellType == ID($modfloor)) { + T abs_b = abs(b, width); + T abs_y = factory.unsigned_mod(abs(a, width), abs_b, width); + T flip_y = factory.bitwise_and(factory.bitwise_xor(sign(a, width), sign(b, width), 1), factory.reduce_or(abs_y, width), 1); + T y_flipped = factory.mux(abs_y, factory.sub(abs_b, abs_y, width), flip_y, width); + T y = neg_if(y_flipped, width, sign(b, b_width)); + return extend(y, width, y_width, true); + } else + log_error("unhandled cell in CellSimplifier %s\n", cellType.c_str()); + } else { + if(cellType.in({ID($mod), ID($modfloor)})) + return extend(factory.unsigned_mod(a, b, width), width, y_width, false); + else + return extend(factory.unsigned_div(a, b, width), width, y_width, false); + } + } else if(cellType == ID($pow)) { + return handle_pow(inputs.at(ID(A)), a_width, inputs.at(ID(B)), b_width, y_width, a_signed && b_signed); + } else{ log_error("unhandled cell in CellSimplifier %s\n", cellType.c_str()); } } diff --git a/kernel/functionalir.h b/kernel/functionalir.h index 2c0d0f55c2c..3e7c55ef4ec 100644 --- a/kernel/functionalir.h +++ b/kernel/functionalir.h @@ -39,6 +39,9 @@ class FunctionalIR { concat, add, sub, + mul, + unsigned_div, + unsigned_mod, bitwise_and, bitwise_or, bitwise_xor, @@ -145,6 +148,9 @@ class FunctionalIR { std::string concat(Node, Node a, int, Node b, int) { return "concat(" + np(a) + ", " + np(b) + ")"; } std::string add(Node, Node a, Node b, int) { return "add(" + np(a) + ", " + np(b) + ")"; } std::string sub(Node, Node a, Node b, int) { return "sub(" + np(a) + ", " + np(b) + ")"; } + std::string mul(Node, Node a, Node b, int) { return "mul(" + np(a) + ", " + np(b) + ")"; } + std::string unsigned_div(Node, Node a, Node b, int) { return "unsigned_div(" + np(a) + ", " + np(b) + ")"; } + std::string unsigned_mod(Node, Node a, Node b, int) { return "unsigned_mod(" + np(a) + ", " + np(b) + ")"; } std::string bitwise_and(Node, Node a, Node b, int) { return "bitwise_and(" + np(a) + ", " + np(b) + ")"; } std::string bitwise_or(Node, Node a, Node b, int) { return "bitwise_or(" + np(a) + ", " + np(b) + ")"; } std::string bitwise_xor(Node, Node a, Node b, int) { return "bitwise_xor(" + np(a) + ", " + np(b) + ")"; } @@ -193,11 +199,14 @@ class FunctionalIR { case Fn::concat: return v.concat(*this, arg(0), arg(0).width(), arg(1), arg(1).width()); break; case Fn::add: return v.add(*this, arg(0), arg(1), sort().width()); break; case Fn::sub: return v.sub(*this, arg(0), arg(1), sort().width()); break; + case Fn::mul: return v.mul(*this, arg(0), arg(1), sort().width()); break; + case Fn::unsigned_div: return v.unsigned_div(*this, arg(0), arg(1), sort().width()); break; + case Fn::unsigned_mod: return v.unsigned_mod(*this, arg(0), arg(1), sort().width()); break; case Fn::bitwise_and: return v.bitwise_and(*this, arg(0), arg(1), sort().width()); break; case Fn::bitwise_or: return v.bitwise_or(*this, arg(0), arg(1), sort().width()); break; case Fn::bitwise_xor: return v.bitwise_xor(*this, arg(0), arg(1), sort().width()); break; case Fn::bitwise_not: return v.bitwise_not(*this, arg(0), sort().width()); break; - case Fn::unary_minus: return v.bitwise_not(*this, arg(0), sort().width()); break; + case Fn::unary_minus: return v.unary_minus(*this, arg(0), sort().width()); break; case Fn::reduce_and: return v.reduce_and(*this, arg(0), arg(0).width()); break; case Fn::reduce_or: return v.reduce_or(*this, arg(0), arg(0).width()); break; case Fn::reduce_xor: return v.reduce_xor(*this, arg(0), arg(0).width()); break; @@ -257,6 +266,9 @@ class FunctionalIR { } Node add(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); } Node sub(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); } + Node mul(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); } + Node unsigned_div(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); } + Node unsigned_mod(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); } Node bitwise_and(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); } Node bitwise_or(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); } Node bitwise_xor(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); } @@ -298,7 +310,8 @@ class FunctionalIR { return add(Fn::buf, Sort(width), {}); } void update_pending(Node node, Node value) { - log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0 && node.sort() == value.sort()); + log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0); + log_assert(node.sort() == value.sort()); node._ref.append_arg(value._ref); } Node input(IdString name, int width) {