Skip to content

Commit

Permalink
Remove int_power and use built-in 2 ^
Browse files Browse the repository at this point in the history
Sail has built-in support for `2 ^`, and this is the only exponent we use so there's no need for the more generic `int_power`. Additionally the type of `int_power` is both way looser than `2 ^` (which is understood at the type level), and actually wrong, e.g. it will let you do `3 ^ -1` and give the result as 3.

Note I had to add spaces because this was required until a very recent Sail version - see rems-project/sail#657
  • Loading branch information
Timmmm committed Aug 27, 2024
1 parent 05b845c commit f0448b9
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 33 deletions.
4 changes: 1 addition & 3 deletions model/prelude.sail
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ function bit_str(b: bit) -> string =

overload BitStr = {bits_str, bit_str}

val int_power = {ocaml: "int_power", interpreter: "int_power", lem: "pow", coq: "pow", c: "pow_int"} : (int, int) -> int

overload operator ^ = {xor_vec, int_power, concat_str}
overload operator ^ = {xor_vec, concat_str}

val sub_vec = {c: "sub_bits", _: "sub_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n)

Expand Down
2 changes: 1 addition & 1 deletion model/riscv_insts_mext.sail
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ function clause execute (DIVW(rs2, rs1, rd, s)) = {
let rs2_int : int = if s then signed(rs2_val) else unsigned(rs2_val);
let q : int = if rs2_int == 0 then -1 else quot_round_zero(rs1_int, rs2_int);
/* check for signed overflow */
let q': int = if s & q > (2 ^ 31 - 1) then (0 - 2^31) else q;
let q': int = if s & q > (2 ^ 31 - 1) then (0 - (2 ^ 31)) else q;
X(rd) = sign_extend(to_bits(32, q'));
RETIRE_SUCCESS
}
Expand Down
12 changes: 6 additions & 6 deletions model/riscv_insts_vext_arith.sail
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function clause execute(VVTYPE(funct6, vm, vs2, vs1, vd)) = {
VV_VRGATHER => {
if (vs1 == vd | vs2 == vd) then { handle_illegal(); return RETIRE_FAIL };
let idx = unsigned(vs1_val[i]);
let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow);
let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow);
assert(VLMAX <= 'n);
if idx < VLMAX then vs2_val[idx] else zeros()
},
Expand All @@ -125,7 +125,7 @@ function clause execute(VVTYPE(funct6, vm, vs2, vs1, vd)) = {
/* vrgatherei16.vv uses SEW/LMUL for the data in vs2 but EEW=16 and EMUL = (16/SEW)*LMUL for the indices in vs1 */
let vs1_new : vector('n, dec, bits(16)) = read_vreg(num_elem, 16, 4 + LMUL_pow - SEW_pow, vs1);
let idx = unsigned(vs1_new[i]);
let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow);
let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow);
assert(VLMAX <= 'n);
if idx < VLMAX then vs2_val[idx] else zeros()
}
Expand Down Expand Up @@ -694,13 +694,13 @@ function clause execute(VXSG(funct6, vm, vs2, rs1, vd)) = {
if i >= rs1_val then vs2_val[i - rs1_val] else vd_val[i]
},
VX_VSLIDEDOWN => {
let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow);
let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow);
assert(VLMAX > 0 & VLMAX <= 'n);
if i + rs1_val < VLMAX then vs2_val[i + rs1_val] else zeros()
},
VX_VRGATHER => {
if (vs2 == vd) then { handle_illegal(); return RETIRE_FAIL };
let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow);
let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow);
assert(VLMAX > 0 & VLMAX <= 'n);
if rs1_val < VLMAX then vs2_val[rs1_val] else zeros()
}
Expand Down Expand Up @@ -1084,13 +1084,13 @@ function clause execute(VISG(funct6, vm, vs2, simm, vd)) = {
if i >= imm_val then vs2_val[i - imm_val] else vd_val[i]
},
VI_VSLIDEDOWN => {
let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow);
let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow);
assert(VLMAX > 0 & VLMAX <= 'n);
if i + imm_val < VLMAX then vs2_val[i + imm_val] else zeros()
},
VI_VRGATHER => {
if (vs2 == vd) then { handle_illegal(); return RETIRE_FAIL };
let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow);
let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow);
assert(VLMAX > 0 & VLMAX <= 'n);
if imm_val < VLMAX then vs2_val[imm_val] else zeros()
}
Expand Down
4 changes: 2 additions & 2 deletions model/riscv_insts_vext_fp_utils.sail
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ function rsqrt7 (v, sub) = {
};
assert(idx >= 0 & idx < 128);
let out_sig = to_bits(s, table[(127 - idx)]) << (s - 7);
let out_exp = to_bits(e, (3 * (2^(e - 1) - 1) - 1 - signed(normalized_exp)) / 2);
let out_exp = to_bits(e, (3 * (2 ^ (e - 1) - 1) - 1 - signed(normalized_exp)) / 2);
zero_extend(64, sign @ out_exp @ out_sig)
}

