From df0608a4bc5fe28a14d1ef2affc694023371ff7e Mon Sep 17 00:00:00 2001 From: Max Brylski Date: Thu, 1 Aug 2024 11:02:35 -0500 Subject: [PATCH] Refactor `/format_int.mojo` to use Formatter, and take care of corresponding TODO's. Signed-off-by: Max Brylski --- stdlib/src/builtin/format_int.mojo | 234 +++++++++++++++++--------- stdlib/src/builtin/object.mojo | 22 +-- stdlib/src/builtin/simd.mojo | 8 +- stdlib/src/memory/unsafe_pointer.mojo | 6 +- 4 files changed, 169 insertions(+), 101 deletions(-) diff --git a/stdlib/src/builtin/format_int.mojo b/stdlib/src/builtin/format_int.mojo index 0ea4fac700d..702a43b5fdf 100644 --- a/stdlib/src/builtin/format_int.mojo +++ b/stdlib/src/builtin/format_int.mojo @@ -28,7 +28,34 @@ alias _DEFAULT_DIGIT_CHARS = "0123456789abcdefghijklmnopqrstuvwxyz" # ===----------------------------------------------------------------------===# -fn bin(num: Scalar, /, *, prefix: StaticString = "0b") -> String: +fn bin[ + WriterType: Writer, // +](value: Scalar, /, *, inout writer: WriterType, prefix: StaticString = "0b"): + """Writes the binary string representation an integral value to a formatter. + + ```mojo + print(bin(123)) + print(bin(-123)) + ``` + ```plaintext + '0b1111011' + '-0b1111011' + ``` + + Args: + value: An integral scalar value. + writer: The formatter to write to. + prefix: The prefix of the formatted int. + """ + + @parameter + if value.type is DType.bool: + bin(value.cast[DType.int8](), writer=writer, prefix=prefix) + else: + _try_write_int(writer, value, 2, prefix=prefix) + + +fn bin(value: Scalar, /, *, prefix: StaticString = "0b") -> String: """Return the binary string representation an integral value. ```mojo @@ -41,44 +68,48 @@ fn bin(num: Scalar, /, *, prefix: StaticString = "0b") -> String: ``` Args: - num: An integral scalar value. + value: An integral scalar value. prefix: The prefix of the formatted int. Returns: The binary string representation of num. """ - return _try_format_int(num, 2, prefix=prefix) + var result = String() + bin(value, writer=result, prefix=prefix) + return result^ -# Need this until we have constraints to stop the compiler from matching this -# directly to bin[type: DType](num: Scalar[type]). -fn bin(b: Scalar[DType.bool], /, *, prefix: StaticString = "0b") -> String: - """Returns the binary representation of a scalar bool. +fn bin[ + T: Indexer, WriterType: Writer, // +](value: T, /, *, inout writer: WriterType, prefix: StaticString = "0b"): + """Writes the binary representation of an indexer type to a formatter. + + Parameters: + T: The Indexer type. + WriterType: The type of the `writer` argument. Args: - b: A scalar bool value. + value: An indexer value. + writer: The formatter to write to. prefix: The prefix of the formatted int. - - Returns: - The binary string representation of b. """ - return bin(b.cast[DType.int8](), prefix=prefix) + bin(Scalar[DType.index](index(value)), writer=writer, prefix=prefix) -fn bin[T: Indexer, //](num: T, /, *, prefix: StaticString = "0b") -> String: +fn bin[T: Indexer, //](value: T, /, *, prefix: StaticString = "0b") -> String: """Returns the binary representation of an indexer type. Parameters: T: The Indexer type. Args: - num: An indexer value. + value: An indexer value. prefix: The prefix of the formatted int. Returns: The binary string representation of num. """ - return bin(Scalar[DType.index](index(num)), prefix=prefix) + return bin(Scalar[DType.index](index(value)), prefix=prefix) # ===----------------------------------------------------------------------===# @@ -86,6 +117,29 @@ fn bin[T: Indexer, //](num: T, /, *, prefix: StaticString = "0b") -> String: # ===----------------------------------------------------------------------===# +fn hex[ + WriterType: Writer, // +](value: Scalar, /, *, inout writer: WriterType, prefix: StaticString = "0x"): + """Writes the hex string representation of the given integer to a formatter. + + The hexadecimal representation is a base-16 encoding of the integer value. + + The formatted string will be prefixed with "0x" to indicate that the + subsequent digits are hex. + + Args: + value: The integer value to format. + writer: The formatter to write to. + prefix: The prefix of the formatted int. + """ + + @parameter + if value.type is DType.bool: + hex(value.cast[DType.int8](), writer=writer, prefix=prefix) + else: + _try_write_int(writer, value, 16, prefix=prefix) + + fn hex(value: Scalar, /, *, prefix: StaticString = "0x") -> String: """Returns the hex string representation of the given integer. @@ -101,46 +155,52 @@ fn hex(value: Scalar, /, *, prefix: StaticString = "0x") -> String: Returns: A string containing the hex representation of the given integer. """ - return _try_format_int(value, 16, prefix=prefix) + var result = String() + hex(value, writer=result, prefix=prefix) + return result^ -fn hex[T: Indexer, //](value: T, /, *, prefix: StaticString = "0x") -> String: - """Returns the hex string representation of the given integer. +fn hex[ + T: Indexer, WriterType: Writer, // +](value: T, /, *, inout writer: WriterType, prefix: StaticString = "0x"): + """Writes the hex string representation of the given integer to a formatter. The hexadecimal representation is a base-16 encoding of the integer value. - The returned string will be prefixed with "0x" to indicate that the + The formatted string will be prefixed with "0x" to indicate that the subsequent digits are hex. Parameters: T: The indexer type to represent in hexadecimal. + WriterType: The type of the `writer` argument. Args: value: The integer value to format. + writer: The formatter to write to. prefix: The prefix of the formatted int. - - Returns: - A string containing the hex representation of the given integer. """ - return hex(Scalar[DType.index](index(value)), prefix=prefix) + hex(Scalar[DType.index](index(value)), writer=writer, prefix=prefix) -fn hex(value: Scalar[DType.bool], /, *, prefix: StaticString = "0x") -> String: - """Returns the hex string representation of the given scalar bool. +fn hex[T: Indexer, //](value: T, /, *, prefix: StaticString = "0x") -> String: + """Returns the hex string representation of the given integer. - The hexadecimal representation is a base-16 encoding of the bool. + The hexadecimal representation is a base-16 encoding of the integer value. The returned string will be prefixed with "0x" to indicate that the subsequent digits are hex. + Parameters: + T: The indexer type to represent in hexadecimal. + Args: - value: The bool value to format. + value: The integer value to format. prefix: The prefix of the formatted int. Returns: - A string containing the hex representation of the given bool. + A string containing the hex representation of the given integer. """ - return hex(value.cast[DType.int8](), prefix=prefix) + return hex(Scalar[DType.index](index(value)), prefix=prefix) # ===----------------------------------------------------------------------===# @@ -148,6 +208,29 @@ fn hex(value: Scalar[DType.bool], /, *, prefix: StaticString = "0x") -> String: # ===----------------------------------------------------------------------===# +fn oct[ + WriterType: Writer, // +](value: Scalar, /, *, inout writer: WriterType, prefix: StaticString = "0o"): + """Writes the octal string representation of the given integer to a formatter. + + The octal representation is a base-8 encoding of the integer value. + + The formatted string will be prefixed with "0o" to indicate that the + subsequent digits are octal. + + Args: + value: The integer value to format. + writer: The formatter to write to. + prefix: The prefix of the formatted int. + """ + + @parameter + if value.type is DType.bool: + oct(value.cast[DType.int8](), writer=writer, prefix=prefix) + else: + _try_write_int(writer, value, 8, prefix=prefix) + + fn oct(value: Scalar, /, *, prefix: StaticString = "0o") -> String: """Returns the octal string representation of the given integer. @@ -163,46 +246,52 @@ fn oct(value: Scalar, /, *, prefix: StaticString = "0o") -> String: Returns: A string containing the octal representation of the given integer. """ - return _try_format_int(value, 8, prefix=prefix) + var result = String() + oct(value, writer=result, prefix=prefix) + return result^ -fn oct[T: Indexer, //](value: T, /, *, prefix: StaticString = "0o") -> String: - """Returns the octal string representation of the given integer. +fn oct[ + T: Indexer, WriterType: Writer, // +](value: T, /, *, inout writer: WriterType, prefix: StaticString = "0o"): + """Writes the octal string representation of the given integer to a formatter. The octal representation is a base-8 encoding of the integer value. - The returned string will be prefixed with "0o" to indicate that the + The formatted string will be prefixed with "0o" to indicate that the subsequent digits are octal. Parameters: T: The indexer type to represent in octal. + WriterType: The type of the `writer` argument. Args: value: The integer value to format. + writer: The formatter to write to. prefix: The prefix of the formatted int. - - Returns: - A string containing the octal representation of the given integer. """ - return oct(Scalar[DType.index](index(value)), prefix=prefix) + oct(Scalar[DType.index](index(value)), writer=writer, prefix=prefix) -fn oct(value: Scalar[DType.bool], /, *, prefix: StaticString = "0o") -> String: - """Returns the octal string representation of the given scalar bool. +fn oct[T: Indexer, //](value: T, /, *, prefix: StaticString = "0o") -> String: + """Returns the octal string representation of the given integer. - The octal representation is a base-8 encoding of the bool. + The octal representation is a base-8 encoding of the integer value. The returned string will be prefixed with "0o" to indicate that the subsequent digits are octal. + Parameters: + T: The indexer type to represent in octal. + Args: - value: The bool value to format. + value: The integer value to format. prefix: The prefix of the formatted int. Returns: - A string containing the octal representation of the given bool. + A string containing the octal representation of the given integer. """ - return oct(value.cast[DType.int8](), prefix=prefix) + return oct(Scalar[DType.index](index(value)), prefix=prefix) # ===----------------------------------------------------------------------===# @@ -215,17 +304,12 @@ fn _try_format_int( /, radix: Int = 10, *, + digit_chars: StaticString = _DEFAULT_DIGIT_CHARS, prefix: StaticString = "", ) -> String: - try: - return _format_int(value, radix, prefix=prefix) - except e: - # This should not be reachable as _format_int only throws if we pass - # incompatible radix and custom digit chars, which we aren't doing - # above. - return abort[String]( - "unexpected exception formatting value as hexadecimal: " + str(e) - ) + var string = String() + _try_write_int(string, value, radix, digit_chars=digit_chars, prefix=prefix) + return string^ fn _format_int[ @@ -237,44 +321,44 @@ fn _format_int[ digit_chars: StaticString = _DEFAULT_DIGIT_CHARS, prefix: StaticString = "", ) raises -> String: - var output = String() + var string = String() + _write_int(string, value, radix, digit_chars=digit_chars, prefix=prefix) + return string^ - _write_int(output, value, radix, digit_chars=digit_chars, prefix=prefix) - return output^ - - -fn _write_int[ +fn _try_write_int[ type: DType, - W: Writer, + WriterType: Writer, //, ]( - inout writer: W, + inout writer: WriterType, value: Scalar[type], /, radix: Int = 10, *, digit_chars: StaticString = _DEFAULT_DIGIT_CHARS, prefix: StaticString = "", -) raises: - var err = _try_write_int( - writer, value, radix, digit_chars=digit_chars, prefix=prefix - ) - if err: - raise err.value() +): + try: + _write_int(writer, value, radix, digit_chars=digit_chars, prefix=prefix) + except e: + # This should not be reachable as _format_int only throws if we pass + # incompatible radix and custom digit chars, which we aren't doing + # above. + abort("unexpected exception formatting value as hexadecimal: " + str(e)) -fn _try_write_int[ +fn _write_int[ type: DType, - W: Writer, + WriterType: Writer, ]( - inout writer: W, + inout writer: WriterType, value: Scalar[type], /, radix: Int = 10, *, digit_chars: StaticString = _DEFAULT_DIGIT_CHARS, prefix: StaticString = "", -) -> Optional[Error]: +) raises: """Writes a formatted string representation of the given integer using the specified radix. @@ -285,16 +369,16 @@ fn _try_write_int[ # Check that the radix and available digit characters are valid if radix < 2: - return Error("Unable to format integer to string with radix < 2") + raise Error("Unable to format integer to string with radix < 2") if radix > digit_chars.byte_length(): - return Error( + raise Error( "Unable to format integer to string when provided radix is larger " "than length of available digit value characters" ) if not digit_chars.byte_length() >= 2: - return Error( + raise Error( "Unable to format integer to string when provided digit_chars" " mapping len is not >= 2" ) @@ -332,8 +416,6 @@ fn _try_write_int[ ) writer.write(zero) - return None - # Create a buffer to store the formatted value # Stack allocate enough bytes to store any formatted 64-bit integer @@ -410,5 +492,3 @@ fn _try_write_int[ ) writer.write(str_slice) - - return None diff --git a/stdlib/src/builtin/object.mojo b/stdlib/src/builtin/object.mojo index f7587a8ac63..ea33ce11eb2 100644 --- a/stdlib/src/builtin/object.mojo +++ b/stdlib/src/builtin/object.mojo @@ -572,29 +572,25 @@ struct _ObjectImpl( writer.write("None") return if self.is_bool(): - writer.write(str(self.get_as_bool())) + writer.write(self.get_as_bool()) return if self.is_int(): - writer.write(str(self.get_as_int())) + writer.write(self.get_as_int()) return if self.is_float(): - writer.write(str(self.get_as_float())) + writer.write(self.get_as_float()) return if self.is_str(): writer.write( - "'" - + str( - StringRef( - self.get_as_string().data, self.get_as_string().length - ) - ) - + "'" + "'", + StringRef( + self.get_as_string().data, self.get_as_string().length + ), + "'", ) return if self.is_func(): - writer.write( - "Function at address " + hex(int(self.get_as_func().value)) - ) + writer.write("Function at address ", self.get_as_func().value) return if self.is_list(): writer.write(String("[")) diff --git a/stdlib/src/builtin/simd.mojo b/stdlib/src/builtin/simd.mojo index e8eafc1525f..dda6ce9556b 100644 --- a/stdlib/src/builtin/simd.mojo +++ b/stdlib/src/builtin/simd.mojo @@ -1691,13 +1691,7 @@ struct SIMD[type: DType, size: Int]( # an ABI mismatch. _printf["%g"](element.cast[DType.float64]()) elif type.is_integral(): - var err = _try_write_int(writer, element) - if err: - abort( - "unreachable: unexpected write int failure" - " condition: " - + str(err.value()) - ) + _try_write_int(writer, element) else: _printf[_get_dtype_printf_format[type]()](element) else: diff --git a/stdlib/src/memory/unsafe_pointer.mojo b/stdlib/src/memory/unsafe_pointer.mojo index 548ea8ccbd6..525d565cbff 100644 --- a/stdlib/src/memory/unsafe_pointer.mojo +++ b/stdlib/src/memory/unsafe_pointer.mojo @@ -408,7 +408,7 @@ struct UnsafePointer[ Returns: The string representation of the pointer. """ - return hex(int(self)) + return String.write(self) @no_inline fn write_to[W: Writer](self, inout writer: W): @@ -421,9 +421,7 @@ struct UnsafePointer[ Args: writer: The object to write to. """ - - # TODO: Avoid intermediate String allocation. - writer.write(str(self)) + hex(int(self), writer=writer) # ===-------------------------------------------------------------------===# # Methods