From 7f68327e9f1237e9d9670994eed3ab197816e8c1 Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Wed, 17 Jul 2024 12:42:24 +0100 Subject: [PATCH] remove widths parameters from FunctionalIR factory methods and from functionalir.cc --- kernel/functionalir.cc | 391 +++++++++++++++++++---------------------- kernel/functionalir.h | 66 ++++--- 2 files changed, 216 insertions(+), 241 deletions(-) diff --git a/kernel/functionalir.cc b/kernel/functionalir.cc index 9b8076a0ef7..48baad7168b 100644 --- a/kernel/functionalir.cc +++ b/kernel/functionalir.cc @@ -98,87 +98,75 @@ std::string FunctionalIR::Node::to_string(std::function np) return visit(PrintVisitor(np)); } -template class CellSimplifier { - Factory &factory; - T reduce_shift_width(T b, int b_width, int y_width, int &reduced_b_width) { + using Node = FunctionalIR::Node; + FunctionalIR::Factory &factory; + Node reduce_shift_width(Node b, int y_width) { log_assert(y_width > 0); int new_width = ceil_log2(y_width + 1); - if (b_width <= new_width) { - reduced_b_width = b_width; + if (b.width() <= new_width) { return b; } else { - reduced_b_width = new_width; - T lower_b = factory.slice(b, b_width, 0, new_width); - T overflow = factory.unsigned_greater_than(b, factory.constant(RTLIL::Const(y_width, b_width)), b_width); - return factory.mux(lower_b, factory.constant(RTLIL::Const(y_width, new_width)), overflow, new_width); + Node lower_b = factory.slice(b, 0, new_width); + Node overflow = factory.unsigned_greater_than(b, factory.constant(RTLIL::Const(y_width, b.width()))); + return factory.mux(lower_b, factory.constant(RTLIL::Const(y_width, new_width)), overflow); } } - T sign(T a, int a_width) { - return factory.slice(a, a_width, a_width - 1, 1); + Node sign(Node a) { + return factory.slice(a, a.width() - 1, 1); } - T neg_if(T a, int a_width, T s) { - return factory.mux(a, factory.unary_minus(a, a_width), s, a_width); + Node neg_if(Node a, Node s) { + return factory.mux(a, factory.unary_minus(a), s); } - T abs(T a, int a_width) { - return neg_if(a, a_width, sign(a, a_width)); + Node abs(Node a) { + return neg_if(a, sign(a)); } public: - T extend(T a, int in_width, int out_width, bool is_signed) { - if(in_width == out_width) - return a; - if(in_width > out_width) - return factory.slice(a, in_width, 0, out_width); - return factory.extend(a, in_width, out_width, is_signed); - } - T logical_shift_left(T a, T b, int y_width, int b_width) { - int reduced_b_width; - T reduced_b = reduce_shift_width(b, b_width, y_width, reduced_b_width); - return factory.logical_shift_left(a, reduced_b, y_width, reduced_b_width); + Node logical_shift_left(Node a, Node b) { + Node reduced_b = reduce_shift_width(b, a.width()); + return factory.logical_shift_left(a, reduced_b); } - T logical_shift_right(T a, T b, int y_width, int b_width) { - int reduced_b_width; - T reduced_b = reduce_shift_width(b, b_width, y_width, reduced_b_width); - return factory.logical_shift_right(a, reduced_b, y_width, reduced_b_width); + Node logical_shift_right(Node a, Node b) { + Node reduced_b = reduce_shift_width(b, a.width()); + return factory.logical_shift_right(a, reduced_b); } - T arithmetic_shift_right(T a, T b, int y_width, int b_width) { - int reduced_b_width; - T reduced_b = reduce_shift_width(b, b_width, y_width, reduced_b_width); - return factory.arithmetic_shift_right(a, reduced_b, y_width, reduced_b_width); + Node arithmetic_shift_right(Node a, Node b) { + Node reduced_b = reduce_shift_width(b, a.width()); + return factory.arithmetic_shift_right(a, reduced_b); } - T bitwise_mux(T a, T b, T s, int width) { - T aa = factory.bitwise_and(a, factory.bitwise_not(s, width), width); - T bb = factory.bitwise_and(b, s, width); - return factory.bitwise_or(aa, bb, width); + Node bitwise_mux(Node a, Node b, Node s) { + Node aa = factory.bitwise_and(a, factory.bitwise_not(s)); + Node bb = factory.bitwise_and(b, s); + return factory.bitwise_or(aa, bb); } - CellSimplifier(Factory &f) : factory(f) {} + CellSimplifier(FunctionalIR::Factory &f) : factory(f) {} private: - T handle_pow(T a0, int a_width, T b, int b_width, int y_width, bool is_signed) { - T a = extend(a0, a_width, y_width, is_signed); - T r = factory.constant(Const(1, y_width)); - for(int i = 0; i < b_width; i++) { - T b_bit = factory.slice(b, b_width, i, 1); - r = factory.mux(r, factory.mul(r, a, y_width), b_bit, y_width); - a = factory.mul(a, a, y_width); + Node handle_pow(Node a0, Node b, int y_width, bool is_signed) { + Node a = factory.extend(a0, y_width, is_signed); + Node r = factory.constant(Const(1, y_width)); + for(int i = 0; i < b.width(); i++) { + Node b_bit = factory.slice(b, i, 1); + r = factory.mux(r, factory.mul(r, a), b_bit); + a = factory.mul(a, a); } if (is_signed) { - T a_ge_1 = factory.unsigned_greater_than(abs(a0, a_width), factory.constant(Const(1, a_width)), a_width); - T zero_result = factory.bitwise_and(a_ge_1, sign(b, b_width), 1); - r = factory.mux(r, factory.constant(Const(0, y_width)), zero_result, y_width); + Node a_ge_1 = factory.unsigned_greater_than(abs(a0), factory.constant(Const(1, a0.width()))); + Node zero_result = factory.bitwise_and(a_ge_1, sign(b)); + r = factory.mux(r, factory.constant(Const(0, y_width)), zero_result); } return r; } - T handle_bmux(T a, T s, int a_width, int a_offset, int width, int s_width, int sn) { + Node handle_bmux(Node a, Node s, int a_offset, int width, int sn) { if(sn < 1) - return factory.slice(a, a_width, a_offset, width); + return factory.slice(a, a_offset, width); else { - T y0 = handle_bmux(a, s, a_width, a_offset, width, s_width, sn - 1); - T y1 = handle_bmux(a, s, a_width, a_offset + (width << (sn - 1)), width, s_width, sn - 1); - return factory.mux(y0, y1, factory.slice(s, s_width, sn - 1, 1), width); + Node y0 = handle_bmux(a, s, a_offset, width, sn - 1); + Node y1 = handle_bmux(a, s, a_offset + (width << (sn - 1)), width, sn - 1); + return factory.mux(y0, y1, factory.slice(s, sn - 1, 1)); } } public: - T handle(IdString cellType, dict parameters, dict inputs) + Node handle(IdString cellType, dict parameters, dict inputs) { int a_width = parameters.at(ID(A_WIDTH), Const(-1)).as_int(); int b_width = parameters.at(ID(B_WIDTH), Const(-1)).as_int(); @@ -187,208 +175,202 @@ class CellSimplifier { bool b_signed = parameters.at(ID(B_SIGNED), Const(0)).as_bool(); if(cellType.in({ID($add), ID($sub), ID($and), ID($or), ID($xor), ID($xnor), ID($mul)})){ bool is_signed = a_signed && b_signed; - T a = extend(inputs.at(ID(A)), a_width, y_width, is_signed); - T b = extend(inputs.at(ID(B)), b_width, y_width, is_signed); + Node a = factory.extend(inputs.at(ID(A)), y_width, is_signed); + Node b = factory.extend(inputs.at(ID(B)), y_width, is_signed); if(cellType == ID($add)) - return factory.add(a, b, y_width); + return factory.add(a, b); else if(cellType == ID($sub)) - return factory.sub(a, b, y_width); + return factory.sub(a, b); else if(cellType == ID($mul)) - return factory.mul(a, b, y_width); + return factory.mul(a, b); else if(cellType == ID($and)) - return factory.bitwise_and(a, b, y_width); + return factory.bitwise_and(a, b); else if(cellType == ID($or)) - return factory.bitwise_or(a, b, y_width); + return factory.bitwise_or(a, b); else if(cellType == ID($xor)) - return factory.bitwise_xor(a, b, y_width); + return factory.bitwise_xor(a, b); else if(cellType == ID($xnor)) - return factory.bitwise_not(factory.bitwise_xor(a, b, y_width), y_width); + return factory.bitwise_not(factory.bitwise_xor(a, b)); else log_abort(); }else if(cellType.in({ID($eq), ID($ne), ID($eqx), ID($nex), ID($le), ID($lt), ID($ge), ID($gt)})){ bool is_signed = a_signed && b_signed; int width = max(a_width, b_width); - T a = extend(inputs.at(ID(A)), a_width, width, is_signed); - T b = extend(inputs.at(ID(B)), b_width, width, is_signed); + Node a = factory.extend(inputs.at(ID(A)), width, is_signed); + Node b = factory.extend(inputs.at(ID(B)), width, is_signed); if(cellType.in({ID($eq), ID($eqx)})) - return extend(factory.equal(a, b, width), 1, y_width, false); + return factory.extend(factory.equal(a, b), y_width, false); else if(cellType.in({ID($ne), ID($nex)})) - return extend(factory.not_equal(a, b, width), 1, y_width, false); + return factory.extend(factory.not_equal(a, b), y_width, false); else if(cellType == ID($lt)) - return extend(is_signed ? factory.signed_greater_than(b, a, width) : factory.unsigned_greater_than(b, a, width), 1, y_width, false); + return factory.extend(is_signed ? factory.signed_greater_than(b, a) : factory.unsigned_greater_than(b, a), y_width, false); else if(cellType == ID($le)) - return extend(is_signed ? factory.signed_greater_equal(b, a, width) : factory.unsigned_greater_equal(b, a, width), 1, y_width, false); + return factory.extend(is_signed ? factory.signed_greater_equal(b, a) : factory.unsigned_greater_equal(b, a), y_width, false); else if(cellType == ID($gt)) - return extend(is_signed ? factory.signed_greater_than(a, b, width) : factory.unsigned_greater_than(a, b, width), 1, y_width, false); + return factory.extend(is_signed ? factory.signed_greater_than(a, b) : factory.unsigned_greater_than(a, b), y_width, false); else if(cellType == ID($ge)) - return extend(is_signed ? factory.signed_greater_equal(a, b, width) : factory.unsigned_greater_equal(a, b, width), 1, y_width, false); + return factory.extend(is_signed ? factory.signed_greater_equal(a, b) : factory.unsigned_greater_equal(a, b), y_width, false); else log_abort(); }else if(cellType.in({ID($logic_or), ID($logic_and)})){ - T a = factory.reduce_or(inputs.at(ID(A)), a_width); - T b = factory.reduce_or(inputs.at(ID(B)), b_width); - T y = cellType == ID($logic_and) ? factory.bitwise_and(a, b, 1) : factory.bitwise_or(a, b, 1); - return extend(y, 1, y_width, false); + Node a = factory.reduce_or(inputs.at(ID(A))); + Node b = factory.reduce_or(inputs.at(ID(B))); + Node y = cellType == ID($logic_and) ? factory.bitwise_and(a, b) : factory.bitwise_or(a, b); + return factory.extend(y, y_width, false); }else if(cellType == ID($not)){ - T a = extend(inputs.at(ID(A)), a_width, y_width, a_signed); - return factory.bitwise_not(a, y_width); + Node a = factory.extend(inputs.at(ID(A)), y_width, a_signed); + return factory.bitwise_not(a); }else if(cellType == ID($pos)){ - return extend(inputs.at(ID(A)), a_width, y_width, a_signed); + return factory.extend(inputs.at(ID(A)), y_width, a_signed); }else if(cellType == ID($neg)){ - T a = extend(inputs.at(ID(A)), a_width, y_width, a_signed); - return factory.unary_minus(a, y_width); + Node a = factory.extend(inputs.at(ID(A)), y_width, a_signed); + return factory.unary_minus(a); }else if(cellType == ID($logic_not)){ - T a = factory.reduce_or(inputs.at(ID(A)), a_width); - T y = factory.bitwise_not(a, 1); - return extend(y, 1, y_width, false); + Node a = factory.reduce_or(inputs.at(ID(A))); + Node y = factory.bitwise_not(a); + return factory.extend(y, y_width, false); }else if(cellType.in({ID($reduce_or), ID($reduce_bool)})){ - T a = factory.reduce_or(inputs.at(ID(A)), a_width); - return extend(a, 1, y_width, false); + Node a = factory.reduce_or(inputs.at(ID(A))); + return factory.extend(a, y_width, false); }else if(cellType == ID($reduce_and)){ - T a = factory.reduce_and(inputs.at(ID(A)), a_width); - return extend(a, 1, y_width, false); + Node a = factory.reduce_and(inputs.at(ID(A))); + return factory.extend(a, y_width, false); }else if(cellType.in({ID($reduce_xor), ID($reduce_xnor)})){ - T a = factory.reduce_xor(inputs.at(ID(A)), a_width); - T y = cellType == ID($reduce_xnor) ? factory.bitwise_not(a, 1) : a; - return extend(y, 1, y_width, false); + Node a = factory.reduce_xor(inputs.at(ID(A))); + Node y = cellType == ID($reduce_xnor) ? factory.bitwise_not(a) : a; + return factory.extend(y, y_width, false); }else if(cellType == ID($shl) || cellType == ID($sshl)){ - T a = extend(inputs.at(ID(A)), a_width, y_width, a_signed); - T b = inputs.at(ID(B)); - return logical_shift_left(a, b, y_width, b_width); + Node a = factory.extend(inputs.at(ID(A)), y_width, a_signed); + Node b = inputs.at(ID(B)); + return logical_shift_left(a, b); }else if(cellType == ID($shr) || cellType == ID($sshr)){ int width = max(a_width, y_width); - T a = extend(inputs.at(ID(A)), a_width, width, a_signed); - T b = inputs.at(ID(B)); - T y = a_signed && cellType == ID($sshr) ? - arithmetic_shift_right(a, b, width, b_width) : - logical_shift_right(a, b, width, b_width); - return extend(y, width, y_width, a_signed); + Node a = factory.extend(inputs.at(ID(A)), width, a_signed); + Node b = inputs.at(ID(B)); + Node y = a_signed && cellType == ID($sshr) ? + arithmetic_shift_right(a, b) : + logical_shift_right(a, b); + return factory.extend(y, y_width, a_signed); }else if(cellType == ID($shiftx) || cellType == ID($shift)){ int width = max(a_width, y_width); - T a = extend(inputs.at(ID(A)), a_width, width, cellType == ID($shift) && a_signed); - T b = inputs.at(ID(B)); - T shr = logical_shift_right(a, b, width, b_width); + Node a = factory.extend(inputs.at(ID(A)), width, cellType == ID($shift) && a_signed); + Node b = inputs.at(ID(B)); + Node shr = logical_shift_right(a, b); if(b_signed) { - T sign_b = sign(b, b_width); - T shl = logical_shift_left(a, factory.unary_minus(b, b_width), width, b_width); - T y = factory.mux(shr, shl, sign_b, width); - return extend(y, width, y_width, false); + Node shl = logical_shift_left(a, factory.unary_minus(b)); + Node y = factory.mux(shr, shl, sign(b)); + return factory.extend(y, y_width, false); } else { - return extend(shr, width, y_width, false); + return factory.extend(shr, y_width, false); } }else if(cellType == ID($mux)){ - int width = parameters.at(ID(WIDTH)).as_int(); - return factory.mux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S)), width); + return factory.mux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S))); }else if(cellType == ID($pmux)){ - int width = parameters.at(ID(WIDTH)).as_int(); - int s_width = parameters.at(ID(S_WIDTH)).as_int(); - return factory.pmux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S)), width, s_width); + return factory.pmux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S))); }else if(cellType == ID($concat)){ - T a = inputs.at(ID(A)); - T b = inputs.at(ID(B)); - return factory.concat(a, a_width, b, b_width); + Node a = inputs.at(ID(A)); + Node b = inputs.at(ID(B)); + return factory.concat(a, b); }else if(cellType == ID($slice)){ int offset = parameters.at(ID(OFFSET)).as_int(); - T a = inputs.at(ID(A)); - return factory.slice(a, a_width, offset, y_width); + Node a = inputs.at(ID(A)); + return factory.slice(a, offset, y_width); }else if(cellType.in({ID($div), ID($mod), ID($divfloor), ID($modfloor)})) { int width = max(a_width, b_width); bool is_signed = a_signed && b_signed; - T a = extend(inputs.at(ID(A)), a_width, width, is_signed); - T b = extend(inputs.at(ID(B)), b_width, width, is_signed); + Node a = factory.extend(inputs.at(ID(A)), width, is_signed); + Node b = factory.extend(inputs.at(ID(B)), width, is_signed); if(is_signed) { if(cellType == ID($div)) { // divide absolute values, then flip the sign if input signs differ // but extend the width first, to handle the case (most negative value) / (-1) - T abs_y = factory.unsigned_div(abs(a, width), abs(b, width), width); - T out_sign = factory.not_equal(sign(a, width), sign(b, width), 1); - return neg_if(extend(abs_y, width, y_width, false), y_width, out_sign); + Node abs_y = factory.unsigned_div(abs(a), abs(b)); + Node out_sign = factory.not_equal(sign(a), sign(b)); + return neg_if(factory.extend(abs_y, y_width, false), out_sign); } else if(cellType == ID($mod)) { // similar to division but output sign == divisor sign - T abs_y = factory.unsigned_mod(abs(a, width), abs(b, width), width); - return neg_if(extend(abs_y, width, y_width, false), y_width, sign(a, width)); + Node abs_y = factory.unsigned_mod(abs(a), abs(b)); + return neg_if(factory.extend(abs_y, y_width, false), sign(a)); } else if(cellType == ID($divfloor)) { // if b is negative, flip both signs so that b is positive - T b_sign = sign(b, width); - T a1 = neg_if(a, width, b_sign); - T b1 = neg_if(b, width, b_sign); + Node b_sign = sign(b); + Node a1 = neg_if(a, b_sign); + Node b1 = neg_if(b, b_sign); // if a is now negative, calculate ~((~a) / b) = -((-a - 1) / b + 1) // which equals the negative of (-a) / b with rounding up rather than down // note that to handle the case where a = most negative value properly, - // we have to calculate a1_sign from the original values rather than using sign(a1, width) - T a1_sign = factory.bitwise_and(factory.not_equal(sign(a, width), sign(b, width), 1), factory.reduce_or(a, width), 1); - T a2 = factory.mux(a1, factory.bitwise_not(a1, width), a1_sign, width); - T y1 = factory.unsigned_div(a2, b1, width); - T y2 = extend(y1, width, y_width, false); - return factory.mux(y2, factory.bitwise_not(y2, y_width), a1_sign, y_width); + // we have to calculate a1_sign from the original values rather than using sign(a1) + Node a1_sign = factory.bitwise_and(factory.not_equal(sign(a), sign(b)), factory.reduce_or(a)); + Node a2 = factory.mux(a1, factory.bitwise_not(a1), a1_sign); + Node y1 = factory.unsigned_div(a2, b1); + Node y2 = factory.extend(y1, y_width, false); + return factory.mux(y2, factory.bitwise_not(y2), a1_sign); } else if(cellType == ID($modfloor)) { // calculate |a| % |b| and then subtract from |b| if input signs differ and the remainder is non-zero - T abs_b = abs(b, width); - T abs_y = factory.unsigned_mod(abs(a, width), abs_b, width); - T flip_y = factory.bitwise_and(factory.bitwise_xor(sign(a, width), sign(b, width), 1), factory.reduce_or(abs_y, width), 1); - T y_flipped = factory.mux(abs_y, factory.sub(abs_b, abs_y, width), flip_y, width); + Node abs_b = abs(b); + Node abs_y = factory.unsigned_mod(abs(a), abs_b); + Node flip_y = factory.bitwise_and(factory.bitwise_xor(sign(a), sign(b)), factory.reduce_or(abs_y)); + Node y_flipped = factory.mux(abs_y, factory.sub(abs_b, abs_y), flip_y); // since y_flipped is strictly less than |b|, the top bit is always 0 and we can just sign extend the flipped result - T y = neg_if(y_flipped, width, sign(b, b_width)); - return extend(y, width, y_width, true); + Node y = neg_if(y_flipped, sign(b)); + return factory.extend(y, y_width, true); } else log_error("unhandled cell in CellSimplifier %s\n", cellType.c_str()); } else { if(cellType.in({ID($mod), ID($modfloor)})) - return extend(factory.unsigned_mod(a, b, width), width, y_width, false); + return factory.extend(factory.unsigned_mod(a, b), y_width, false); else - return extend(factory.unsigned_div(a, b, width), width, y_width, false); + return factory.extend(factory.unsigned_div(a, b), y_width, false); } } else if(cellType == ID($pow)) { - return handle_pow(inputs.at(ID(A)), a_width, inputs.at(ID(B)), b_width, y_width, a_signed && b_signed); + return handle_pow(inputs.at(ID(A)), inputs.at(ID(B)), y_width, a_signed && b_signed); } else if (cellType == ID($lut)) { int width = parameters.at(ID(WIDTH)).as_int(); Const lut_table = parameters.at(ID(LUT)); lut_table.extu(1 << width); - return handle_bmux(factory.constant(lut_table), inputs.at(ID(A)), 1 << width, 0, 1, width, width); + return handle_bmux(factory.constant(lut_table), inputs.at(ID(A)), 0, 1, width); } else if (cellType == ID($bwmux)) { - int width = parameters.at(ID(WIDTH)).as_int(); - T a = inputs.at(ID(A)); - T b = inputs.at(ID(B)); - T s = inputs.at(ID(S)); + Node a = inputs.at(ID(A)); + Node b = inputs.at(ID(B)); + Node s = inputs.at(ID(S)); return factory.bitwise_or( - factory.bitwise_and(a, factory.bitwise_not(s, width), width), - factory.bitwise_and(b, s, width), width); + factory.bitwise_and(a, factory.bitwise_not(s)), + factory.bitwise_and(b, s)); } else if (cellType == ID($bweqx)) { - int width = parameters.at(ID(WIDTH)).as_int(); - T a = inputs.at(ID(A)); - T b = inputs.at(ID(B)); - return factory.bitwise_not(factory.bitwise_xor(a, b, width), width); + Node a = inputs.at(ID(A)); + Node b = inputs.at(ID(B)); + return factory.bitwise_not(factory.bitwise_xor(a, b)); } else if(cellType == ID($bmux)) { int width = parameters.at(ID(WIDTH)).as_int(); int s_width = parameters.at(ID(S_WIDTH)).as_int(); - return handle_bmux(inputs.at(ID(A)), inputs.at(ID(S)), width << s_width, 0, width, s_width, s_width); + return handle_bmux(inputs.at(ID(A)), inputs.at(ID(S)), 0, width, s_width); } else if(cellType == ID($demux)) { int width = parameters.at(ID(WIDTH)).as_int(); int s_width = parameters.at(ID(S_WIDTH)).as_int(); int y_width = width << s_width; int b_width = ceil_log2(y_width + 1); - T a = extend(inputs.at(ID(A)), width, y_width, false); - T s = factory.extend(inputs.at(ID(S)), s_width, b_width, false); - T b = factory.mul(s, factory.constant(Const(width, b_width)), b_width); - return factory.logical_shift_left(a, b, y_width, b_width); + Node a = factory.extend(inputs.at(ID(A)), y_width, false); + Node s = factory.extend(inputs.at(ID(S)), b_width, false); + Node b = factory.mul(s, factory.constant(Const(width, b_width))); + return factory.logical_shift_left(a, b); } else { log_error("unhandled cell in CellSimplifier %s\n", cellType.c_str()); } } }; -template class FunctionalIRConstruction { + using Node = FunctionalIR::Node; std::deque queue; - dict graph_nodes; + dict graph_nodes; idict cells; DriverMap driver_map; - Factory& factory; - CellSimplifier simplifier; + FunctionalIR::Factory& factory; + CellSimplifier simplifier; vector memories_vector; dict memories; - T enqueue(DriveSpec const &spec) + Node enqueue(DriveSpec const &spec) { auto it = graph_nodes.find(spec); if(it == graph_nodes.end()){ @@ -400,7 +382,7 @@ class FunctionalIRConstruction { return it->second; } public: - FunctionalIRConstruction(Factory &f) : factory(f), simplifier(f) {} + FunctionalIRConstruction(FunctionalIR::Factory &f) : factory(f), simplifier(f) {} void add_module(Module *module) { driver_map.add(module); @@ -410,7 +392,7 @@ class FunctionalIRConstruction { } for (auto wire : module->wires()) { if (wire->port_output) { - T node = enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width))); + Node node = enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width))); factory.declare_output(node, wire->name, wire->width); } } @@ -420,37 +402,34 @@ class FunctionalIRConstruction { memories[mem.cell] = &mem; } } - T concatenate_read_results(Mem *, vector results) + Node concatenate_read_results(Mem *, vector results) { /* TODO: write code to check that this is ok to do */ if(results.size() == 0) return factory.undriven(0); - T node = results[0]; - int size = results[0].width(); - for(size_t i = 1; i < results.size(); i++) { - node = factory.concat(node, size, results[i], results[i].width()); - size += results[i].width(); - } + Node node = results[0]; + for(size_t i = 1; i < results.size(); i++) + node = factory.concat(node, results[i]); return node; } - T handle_memory(Mem *mem) + Node handle_memory(Mem *mem) { - vector read_results; + vector read_results; int addr_width = ceil_log2(mem->size); int data_width = mem->width; - T node = factory.state_memory(mem->cell->name, addr_width, data_width); + Node node = factory.state_memory(mem->cell->name, addr_width, data_width); for (auto &rd : mem->rd_ports) { log_assert(!rd.clk_enable); - T addr = enqueue(driver_map(DriveSpec(rd.addr))); - read_results.push_back(factory.memory_read(node, addr, addr_width, data_width)); + Node addr = enqueue(driver_map(DriveSpec(rd.addr))); + read_results.push_back(factory.memory_read(node, addr)); } for (auto &wr : mem->wr_ports) { - T en = enqueue(driver_map(DriveSpec(wr.en))); - T addr = enqueue(driver_map(DriveSpec(wr.addr))); - T new_data = enqueue(driver_map(DriveSpec(wr.data))); - T old_data = factory.memory_read(node, addr, addr_width, data_width); - T wr_data = simplifier.bitwise_mux(old_data, new_data, en, data_width); - node = factory.memory_write(node, addr, wr_data, addr_width, data_width); + Node en = enqueue(driver_map(DriveSpec(wr.en))); + Node addr = enqueue(driver_map(DriveSpec(wr.addr))); + Node new_data = enqueue(driver_map(DriveSpec(wr.data))); + Node old_data = factory.memory_read(node, addr); + Node wr_data = simplifier.bitwise_mux(old_data, new_data, en); + node = factory.memory_write(node, addr, wr_data); } factory.declare_state_memory(node, mem->cell->name, addr_width, data_width); return concatenate_read_results(mem, read_results); @@ -459,16 +438,13 @@ class FunctionalIRConstruction { { for (; !queue.empty(); queue.pop_front()) { DriveSpec spec = queue.front(); - T pending = graph_nodes.at(spec); + Node pending = graph_nodes.at(spec); if (spec.chunks().size() > 1) { auto chunks = spec.chunks(); - T node = enqueue(chunks[0]); - int width = chunks[0].size(); - for(size_t i = 1; i < chunks.size(); i++) { - node = factory.concat(node, width, enqueue(chunks[i]), chunks[i].size()); - width += chunks[i].size(); - } + Node node = enqueue(chunks[0]); + for(size_t i = 1; i < chunks.size(); i++) + node = factory.concat(node, enqueue(chunks[i])); factory.update_pending(pending, node); } else if (spec.chunks().size() == 1) { DriveChunk chunk = spec.chunks()[0]; @@ -476,18 +452,18 @@ class FunctionalIRConstruction { DriveChunkWire wire_chunk = chunk.wire(); if (wire_chunk.is_whole()) { if (wire_chunk.wire->port_input) { - T node = factory.input(wire_chunk.wire->name, wire_chunk.width); + Node node = factory.input(wire_chunk.wire->name, wire_chunk.width); factory.suggest_name(node, wire_chunk.wire->name); factory.update_pending(pending, node); } else { DriveSpec driver = driver_map(DriveSpec(wire_chunk)); - T node = enqueue(driver); + Node node = enqueue(driver); factory.suggest_name(node, wire_chunk.wire->name); factory.update_pending(pending, node); } } else { DriveChunkWire whole_wire(wire_chunk.wire, 0, wire_chunk.wire->width); - T node = factory.slice(enqueue(whole_wire), wire_chunk.wire->width, wire_chunk.offset, wire_chunk.width); + Node node = factory.slice(enqueue(whole_wire), wire_chunk.offset, wire_chunk.width); factory.update_pending(pending, node); } } else if (chunk.is_port()) { @@ -497,21 +473,22 @@ class FunctionalIRConstruction { if (port_chunk.cell->type.in(ID($dff), ID($ff))) { Cell *cell = port_chunk.cell; - T node = factory.state(cell->name, port_chunk.width); + Node node = factory.state(cell->name, port_chunk.width); factory.suggest_name(node, port_chunk.cell->name); factory.update_pending(pending, node); for (auto const &conn : cell->connections()) { if (driver_map.celltypes.cell_input(cell->type, conn.first)) { - T node = enqueue(DriveChunkPort(cell, conn)); + Node node = enqueue(DriveChunkPort(cell, conn)); factory.declare_state(node, cell->name, port_chunk.width); } } } else { - T cell = enqueue(DriveChunkMarker(cells(port_chunk.cell), 0, port_chunk.width)); + Node cell = enqueue(DriveChunkMarker(cells(port_chunk.cell), 0, port_chunk.width)); factory.suggest_name(cell, port_chunk.cell->name); - T node = factory.cell_output(cell, port_chunk.cell->type, port_chunk.port, port_chunk.width); + //Node node = factory.cell_output(cell, port_chunk.cell->type, port_chunk.port, port_chunk.width); + Node node = cell; factory.suggest_name(node, port_chunk.cell->name.str() + "$" + port_chunk.port.str()); factory.update_pending(pending, node); } @@ -521,37 +498,37 @@ class FunctionalIRConstruction { } } else { DriveChunkPort whole_port(port_chunk.cell, port_chunk.port, 0, GetSize(port_chunk.cell->connections().at(port_chunk.port))); - T node = factory.slice(enqueue(whole_port), whole_port.width, port_chunk.offset, port_chunk.width); + Node node = factory.slice(enqueue(whole_port), port_chunk.offset, port_chunk.width); factory.update_pending(pending, node); } } else if (chunk.is_constant()) { - T node = factory.constant(chunk.constant()); + Node node = factory.constant(chunk.constant()); factory.suggest_name(node, "$const" + std::to_string(chunk.size()) + "b" + chunk.constant().as_string()); factory.update_pending(pending, node); } else if (chunk.is_multiple()) { - vector args; + vector args; for (auto const &driver : chunk.multiple().multiple()) args.push_back(enqueue(driver)); - T node = factory.multiple(args, chunk.size()); + Node node = factory.multiple(args, chunk.size()); factory.update_pending(pending, node); } else if (chunk.is_marker()) { Cell *cell = cells[chunk.marker().marker]; if (cell->is_mem_cell()) { Mem *mem = memories.at(cell, nullptr); log_assert(mem != nullptr); - T node = handle_memory(mem); + Node node = handle_memory(mem); factory.update_pending(pending, node); } else { - dict connections; + dict connections; for(auto const &conn : cell->connections()) { if(driver_map.celltypes.cell_input(cell->type, conn.first)) connections.insert({ conn.first, enqueue(DriveChunkPort(cell, conn)) }); } - T node = simplifier.handle(cell->type, cell->parameters, connections); + Node node = simplifier.handle(cell->type, cell->parameters, connections); factory.update_pending(pending, node); } } else if (chunk.is_none()) { - T node = factory.undriven(chunk.size()); + Node node = factory.undriven(chunk.size()); factory.update_pending(pending, node); } else { log_error("unhandled drivespec: %s\n", log_signal(chunk)); @@ -567,7 +544,7 @@ class FunctionalIRConstruction { FunctionalIR FunctionalIR::from_module(Module *module) { FunctionalIR ir; auto factory = ir.factory(); - FunctionalIRConstruction ctor(factory); + FunctionalIRConstruction ctor(factory); ctor.add_module(module); ctor.process_queue(); ir.topological_sort(); diff --git a/kernel/functionalir.h b/kernel/functionalir.h index e586695695a..44e3589db2d 100644 --- a/kernel/functionalir.h +++ b/kernel/functionalir.h @@ -385,78 +385,79 @@ class FunctionalIR { void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal()); } void check_unary(Node const &a) { log_assert(a.sort().is_signal()); } public: - Node slice(Node a, int, int offset, int out_width) { + Node slice(Node a, int offset, int out_width) { log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width()); if(offset == 0 && out_width == a.width()) return a; return add(NodeData(Fn::slice, offset), Sort(out_width), {a}); } - Node extend(Node a, int, int out_width, bool is_signed) { + // extend will either extend or truncate the provided value to reach the desired width + Node extend(Node a, int out_width, bool is_signed) { int in_width = a.sort().width(); log_assert(a.sort().is_signal()); if(in_width == out_width) return a; - if(in_width < out_width) - return slice(a, in_width, 0, out_width); + if(in_width > out_width) + return slice(a, 0, out_width); if(is_signed) return add(Fn::sign_extend, Sort(out_width), {a}); else return add(Fn::zero_extend, Sort(out_width), {a}); } - Node concat(Node a, int, Node b, int) { + Node concat(Node a, Node b) { log_assert(a.sort().is_signal() && b.sort().is_signal()); return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b}); } - Node add(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); } - Node sub(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); } - Node mul(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); } - Node unsigned_div(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); } - Node unsigned_mod(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); } - Node bitwise_and(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); } - Node bitwise_or(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); } - Node bitwise_xor(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); } - Node bitwise_not(Node a, int) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); } - Node unary_minus(Node a, int) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); } - Node reduce_and(Node a, int) { + Node add(Node a, Node b) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); } + Node sub(Node a, Node b) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); } + Node mul(Node a, Node b) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); } + Node unsigned_div(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); } + Node unsigned_mod(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); } + Node bitwise_and(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); } + Node bitwise_or(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); } + Node bitwise_xor(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); } + Node bitwise_not(Node a) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); } + Node unary_minus(Node a) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); } + Node reduce_and(Node a) { check_unary(a); if(a.width() == 1) return a; return add(Fn::reduce_and, Sort(1), {a}); } - Node reduce_or(Node a, int) { + Node reduce_or(Node a) { check_unary(a); if(a.width() == 1) return a; return add(Fn::reduce_or, Sort(1), {a}); } - Node reduce_xor(Node a, int) { + Node reduce_xor(Node a) { check_unary(a); if(a.width() == 1) return a; return add(Fn::reduce_xor, Sort(1), {a}); } - Node equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); } - Node not_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); } - Node signed_greater_than(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); } - Node signed_greater_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); } - Node unsigned_greater_than(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); } - Node unsigned_greater_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); } - Node logical_shift_left(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); } - Node logical_shift_right(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); } - Node arithmetic_shift_right(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); } - Node mux(Node a, Node b, Node s, int) { + Node equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); } + Node not_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); } + Node signed_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); } + Node signed_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); } + Node unsigned_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); } + Node unsigned_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); } + Node logical_shift_left(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); } + Node logical_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); } + Node arithmetic_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); } + Node mux(Node a, Node b, Node s) { log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1)); return add(Fn::mux, a.sort(), {a, b, s}); } - Node pmux(Node a, Node b, Node s, int, int) { + Node pmux(Node a, Node b, Node s) { log_assert(a.sort().is_signal() && b.sort().is_signal() && s.sort().is_signal() && a.sort().width() * s.sort().width() == b.sort().width()); return add(Fn::pmux, a.sort(), {a, b, s}); } - Node memory_read(Node mem, Node addr, int, int) { + Node memory_read(Node mem, Node addr) { log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width()); return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr}); } - Node memory_write(Node mem, Node addr, Node data, int, int) { + Node memory_write(Node mem, Node addr, Node data) { log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() && mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width()); return add(Fn::memory_write, mem.sort(), {mem, addr, data}); @@ -484,9 +485,6 @@ class FunctionalIR { _ir.add_state(name, Sort(addr_width, data_width)); return add(NodeData(Fn::state, name), Sort(addr_width, data_width), {}); } - Node cell_output(Node node, IdString, IdString, int) { - return node; - } Node multiple(vector args, int width) { auto node = add(Fn::multiple, Sort(width), {}); for(const auto &arg : args)