diff --git a/stdlib/src/bit/__init__.mojo b/stdlib/src/bit/__init__.mojo index e41c3ca52c1..31303ed8f8b 100644 --- a/stdlib/src/bit/__init__.mojo +++ b/stdlib/src/bit/__init__.mojo @@ -25,4 +25,6 @@ from .bit import ( pop_count, rotate_bits_left, rotate_bits_right, + bin, + hex, ) diff --git a/stdlib/src/bit/bit.mojo b/stdlib/src/bit/bit.mojo index 5405411ab1a..804d47b18f9 100644 --- a/stdlib/src/bit/bit.mojo +++ b/stdlib/src/bit/bit.mojo @@ -19,8 +19,10 @@ from bit import count_leading_zeros ``` """ +from memory import bitcast, unpack_bits from sys import llvm_intrinsic, sizeof from sys.info import bitwidthof +from utils import Span # ===----------------------------------------------------------------------===# # count_leading_zeros @@ -654,3 +656,79 @@ fn rotate_bits_right[ return llvm_intrinsic["llvm.fshr", __type_of(x), has_side_effect=False]( x, x, SIMD[type, width](shift) ) + + +# ===----------------------------------------------------------------------===# +# bin and hex +# ===----------------------------------------------------------------------===# + + +fn bin[dtype: DType, //](x: Scalar[dtype]) -> String: + """Converts a scalar to a binary string. + + Parameters: + dtype: The data type of the input scalar. + + Args: + x: The input scalar value. + + Returns: + A binary string representation of the input scalar value. + """ + alias len = dtype.bitwidth() + 1 + buff = String._buffer_type(capacity=len) + _write_bin(x, buff) + buff.size = len + return String(impl=buff) + + +@always_inline +fn _write_bin(x: Scalar, s: Span[Byte, _]): + alias `0` = ord("0") + + @parameter + if x.type.sizeof() == 1: + r = x + else: + r = byte_swap(x) + bytes = unpack_bits(r).cast[DType.uint8]() + s.unsafe_ptr().store(bytes + `0`) + + +# fmt: off +alias _table = SIMD[DType.uint8, 16]( + ord("0"), ord("1"), ord("2"), ord("3"), ord("4"), ord("5"), ord("6"), ord("7"), + ord("8"), ord("9"), ord("a"), ord("b"), ord("c"), ord("d"), ord("e"), ord("f"), +) +# fmt: on + + +fn hex[dtype: DType, //](x: Scalar[dtype]) -> String: + """Converts a scalar to a hexadecimal string. + + Parameters: + dtype: The data type of the input scalar. + + Args: + x: The input scalar value. + + Returns: + A hexadecimal string representation of the input scalar value. + """ + alias len = dtype.sizeof() * 2 + 1 + buff = String._buffer_type(capacity=len) + _write_hex(x, buff) + buff.size = len + return String(impl=buff) + + +@always_inline +fn _write_hex(x: Scalar, s: Span[Byte, _]): + @parameter + if x.type.sizeof() == 1: + r = x + else: + r = byte_swap(x) + bytes = bitcast[DType.uint8, x.type.sizeof()](r) + nibbles = (bytes >> 4).interleave(bytes & 0xF) + s.unsafe_ptr().store(_table._dynamic_shuffle(nibbles)) diff --git a/stdlib/src/memory/__init__.mojo b/stdlib/src/memory/__init__.mojo index cc226348fd1..668a1195517 100644 --- a/stdlib/src/memory/__init__.mojo +++ b/stdlib/src/memory/__init__.mojo @@ -16,5 +16,5 @@ from .arc import Arc from .box import Box from .memory import memcmp, memcpy, memset, memset_zero, stack_allocation from .pointer import AddressSpace, Pointer -from .unsafe import bitcast +from .unsafe import bitcast, unpack_bits from .unsafe_pointer import UnsafePointer diff --git a/stdlib/src/memory/unsafe.mojo b/stdlib/src/memory/unsafe.mojo index 93d7e2266b2..9c495fb0539 100644 --- a/stdlib/src/memory/unsafe.mojo +++ b/stdlib/src/memory/unsafe.mojo @@ -119,3 +119,33 @@ fn bitcast[ return __mlir_op.`pop.bitcast`[ _type = __mlir_type[`!pop.scalar<`, new_type.value, `>`] ](val.value) + + +@always_inline("nodebug") +fn unpack_bits[ + dtype: DType, //, width: Int = bitwidthof[dtype]() +](res: Scalar[dtype]) -> SIMD[DType.bool, width]: + """Pack a scalar value into a SIMD vector of boolean values. + + Parameters: + dtype: The data type of the input scalar value. + width: The width of the SIMD vector. + + Constraints: + The bitwidth of the data type must be equal to the SIMD width. + + Args: + res: The input scalar value. + + Returns: + A SIMD vector where each element is a boolean value representing the + corresponding bit of the input scalar value. + """ + constrained[ + bitwidthof[dtype]() == width, + "the bitwidth of the data type must be equal to the SIMD width", + ]() + b = __mlir_op.`pop.bitcast`[ + _type = __mlir_type[`!pop.simd<`, width.value, `, bool>`] + ](res.value) + return SIMD[DType.bool, width](b)