From f0448b9a6ac3519cf7b5c057b6227e8909426462 Mon Sep 17 00:00:00 2001 From: Tim Hutt Date: Thu, 8 Aug 2024 10:50:00 +0100 Subject: [PATCH] Remove int_power and use built-in 2 ^ Sail has built-in support for `2 ^`, and this is the only exponent we use so there's no need for the more generic `int_power`. Additionally the type of `int_power` is both way looser than `2 ^` (which is understood at the type level), and actually wrong, e.g. it will let you do `3 ^ -1` and give the result as 3. Note I had to add spaces because this was required until a very recent Sail version - see https://github.com/rems-project/sail/issues/657 --- model/prelude.sail | 4 +--- model/riscv_insts_mext.sail | 2 +- model/riscv_insts_vext_arith.sail | 12 ++++++------ model/riscv_insts_vext_fp_utils.sail | 4 ++-- model/riscv_insts_vext_mem.sail | 14 +++++++------- model/riscv_insts_vext_utils.sail | 20 ++++++++++---------- model/riscv_insts_vext_vset.sail | 6 +++--- model/riscv_vext_regs.sail | 2 +- 8 files changed, 31 insertions(+), 33 deletions(-) diff --git a/model/prelude.sail b/model/prelude.sail index 45d49c0ed..2f6ae1b32 100644 --- a/model/prelude.sail +++ b/model/prelude.sail @@ -39,9 +39,7 @@ function bit_str(b: bit) -> string = overload BitStr = {bits_str, bit_str} -val int_power = {ocaml: "int_power", interpreter: "int_power", lem: "pow", coq: "pow", c: "pow_int"} : (int, int) -> int - -overload operator ^ = {xor_vec, int_power, concat_str} +overload operator ^ = {xor_vec, concat_str} val sub_vec = {c: "sub_bits", _: "sub_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n) diff --git a/model/riscv_insts_mext.sail b/model/riscv_insts_mext.sail index b153e792b..5c70eb20c 100644 --- a/model/riscv_insts_mext.sail +++ b/model/riscv_insts_mext.sail @@ -138,7 +138,7 @@ function clause execute (DIVW(rs2, rs1, rd, s)) = { let rs2_int : int = if s then signed(rs2_val) else unsigned(rs2_val); let q : int = if rs2_int == 0 then -1 else quot_round_zero(rs1_int, rs2_int); /* check for signed overflow */ - let q': int = if s & q > (2 ^ 31 - 1) then (0 - 2^31) else q; + let q': int = if s & q > (2 ^ 31 - 1) then (0 - (2 ^ 31)) else q; X(rd) = sign_extend(to_bits(32, q')); RETIRE_SUCCESS } diff --git a/model/riscv_insts_vext_arith.sail b/model/riscv_insts_vext_arith.sail index b351b01bb..14700ca34 100644 --- a/model/riscv_insts_vext_arith.sail +++ b/model/riscv_insts_vext_arith.sail @@ -116,7 +116,7 @@ function clause execute(VVTYPE(funct6, vm, vs2, vs1, vd)) = { VV_VRGATHER => { if (vs1 == vd | vs2 == vd) then { handle_illegal(); return RETIRE_FAIL }; let idx = unsigned(vs1_val[i]); - let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow); + let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow); assert(VLMAX <= 'n); if idx < VLMAX then vs2_val[idx] else zeros() }, @@ -125,7 +125,7 @@ function clause execute(VVTYPE(funct6, vm, vs2, vs1, vd)) = { /* vrgatherei16.vv uses SEW/LMUL for the data in vs2 but EEW=16 and EMUL = (16/SEW)*LMUL for the indices in vs1 */ let vs1_new : vector('n, dec, bits(16)) = read_vreg(num_elem, 16, 4 + LMUL_pow - SEW_pow, vs1); let idx = unsigned(vs1_new[i]); - let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow); + let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow); assert(VLMAX <= 'n); if idx < VLMAX then vs2_val[idx] else zeros() } @@ -694,13 +694,13 @@ function clause execute(VXSG(funct6, vm, vs2, rs1, vd)) = { if i >= rs1_val then vs2_val[i - rs1_val] else vd_val[i] }, VX_VSLIDEDOWN => { - let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow); + let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow); assert(VLMAX > 0 & VLMAX <= 'n); if i + rs1_val < VLMAX then vs2_val[i + rs1_val] else zeros() }, VX_VRGATHER => { if (vs2 == vd) then { handle_illegal(); return RETIRE_FAIL }; - let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow); + let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow); assert(VLMAX > 0 & VLMAX <= 'n); if rs1_val < VLMAX then vs2_val[rs1_val] else zeros() } @@ -1084,13 +1084,13 @@ function clause execute(VISG(funct6, vm, vs2, simm, vd)) = { if i >= imm_val then vs2_val[i - imm_val] else vd_val[i] }, VI_VSLIDEDOWN => { - let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow); + let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow); assert(VLMAX > 0 & VLMAX <= 'n); if i + imm_val < VLMAX then vs2_val[i + imm_val] else zeros() }, VI_VRGATHER => { if (vs2 == vd) then { handle_illegal(); return RETIRE_FAIL }; - let VLMAX = int_power(2, LMUL_pow + VLEN_pow - SEW_pow); + let VLMAX = 2 ^ (LMUL_pow + VLEN_pow - SEW_pow); assert(VLMAX > 0 & VLMAX <= 'n); if imm_val < VLMAX then vs2_val[imm_val] else zeros() } diff --git a/model/riscv_insts_vext_fp_utils.sail b/model/riscv_insts_vext_fp_utils.sail index 2abbe5f6c..8d7ccf8af 100755 --- a/model/riscv_insts_vext_fp_utils.sail +++ b/model/riscv_insts_vext_fp_utils.sail @@ -500,7 +500,7 @@ function rsqrt7 (v, sub) = { }; assert(idx >= 0 & idx < 128); let out_sig = to_bits(s, table[(127 - idx)]) << (s - 7); - let out_exp = to_bits(e, (3 * (2^(e - 1) - 1) - 1 - signed(normalized_exp)) / 2); + let out_exp = to_bits(e, (3 * (2 ^ (e - 1) - 1) - 1 - signed(normalized_exp)) / 2); zero_extend(64, sign @ out_exp @ out_sig) } @@ -593,7 +593,7 @@ function recip7 (v, rm_3b, sub) = { 64 => unsigned(normalized_sig[51 .. 45]) }; assert(idx >= 0 & idx < 128); - let mid_exp = to_bits(e, 2 * (2^(e - 1) - 1) - 1 - signed(normalized_exp)); + let mid_exp = to_bits(e, 2 * (2 ^ (e - 1) - 1) - 1 - signed(normalized_exp)); let mid_sig = to_bits(s, table[(127 - idx)]) << (s - 7); let (out_exp, out_sig)= diff --git a/model/riscv_insts_vext_mem.sail b/model/riscv_insts_vext_mem.sail index 9ffb56eff..449d600a9 100644 --- a/model/riscv_insts_vext_mem.sail +++ b/model/riscv_insts_vext_mem.sail @@ -69,7 +69,7 @@ mapping clause encdec = VLSEGTYPE(nf, vm, rs1, width, vd) if extensionEnabled(Ex val process_vlseg : forall 'f 'b 'n 'p, (0 < 'f & 'f <= 8) & ('b in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('b), regidx, int('p), int('n)) -> Retired function process_vlseg (nf, vm, vd, load_width_bytes, rs1, EMUL_pow, num_elem) = { - let EMUL_reg : int = if EMUL_pow <= 0 then 1 else int_power(2, EMUL_pow); + let EMUL_reg : int = if EMUL_pow <= 0 then 1 else 2 ^ (EMUL_pow); let width_type : word_width = size_bytes(load_width_bytes); let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000); let vd_seg : vector('n, dec, bits('f * 'b * 8)) = read_vreg_seg(num_elem, load_width_bytes * 8, EMUL_pow, nf, vd); @@ -135,7 +135,7 @@ mapping clause encdec = VLSEGFFTYPE(nf, vm, rs1, width, vd) if extensionEnabled( val process_vlsegff : forall 'f 'b 'n 'p, (0 < 'f & 'f <= 8) & ('b in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('b), regidx, int('p), int('n)) -> Retired function process_vlsegff (nf, vm, vd, load_width_bytes, rs1, EMUL_pow, num_elem) = { - let EMUL_reg : int = if EMUL_pow <= 0 then 1 else int_power(2, EMUL_pow); + let EMUL_reg : int = if EMUL_pow <= 0 then 1 else 2 ^ (EMUL_pow); let width_type : word_width = size_bytes(load_width_bytes); let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000); let vd_seg : vector('n, dec, bits('f * 'b * 8)) = read_vreg_seg(num_elem, load_width_bytes * 8, EMUL_pow, nf, vd); @@ -240,7 +240,7 @@ mapping clause encdec = VSSEGTYPE(nf, vm, rs1, width, vs3) if extensionEnabled(E val process_vsseg : forall 'f 'b 'n 'p, (0 < 'f & 'f <= 8) & ('b in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('b), regidx, int('p), int('n)) -> Retired function process_vsseg (nf, vm, vs3, load_width_bytes, rs1, EMUL_pow, num_elem) = { - let EMUL_reg : int = if EMUL_pow <= 0 then 1 else int_power(2, EMUL_pow); + let EMUL_reg : int = if EMUL_pow <= 0 then 1 else 2 ^ (EMUL_pow); let width_type : word_width = size_bytes(load_width_bytes); let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000); let vs3_seg : vector('n, dec, bits('f * 'b * 8)) = read_vreg_seg(num_elem, load_width_bytes * 8, EMUL_pow, nf, vs3); @@ -309,7 +309,7 @@ mapping clause encdec = VLSSEGTYPE(nf, vm, rs2, rs1, width, vd) if extensionEnab val process_vlsseg : forall 'f 'b 'n 'p, (0 < 'f & 'f <= 8) & ('b in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('b), regidx, regidx, int('p), int('n)) -> Retired function process_vlsseg (nf, vm, vd, load_width_bytes, rs1, rs2, EMUL_pow, num_elem) = { - let EMUL_reg : int = if EMUL_pow <= 0 then 1 else int_power(2, EMUL_pow); + let EMUL_reg : int = if EMUL_pow <= 0 then 1 else 2 ^ (EMUL_pow); let width_type : word_width = size_bytes(load_width_bytes); let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000); let vd_seg : vector('n, dec, bits('f * 'b * 8)) = read_vreg_seg(num_elem, load_width_bytes * 8, EMUL_pow, nf, vd); @@ -376,7 +376,7 @@ mapping clause encdec = VSSSEGTYPE(nf, vm, rs2, rs1, width, vs3) if extensionEna val process_vssseg : forall 'f 'b 'n 'p, (0 < 'f & 'f <= 8) & ('b in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('b), regidx, regidx, int('p), int('n)) -> Retired function process_vssseg (nf, vm, vs3, load_width_bytes, rs1, rs2, EMUL_pow, num_elem) = { - let EMUL_reg : int = if EMUL_pow <= 0 then 1 else int_power(2, EMUL_pow); + let EMUL_reg : int = if EMUL_pow <= 0 then 1 else 2 ^ (EMUL_pow); let width_type : word_width = size_bytes(load_width_bytes); let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000); let vs3_seg : vector('n, dec, bits('f * 'b * 8)) = read_vreg_seg(num_elem, load_width_bytes * 8, EMUL_pow, nf, vs3); @@ -446,7 +446,7 @@ mapping clause encdec = VLUXSEGTYPE(nf, vm, vs2, rs1, width, vd) if extensionEna val process_vlxseg : forall 'f 'ib 'db 'ip 'dp 'n, (0 < 'f & 'f <= 8) & ('ib in {1, 2, 4, 8}) & ('db in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('ib), int('db), int('ip), int('dp), regidx, regidx, int('n), int) -> Retired function process_vlxseg (nf, vm, vd, EEW_index_bytes, EEW_data_bytes, EMUL_index_pow, EMUL_data_pow, rs1, vs2, num_elem, mop) = { - let EMUL_data_reg : int = if EMUL_data_pow <= 0 then 1 else int_power(2, EMUL_data_pow); + let EMUL_data_reg : int = if EMUL_data_pow <= 0 then 1 else 2 ^ (EMUL_data_pow); let width_type : word_width = size_bytes(EEW_data_bytes); let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000); let vd_seg : vector('n, dec, bits('f * 'db * 8)) = read_vreg_seg(num_elem, EEW_data_bytes * 8, EMUL_data_pow, nf, vd); @@ -538,7 +538,7 @@ mapping clause encdec = VSUXSEGTYPE(nf, vm, vs2, rs1, width, vs3) if extensionEn val process_vsxseg : forall 'f 'ib 'db 'ip 'dp 'n, (0 < 'f & 'f <= 8) & ('ib in {1, 2, 4, 8}) & ('db in {1, 2, 4, 8}) & ('n >= 0). (int('f), bits(1), regidx, int('ib), int('db), int('ip), int('dp), regidx, regidx, int('n), int) -> Retired function process_vsxseg (nf, vm, vs3, EEW_index_bytes, EEW_data_bytes, EMUL_index_pow, EMUL_data_pow, rs1, vs2, num_elem, mop) = { - let EMUL_data_reg : int = if EMUL_data_pow <= 0 then 1 else int_power(2, EMUL_data_pow); + let EMUL_data_reg : int = if EMUL_data_pow <= 0 then 1 else 2 ^ (EMUL_data_pow); let width_type : word_width = size_bytes(EEW_data_bytes); let vm_val : vector('n, dec, bool) = read_vmask(num_elem, vm, 0b00000); let vs3_seg : vector('n, dec, bits('f * 'db * 8)) = read_vreg_seg(num_elem, EEW_data_bytes * 8, EMUL_data_pow, nf, vs3); diff --git a/model/riscv_insts_vext_utils.sail b/model/riscv_insts_vext_utils.sail index 99d575b8c..4bf570671 100755 --- a/model/riscv_insts_vext_utils.sail +++ b/model/riscv_insts_vext_utils.sail @@ -22,7 +22,7 @@ mapping maybe_vmask : string <-> bits(1) = { */ val valid_eew_emul : (int, int) -> bool function valid_eew_emul(EEW, EMUL_pow) = { - let ELEN = int_power(2, get_elen_pow()); + let ELEN = 2 ^ get_elen_pow(); EEW >= 8 & EEW <= ELEN & EMUL_pow >= -3 & EMUL_pow <= 3 } @@ -60,8 +60,8 @@ function valid_rd_mask(rd, vm) = { */ val valid_reg_overlap : (regidx, regidx, int, int) -> bool function valid_reg_overlap(rs, rd, EMUL_pow_rs, EMUL_pow_rd) = { - let rs_group = if EMUL_pow_rs > 0 then int_power(2, EMUL_pow_rs) else 1; - let rd_group = if EMUL_pow_rd > 0 then int_power(2, EMUL_pow_rd) else 1; + let rs_group = if EMUL_pow_rs > 0 then 2 ^ (EMUL_pow_rs) else 1; + let rd_group = if EMUL_pow_rd > 0 then 2 ^ (EMUL_pow_rd) else 1; let rs_int = unsigned(rs); let rd_int = unsigned(rd); if EMUL_pow_rs < EMUL_pow_rd then { @@ -78,8 +78,8 @@ function valid_reg_overlap(rs, rd, EMUL_pow_rs, EMUL_pow_rd) = { */ val valid_segment : (int, int) -> bool function valid_segment(nf, EMUL_pow) = { - if EMUL_pow < 0 then nf / int_power(2, 0 - EMUL_pow) <= 8 - else nf * int_power(2, EMUL_pow) <= 8 + if EMUL_pow < 0 then nf / (2 ^ (0 - EMUL_pow)) <= 8 + else nf * 2 ^ (EMUL_pow) <= 8 } /* ******************************************************************************* */ @@ -209,7 +209,7 @@ function init_masked_result(num_elem, SEW, LMUL_pow, vd_val, vm_val) = { result : vector('n, dec, bits('m)) = undefined; /* Determine the actual number of elements when lmul < 1 */ - let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / int_power(2, 0 - LMUL_pow); + let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / (2 ^ (0 - LMUL_pow)); assert(num_elem >= real_num_elem); foreach (i from 0 to (num_elem - 1)) { @@ -259,7 +259,7 @@ function init_masked_source(num_elem, LMUL_pow, vm_val) = { mask : vector('n, dec, bool) = undefined; /* Determine the actual number of elements when lmul < 1 */ - let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / int_power(2, 0 - LMUL_pow); + let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / (2 ^ (0 - LMUL_pow)); assert(num_elem >= real_num_elem); foreach (i from 0 to (num_elem - 1)) { @@ -294,7 +294,7 @@ function init_masked_result_carry(num_elem, SEW, LMUL_pow, vd_val) = { result : vector('n, dec, bool) = undefined; /* Determine the actual number of elements when lmul < 1 */ - let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / int_power(2, 0 - LMUL_pow); + let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / (2 ^ (0 - LMUL_pow)); assert(num_elem >= real_num_elem); foreach (i from 0 to (num_elem - 1)) { @@ -331,7 +331,7 @@ function init_masked_result_cmp(num_elem, SEW, LMUL_pow, vd_val, vm_val) = { result : vector('n, dec, bool) = undefined; /* Determine the actual number of elements when lmul < 1 */ - let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / int_power(2, 0 - LMUL_pow); + let real_num_elem = if LMUL_pow >= 0 then num_elem else num_elem / (2 ^ (0 - LMUL_pow)); assert(num_elem >= real_num_elem); foreach (i from 0 to (num_elem - 1)) { @@ -372,7 +372,7 @@ function init_masked_result_cmp(num_elem, SEW, LMUL_pow, vd_val, vm_val) = { val read_vreg_seg : forall 'n 'm 'p 'q, 'n >= 0 & 'q >= 0. (int('n), int('m), int('p), int('q), regidx) -> vector('n, dec, bits('q * 'm)) function read_vreg_seg(num_elem, SEW, LMUL_pow, nf, vrid) = { assert('q * 'm > 0); - let LMUL_reg : int = if LMUL_pow <= 0 then 1 else int_power(2, LMUL_pow); + let LMUL_reg : int = if LMUL_pow <= 0 then 1 else 2 ^ (LMUL_pow); vreg_list : vector('q, dec, vector('n, dec, bits('m))) = undefined; result : vector('n, dec, bits('q * 'm)) = undefined; foreach (j from 0 to (nf - 1)) { diff --git a/model/riscv_insts_vext_vset.sail b/model/riscv_insts_vext_vset.sail index 16adf2806..7e73d5724 100644 --- a/model/riscv_insts_vext_vset.sail +++ b/model/riscv_insts_vext_vset.sail @@ -85,7 +85,7 @@ function clause execute VSETVLI(ma, ta, sew, lmul, rs1, rd) = { let LMUL_pow_new = get_lmul_pow(); let SEW_pow_new = get_sew_pow(); if SEW_pow_new > (LMUL_pow_new + ELEN_pow) then { handle_illegal_vtype(); return RETIRE_SUCCESS }; - let VLMAX = int_power(2, VLEN_pow + LMUL_pow_new - SEW_pow_new); + let VLMAX = 2 ^ (VLEN_pow + LMUL_pow_new - SEW_pow_new); /* set vl according to VLMAX and AVL */ if (rs1 != 0b00000) then { /* normal stripmining */ @@ -136,7 +136,7 @@ function clause execute VSETVL(rs2, rs1, rd) = { let LMUL_pow_new = get_lmul_pow(); let SEW_pow_new = get_sew_pow(); if SEW_pow_new > (LMUL_pow_new + ELEN_pow) then { handle_illegal_vtype(); return RETIRE_SUCCESS }; - let VLMAX = int_power(2, VLEN_pow + LMUL_pow_new - SEW_pow_new); + let VLMAX = 2 ^ (VLEN_pow + LMUL_pow_new - SEW_pow_new); /* set vl according to VLMAX and AVL */ if (rs1 != 0b00000) then { /* normal stripmining */ @@ -183,7 +183,7 @@ function clause execute VSETIVLI(ma, ta, sew, lmul, uimm, rd) = { let LMUL_pow_new = get_lmul_pow(); let SEW_pow_new = get_sew_pow(); if SEW_pow_new > (LMUL_pow_new + ELEN_pow) then { handle_illegal_vtype(); return RETIRE_SUCCESS }; - let VLMAX = int_power(2, VLEN_pow + LMUL_pow_new - SEW_pow_new); + let VLMAX = 2 ^ (VLEN_pow + LMUL_pow_new - SEW_pow_new); /* set vl according to VLMAX and AVL */ let AVL = unsigned(uimm); /* AVL is encoded as 5-bit zero-extended imm in the rs1 field */ diff --git a/model/riscv_vext_regs.sail b/model/riscv_vext_regs.sail index 8d96f12a0..511e2aca1 100644 --- a/model/riscv_vext_regs.sail +++ b/model/riscv_vext_regs.sail @@ -239,7 +239,7 @@ function get_num_elem(LMUL_pow, SEW) = { let LMUL_pow_reg = if LMUL_pow < 0 then 0 else LMUL_pow; /* Ignore lmul < 1 so that the entire vreg is read, allowing all masking to * be handled in init_masked_result */ - let num_elem = int_power(2, LMUL_pow_reg) * VLEN / SEW; + let num_elem = 2 ^ (LMUL_pow_reg) * VLEN / SEW; assert(num_elem > 0); num_elem }