Skip to content

Commit

Permalink
fix a few bugs in the functional backend and refactor the testing
Browse files Browse the repository at this point in the history
  • Loading branch information
aiju committed Jul 16, 2024
1 parent 4e2bcd3 commit 5800801
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 271 deletions.
40 changes: 18 additions & 22 deletions kernel/functionalir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,26 +226,36 @@ class CellSimplifier {
T b = extend(inputs.at(ID(B)), b_width, width, is_signed);
if(is_signed) {
if(cellType == ID($div)) {
// divide absolute values, then flip the sign if input signs differ
// but extend the width first, to handle the case (most negative value) / (-1)
T abs_y = factory.unsigned_div(abs(a, width), abs(b, width), width);
T out_sign = factory.not_equal(sign(a, width), sign(b, width), 1);
return neg_if(extend(abs_y, width, y_width, true), y_width, out_sign);
return neg_if(extend(abs_y, width, y_width, false), y_width, out_sign);
} else if(cellType == ID($mod)) {
// similar to division but output sign == divisor sign
T abs_y = factory.unsigned_mod(abs(a, width), abs(b, width), width);
return neg_if(extend(abs_y, width, y_width, true), y_width, sign(a, width));
return neg_if(extend(abs_y, width, y_width, false), y_width, sign(a, width));
} else if(cellType == ID($divfloor)) {
// if b is negative, flip both signs so that b is positive
T b_sign = sign(b, width);
T a1 = neg_if(a, width, b_sign);
T b1 = neg_if(b, width, b_sign);
T a1_sign = sign(a1, width);
// if a is now negative, calculate ~((~a) / b) = -((-a - 1) / b + 1)
// which equals the negative of (-a) / b with rounding up rather than down
// note that to handle the case where a = most negative value properly,
// we have to calculate a1_sign from the original values rather than using sign(a1, width)
T a1_sign = factory.bitwise_and(factory.not_equal(sign(a, width), sign(b, width), 1), reduce_or(a, width), 1);
T a2 = factory.mux(a1, factory.bitwise_not(a1, width), a1_sign, width);
T y1 = factory.unsigned_div(a2, b1, width);
T y2 = factory.mux(y1, factory.bitwise_not(y1, width), a1_sign, width);
return extend(y2, width, y_width, true);
T y2 = extend(y1, width, y_width, false);
return factory.mux(y2, factory.bitwise_not(y2, y_width), a1_sign, y_width);
} else if(cellType == ID($modfloor)) {
// calculate |a| % |b| and then subtract from |b| if input signs differ and the remainder is non-zero
T abs_b = abs(b, width);
T abs_y = factory.unsigned_mod(abs(a, width), abs_b, width);
T flip_y = factory.bitwise_and(factory.bitwise_xor(sign(a, width), sign(b, width), 1), factory.reduce_or(abs_y, width), 1);
T y_flipped = factory.mux(abs_y, factory.sub(abs_b, abs_y, width), flip_y, width);
// since y_flipped is strictly less than |b|, the top bit is always 0 and we can just sign extend the flipped result
T y = neg_if(y_flipped, width, sign(b, b_width));
return extend(y, width, y_width, true);
} else
Expand All @@ -261,22 +271,8 @@ class CellSimplifier {
} else if (cellType == ID($lut)) {
int width = parameters.at(ID(WIDTH)).as_int();
Const lut_table = parameters.at(ID(LUT));
T a = inputs.at(ID(A));
// Output initialization
T y = factory.constant(Const(0, 1));
// Iterate over each possible input combination
for (int i = 0; i < (1 << width); ++i) {
// Create a constant representing the value of i
T i_val = factory.constant(Const(i, width));
// Check if the input matches this value
T match = factory.equal(a, i_val, width);
// Get the corresponding LUT value
bool lut_val = lut_table.bits[i] == State::S1;
T lut_output = factory.constant(Const(lut_val, 1));
// Use a multiplexer to select the correct output based on the match
y = factory.mux(y, lut_output, match, 1);
}
return y;
lut_table.extu(1 << width);
return handle_bmux(factory.constant(lut_table), inputs.at(ID(A)), 1 << width, 0, 1, width, width);
} else if (cellType == ID($bwmux)) {
int width = parameters.at(ID(WIDTH)).as_int();
T a = inputs.at(ID(A));
Expand Down Expand Up @@ -526,7 +522,7 @@ void FunctionalIR::topological_sort() {
if(scc) log_error("combinational loops, aborting\n");
}

IdString merge_name(IdString a, IdString b) {
static IdString merge_name(IdString a, IdString b) {
if(a[0] == '$' && b[0] == '\\')
return b;
else
Expand Down
26 changes: 21 additions & 5 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
import pytest
from rtlil_cells import generate_test_cases
import random

random_seed = random.getrandbits(32)

def pytest_addoption(parser):
parser.addoption(
"--per-cell", type=int, default=None, help="run only N tests per cell"
)
parser.addoption("--per-cell", type=int, default=None, help="run only N tests per cell")
parser.addoption("--steps", type=int, default=1000, help="run each test for N steps")
parser.addoption("--seed", type=int, default=random_seed, help="seed for random number generation, use random seed if unspecified")

def pytest_collection_finish(session):
print('random seed: {}'.format(session.config.getoption("seed")))

@pytest.fixture
def num_steps(request):
return request.config.getoption("steps")

@pytest.fixture
def rnd(request):
seed1 = request.config.getoption("seed")
return lambda seed2: random.Random('{}-{}'.format(seed1, seed2))

def pytest_generate_tests(metafunc):
if "cell" in metafunc.fixturenames:
print(dir(metafunc.config))
per_cell = metafunc.config.getoption("per_cell", default=None)
names, cases = generate_test_cases(per_cell)
seed1 = metafunc.config.getoption("seed")
rnd = lambda seed2: random.Random('{}-{}'.format(seed1, seed2))
names, cases = generate_test_cases(per_cell, rnd)
metafunc.parametrize("cell,parameters", cases, ids=names)
Loading

0 comments on commit 5800801

Please sign in to comment.