Expand Down Expand Up @@ -593,7 +593,7 @@ function recip7 (v, rm_3b, sub) = {
64 => unsigned(normalized_sig[51 .. 45])
};
assert(idx >= 0 & idx < 128);
let mid_exp = to_bits(e, 2 * (2^(e - 1) - 1) - 1 - signed(normalized_exp));
let mid_exp = to_bits(e, 2 * (2 ^ (e - 1) - 1) - 1 - signed(normalized_exp));
let mid_sig = to_bits(s, table[(127 - idx)]) << (s - 7);

let (out_exp, out_sig)=
Expand Down
14 changes: 7 additions & 7 deletions model/riscv_insts_vext_mem.sail
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ mapping clause encdec = VLSEGTYPE(nf, vm, rs1, width, vd) if extensionEnabled(Ex

val process_vlseg : forall 'f 'b 'n 'p, (0 < 'f & 'f <= 8) & ('b in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('b), regidx, int('p), int('n)) -> Retired
function process_vlseg (nf, vm, vd, load_width_bytes, rs1, EMUL_pow, num_elem) = {
let EMUL_reg : int = if EMUL_pow <= 0 then 1 else int_power(2, EMUL_pow);
let EMUL_reg : int = if EMUL_pow <= 0 then 1 else 2 ^ (EMUL_pow);
let width_type : word_width = size_bytes(load_width_bytes);
let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000);
let vd_seg : vector('n, dec, bits('f * 'b * 8)) = read_vreg_seg(num_elem, load_width_bytes * 8, EMUL_pow, nf, vd);
Expand Down Expand Up @@ -135,7 +135,7 @@ mapping clause encdec = VLSEGFFTYPE(nf, vm, rs1, width, vd) if extensionEnabled(

val process_vlsegff : forall 'f 'b 'n 'p, (0 < 'f & 'f <= 8) & ('b in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('b), regidx, int('p), int('n)) -> Retired
function process_vlsegff (nf, vm, vd, load_width_bytes, rs1, EMUL_pow, num_elem) = {
let EMUL_reg : int = if EMUL_pow <= 0 then 1 else int_power(2, EMUL_pow);
let EMUL_reg : int = if EMUL_pow <= 0 then 1 else 2 ^ (EMUL_pow);
let width_type : word_width = size_bytes(load_width_bytes);
let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000);
let vd_seg : vector('n, dec, bits('f * 'b * 8)) = read_vreg_seg(num_elem, load_width_bytes * 8, EMUL_pow, nf, vd);
Expand Down Expand Up @@ -240,7 +240,7 @@ mapping clause encdec = VSSEGTYPE(nf, vm, rs1, width, vs3) if extensionEnabled(E

val process_vsseg : forall 'f 'b 'n 'p, (0 < 'f & 'f <= 8) & ('b in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('b), regidx, int('p), int('n)) -> Retired
function process_vsseg (nf, vm, vs3, load_width_bytes, rs1, EMUL_pow, num_elem) = {
let EMUL_reg : int = if EMUL_pow <= 0 then 1 else int_power(2, EMUL_pow);
let EMUL_reg : int = if EMUL_pow <= 0 then 1 else 2 ^ (EMUL_pow);
let width_type : word_width = size_bytes(load_width_bytes);
let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000);
let vs3_seg : vector('n, dec, bits('f * 'b * 8)) = read_vreg_seg(num_elem, load_width_bytes * 8, EMUL_pow, nf, vs3);
Expand Down Expand Up @@ -309,7 +309,7 @@ mapping clause encdec = VLSSEGTYPE(nf, vm, rs2, rs1, width, vd) if extensionEnab

val process_vlsseg : forall 'f 'b 'n 'p, (0 < 'f & 'f <= 8) & ('b in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('b), regidx, regidx, int('p), int('n)) -> Retired
function process_vlsseg (nf, vm, vd, load_width_bytes, rs1, rs2, EMUL_pow, num_elem) = {
let EMUL_reg : int = if EMUL_pow <= 0 then 1 else int_power(2, EMUL_pow);
let EMUL_reg : int = if EMUL_pow <= 0 then 1 else 2 ^ (EMUL_pow);
let width_type : word_width = size_bytes(load_width_bytes);
let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000);
let vd_seg : vector('n, dec, bits('f * 'b * 8)) = read_vreg_seg(num_elem, load_width_bytes * 8, EMUL_pow, nf, vd);
Expand Down Expand Up @@ -376,7 +376,7 @@ mapping clause encdec = VSSSEGTYPE(nf, vm, rs2, rs1, width, vs3) if extensionEna

val process_vssseg : forall 'f 'b 'n 'p, (0 < 'f & 'f <= 8) & ('b in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('b), regidx, regidx, int('p), int('n)) -> Retired
function process_vssseg (nf, vm, vs3, load_width_bytes, rs1, rs2, EMUL_pow, num_elem) = {
let EMUL_reg : int = if EMUL_pow <= 0 then 1 else int_power(2, EMUL_pow);
let EMUL_reg : int = if EMUL_pow <= 0 then 1 else 2 ^ (EMUL_pow);
let width_type : word_width = size_bytes(load_width_bytes);
let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000);
let vs3_seg : vector('n, dec, bits('f * 'b * 8)) = read_vreg_seg(num_elem, load_width_bytes * 8, EMUL_pow, nf, vs3);
Expand Down Expand Up @@ -446,7 +446,7 @@ mapping clause encdec = VLUXSEGTYPE(nf, vm, vs2, rs1, width, vd) if extensionEna

val process_vlxseg : forall 'f 'ib 'db 'ip 'dp 'n, (0 < 'f & 'f <= 8) & ('ib in {1, 2, 4, 8}) & ('db in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('ib), int('db), int('ip), int('dp), regidx, regidx, int('n), int) -> Retired
function process_vlxseg (nf, vm, vd, EEW_index_bytes, EEW_data_bytes, EMUL_index_pow, EMUL_data_pow, rs1, vs2, num_elem, mop) = {
let EMUL_data_reg : int = if EMUL_data_pow <= 0 then 1 else int_power(2, EMUL_data_pow);
let EMUL_data_reg : int = if EMUL_data_pow <= 0 then 1 else 2 ^ (EMUL_data_pow);
let width_type : word_width = size_bytes(EEW_data_bytes);
let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000);
let vd_seg : vector('n, dec, bits('f * 'db * 8)) = read_vreg_seg(num_elem, EEW_data_bytes * 8, EMUL_data_pow, nf, vd);
Expand Down Expand Up @@ -538,7 +538,7 @@ mapping clause encdec = VSUXSEGTYPE(nf, vm, vs2, rs1, width, vs3) if extensionEn

val process_vsxseg : forall 'f 'ib 'db 'ip 'dp 'n, (0 < 'f & 'f <= 8) & ('ib in {1, 2, 4, 8}) & ('db in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('ib), int('db), int('ip), int('dp), regidx, regidx, int('n), int) -> Retired
function process_vsxseg (nf, vm, vs3, EEW_index_bytes, EEW_data_bytes, EMUL_index_pow, EMUL_data_pow, rs1, vs2, num_elem, mop) = {
let EMUL_data_reg : int = if EMUL_data_pow <= 0 then 1 else int_power(2, EMUL_data_pow);
let EMUL_data_reg : int = if EMUL_data_pow <= 0 then 1 else 2 ^ (EMUL_data_pow);
let width_type : word_width = size_bytes(EEW_data_bytes);
let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000);
let vs3_seg : vector('n, dec, bits('f * 'db * 8)) = read_vreg_seg(num_elem, EEW_data_bytes * 8, EMUL_data_pow, nf, vs3);
Expand Down
20 changes: 10 additions & 10 deletions model/riscv_insts_vext_utils.sail
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mapping maybe_vmask : string <-> bits(1) = {
*/
val valid_eew_emul : (int, int) -> bool
function valid_eew_emul(EEW, EMUL_pow) = {
let ELEN = int_power(2, get_elen_pow());
let ELEN = 2 ^ get_elen_pow();
EEW >= 8 & EEW <= ELEN & EMUL_pow >= -3 & EMUL_pow <= 3
}

Expand Down Expand Up @@ -60,8 +60,8 @@ function valid_rd_mask(rd, vm) = {
*/
val valid_reg_overlap : (regidx, regidx, int, int) -> bool
function valid_reg_overlap(rs, rd, EMUL_pow_rs, EMUL_pow_rd) = {
let rs_group = if EMUL_pow_rs > 0 then int_power(2, EMUL_pow_rs) else 1;
let rd_group = if EMUL_pow_rd > 0 then int_power(2, EMUL_pow_rd) else 1;
let rs_group = if EMUL_pow_rs > 0 then 2 ^ (EMUL_pow_rs) else 1;
let rd_group = if EMUL_pow_rd > 0 then 2 ^ (EMUL_pow_rd) else 1;
let rs_int = unsigned(rs);
let rd_int = unsigned(rd);
if EMUL_pow_rs < EMUL_pow_rd then {
Expand All @@ -78,8 +78,8 @@ function valid_reg_overlap(rs, rd, EMUL_pow_rs, EMUL_pow_rd) = {
*/
val valid_segment : (int, int) -> bool
function valid_segment(nf, EMUL_pow) = {
if EMUL_pow < 0 then nf / int_power(2, 0 - EMUL_pow) <= 8
else nf * int_power(2, EMUL_pow) <= 8
if EMUL_pow < 0 then nf / (2 ^ (0 - EMUL_pow)) <= 8
else nf * 2 ^ (EMUL_pow) <= 8
}

/* ******************************************************************************* */
Expand Down Expand Up @@ -209,7 +209,7 @@ function init_masked_result(num_elem, SEW, LMUL_pow, vd_val, vm_val) = {
result : vector('n, dec, bits('m)) = undefined;

/* Determine the actual number of elements when lmul < 1 */
let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / int_power(2, 0 - LMUL_pow);
let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / (2 ^ (0 - LMUL_pow));
assert(num_elem >= real_num_elem);

foreach (i from 0 to (num_elem - 1)) {
Expand Down Expand Up @@ -259,7 +259,7 @@ function init_masked_source(num_elem, LMUL_pow, vm_val) = {
mask : vector('n, dec, bool) = undefined;

/* Determine the actual number of elements when lmul < 1 */
let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / int_power(2, 0 - LMUL_pow);
let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / (2 ^ (0 - LMUL_pow));
assert(num_elem >= real_num_elem);

foreach (i from 0 to (num_elem - 1)) {
Expand Down Expand Up @@ -294,7 +294,7 @@ function init_masked_result_carry(num_elem, SEW, LMUL_pow, vd_val) = {
result : vector('n, dec, bool) = undefined;

/* Determine the actual number of elements when lmul < 1 */
let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / int_power(2, 0 - LMUL_pow);
let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / (2 ^ (0 - LMUL_pow));
assert(num_elem >= real_num_elem);

foreach (i from 0 to (num_elem - 1)) {
Expand Down Expand Up @@ -331,7 +331,7 @@ function init_masked_result_cmp(num_elem, SEW, LMUL_pow, vd_val, vm_val) = {
result : vector('n, dec, bool) = undefined;

/* Determine the actual number of elements when lmul < 1 */
let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / int_power(2, 0 - LMUL_pow);
let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / (2 ^ (0 - LMUL_pow));
assert(num_elem >= real_num_elem);

foreach (i from 0 to (num_elem - 1)) {
Expand Down Expand Up @@ -372,7 +372,7 @@ function init_masked_result_cmp(num_elem, SEW, LMUL_pow, vd_val, vm_val) = {
val read_vreg_seg : forall 'n 'm 'p 'q, 'n >= 0 & 'q >= 0. (int('n), int('m), int('p), int('q), regidx) -> vector('n, dec, bits('q * 'm))
function read_vreg_seg(num_elem, SEW, LMUL_pow, nf, vrid) = {
assert('q * 'm > 0);
let LMUL_reg : int = if LMUL_pow <= 0 then 1 else int_power(2, LMUL_pow);
let LMUL_reg : int = if LMUL_pow <= 0 then 1 else 2 ^ (LMUL_pow);
vreg_list : vector('q, dec, vector('n, dec, bits('m))) = undefined;
result : vector('n, dec, bits('q * 'm)) = undefined;
foreach (j from 0 to (nf - 1)) {
Expand Down
6 changes: 3 additions & 3 deletions model/riscv_insts_vext_vset.sail
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function clause execute VSETVLI(ma, ta, sew, lmul, rs1, rd) = {
let LMUL_pow_new = get_lmul_pow();
let SEW_pow_new = get_sew_pow();
if SEW_pow_new > (LMUL_pow_new + ELEN_pow) then { handle_illegal_vtype(); return RETIRE_SUCCESS };
let VLMAX = int_power(2, VLEN_pow + LMUL_pow_new - SEW_pow_new);
let VLMAX = 2 ^ (VLEN_pow + LMUL_pow_new - SEW_pow_new);

/* set vl according to VLMAX and AVL */
if (rs1 != 0b00000) then { /* normal stripmining */
Expand Down Expand Up @@ -136,7 +136,7 @@ function clause execute VSETVL(rs2, rs1, rd) = {
let LMUL_pow_new = get_lmul_pow();
let SEW_pow_new = get_sew_pow();
if SEW_pow_new > (LMUL_pow_new + ELEN_pow) then { handle_illegal_vtype(); return RETIRE_SUCCESS };
let VLMAX = int_power(2, VLEN_pow + LMUL_pow_new - SEW_pow_new);
let VLMAX = 2 ^ (VLEN_pow + LMUL_pow_new - SEW_pow_new);

/* set vl according to VLMAX and AVL */
if (rs1 != 0b00000) then { /* normal stripmining */
Expand Down Expand Up @@ -183,7 +183,7 @@ function clause execute VSETIVLI(ma, ta, sew, lmul, uimm, rd) = {
let LMUL_pow_new = get_lmul_pow();
let SEW_pow_new = get_sew_pow();
if SEW_pow_new > (LMUL_pow_new + ELEN_pow) then { handle_illegal_vtype(); return RETIRE_SUCCESS };
let VLMAX = int_power(2, VLEN_pow + LMUL_pow_new - SEW_pow_new);
let VLMAX = 2 ^ (VLEN_pow + LMUL_pow_new - SEW_pow_new);

/* set vl according to VLMAX and AVL */
let AVL = unsigned(uimm); /* AVL is encoded as 5-bit zero-extended imm in the rs1 field */
Expand Down
2 changes: 1 addition & 1 deletion model/riscv_vext_regs.sail
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ function get_num_elem(LMUL_pow, SEW) = {
let LMUL_pow_reg = if LMUL_pow < 0 then 0 else LMUL_pow;
/* Ignore lmul < 1 so that the entire vreg is read, allowing all masking to
* be handled in init_masked_result */
let num_elem = int_power(2, LMUL_pow_reg) * VLEN / SEW;
let num_elem = 2 ^ (LMUL_pow_reg) * VLEN / SEW;
assert(num_elem > 0);
num_elem
}
Expand Down

0 comments on commit f0448b9

Please sign in to comment.