Skip to content

Commit

Permalink
Use generic versions of some float utility functions
Browse files Browse the repository at this point in the history
Make `fsplit`/`fmake`, `f_is_...` and `f_negate` generic over the float size. The vector code actually already used functions like this.

This reduces a lot of copy/pasted code and will allow even more of the code to be deduplicated later.
  • Loading branch information
Timmmm committed Apr 4, 2024
1 parent 8133f43 commit 38fc8ac
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 602 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ SAIL_SYS_SRCS += riscv_next_regs.sail
SAIL_SYS_SRCS += riscv_sys_exceptions.sail # default basic helpers for exception handling
SAIL_SYS_SRCS += riscv_sync_exception.sail # define the exception structure used in the model
SAIL_SYS_SRCS += riscv_next_control.sail # helpers for the 'N' extension
SAIL_SYS_SRCS += riscv_softfloat_interface.sail riscv_fdext_regs.sail riscv_fdext_control.sail
SAIL_SYS_SRCS += riscv_float.sail riscv_softfloat_interface.sail riscv_fdext_regs.sail riscv_fdext_control.sail
SAIL_SYS_SRCS += riscv_csr_ext.sail # access to CSR extensions
SAIL_SYS_SRCS += riscv_sys_control.sail # general exception handling

Expand Down
111 changes: 111 additions & 0 deletions model/riscv_float.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// ---------------------------------------------------------------------------

// Split a floating point bitvec up into its sign, exponent, mantissa parts.
val fsplit : forall 'n, 'n in {16, 32, 64}.
bits('n) -> (
bits(1),
bits(if 'n == 16 then 5 else (if 'n == 32 then 8 else 11)),
bits(if 'n == 16 then 10 else (if 'n == 32 then 23 else 52)),
)
function fsplit(x) = {
if 'n == 16 then (x[15..15], x[14..10], x[9..0])
else if 'n == 32 then (x[31..31], x[30..23], x[22..0])
else (x[63..63], x[62..52], x[51..0])
}

// Join sign, exponent, mantissa parts back into a single bit vector.
val fmake : forall 'e, 'e in {5, 8, 11}.
(
bits(1),
bits('e),
bits(if 'e == 5 then 10 else (if 'e == 8 then 23 else 52)),
) -> bits(if 'e == 5 then 16 else (if 'e == 8 then 32 else 64))
function fmake(sign, exp, mant) = sign @ exp @ mant

// ---------------------------------------------------------------------------
// Floating point property functions.

// Bit vector type for floating points - restricted to f16, f32, f64.
type fbits = { 'n, 'n in {16, 32, 64}. bits('n) }

function f_is_neg_inf(x : fbits) -> bool = {
let (sign, exp, mant) = fsplit(x);
( (sign == 0b1)
& (exp == ones())
& (mant == zeros()))
}

function f_is_neg_norm(x : fbits) -> bool = {
let (sign, exp, mant) = fsplit(x);
( (sign == 0b1)
& (exp != zeros())
& (exp != ones()))
}

function f_is_neg_subnorm(x : fbits) -> bool = {
let (sign, exp, mant) = fsplit(x);
( (sign == 0b1)
& (exp == zeros())
& (mant != zeros()))
}

function f_is_neg_zero(x : fbits) -> bool = {
let (sign, exp, mant) = fsplit(x);
( (sign == ones())
& (exp == zeros())
& (mant == zeros()))
}

function f_is_pos_zero(x : fbits) -> bool = {
let (sign, exp, mant) = fsplit(x);
( (sign == zeros())
& (exp == zeros())
& (mant == zeros()))
}

function f_is_pos_subnorm(x : fbits) -> bool = {
let (sign, exp, mant) = fsplit(x);
( (sign == zeros())
& (exp == zeros())
& (mant != zeros()))
}

function f_is_pos_norm(x : fbits) -> bool = {
let (sign, exp, mant) = fsplit(x);
( (sign == zeros())
& (exp != zeros())
& (exp != ones()))
}

function f_is_pos_inf(x : fbits) -> bool = {
let (sign, exp, mant) = fsplit(x);
( (sign == zeros())
& (exp == ones())
& (mant == zeros()))
}

function f_is_SNaN(x : fbits) -> bool = {
let (sign, exp, 'mant) = fsplit(x);
( (exp == ones())
& (mant['mant - 1] == bitzero)
& (mant != zeros()))
}

function f_is_QNaN(x : fbits) -> bool = {
let (sign, exp, 'mant) = fsplit(x);
( (exp == ones())
& (mant['mant - 1] == bitone))
}

// Either QNaN or SNan
function f_is_NaN(x : fbits) -> bool = {
let (sign, exp, mant) = fsplit(x);
( (exp == ones())
& (mant != zeros()))
}

// ---------------------------------------------------------------------------

// Negation (invert the sign bit which is always the top bit).
val f_negate : forall 'n, 'n in {16, 32, 64}. bits('n) -> bits('n)
function f_negate(x) = ~(x['n - 1 .. 'n - 1]) @ x['n - 2 .. 0]
Loading

0 comments on commit 38fc8ac

Please sign in to comment.