Skip to content

Commit

Permalink
fix benchmark and splitlines overhead
Browse files Browse the repository at this point in the history
Signed-off-by: martinvuyk <[email protected]>
  • Loading branch information
martinvuyk committed Oct 23, 2024
1 parent aea3930 commit 323f4d6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 27 deletions.
26 changes: 12 additions & 14 deletions stdlib/benchmarks/collections/bench_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from benchmark import Bench, BenchConfig, Bencher, BenchId, Unit, keep, run
from random import random_si64, seed
from pathlib import _dir_of_current_file
from collections import Optional
from collections import Optional, Dict
from os import abort
from stdlib.collections.string import String
from stdlib.utils._utf8_validation import _is_valid_utf8
from collections.string import String
from utils._utf8_validation import _is_valid_utf8


# ===----------------------------------------------------------------------===#
Expand Down Expand Up @@ -224,16 +224,6 @@ fn bench_string_is_valid_utf8[
def main():
seed()
var m = Bench(BenchConfig(num_repetitions=5))
# NOTE: A proper way to run a benchmark like this is:
# 1. Run the benchmark on nightly branch with num_repetitions=5 and take the
# **median** value for each function, length, and language that is to be
# measured.
# 2. Then run the benchmark on num_repetitions=1 if you want faster results
# during development of your branch.
# 3. When ready to make statements about speed improvements, first run the
# benchmark again with num_repetitions=5 and take the **median** of that.
# 4. Make a table and report the new **median** numbers and the markdown
# percentage improvement over nightly version (new - nightly)/nightly.
alias filenames = (
"UN_charter_EN",
"UN_charter_ES",
Expand Down Expand Up @@ -281,4 +271,12 @@ def main():
m.bench_function[bench_string_is_valid_utf8[length, fname]](
BenchId("bench_string_is_valid_utf8" + suffix)
)
m.dump_report()

results = Dict[String, Float64]()
for info in m.info_vec:
n = info[].name
time = info[].result.mean("ms")
results[n] = (results.get(n).or_else(time) + time) / 2
print("")
for k_v in results.items():
print(k_v[].key, k_v[].value, sep=",")
38 changes: 25 additions & 13 deletions stdlib/src/utils/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ fn _unicode_codepoint_utf8_byte_length(c: Int) -> Int:
return int((sizes < c).cast[DType.uint8]().reduce_add())


@always_inline
fn _utf8_first_byte_sequence_length(b: Byte) -> Int:
"""Get the length of the sequence starting with given byte. Do note that
this does not work correctly if given a continuation byte."""

debug_assert(
(b & 0b1100_0000) != 0b1000_0000,
(
"Function `_utf8_first_byte_sequence_length()` does not work"
" correctly if given a continuation byte."
),
)
var flipped = ~b
return int(count_leading_zeros(flipped) + (flipped >> 7))


fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int):
"""Shift unicode to utf8 representation.
Expand Down Expand Up @@ -965,23 +981,19 @@ struct StringSlice[is_mutable: Bool, //, origin: Origin[is_mutable].type,](
while offset < length:
eol_start = offset
eol_length = 0
iterator = Self(
unsafe_from_utf8_ptr=ptr + offset, len=length - offset
).__iter__()

while eol_start < length:
char = iterator.__next__()
c_len = char.byte_length()
b0 = ptr[eol_start]
char_len = _utf8_first_byte_sequence_length(b0)
char = Self(unsafe_from_utf8_ptr=ptr + eol_start, len=char_len)
if char.isnewline[single_character=True]():
if c_len == 1 and char.unsafe_ptr()[0] == `\r`:
next_char = iterator.__next__()
if next_char.byte_length() == 1:
isnewline = next_char.unsafe_ptr()[0] == `\n`
eol_length = 1 + int(isnewline)
break
eol_length = c_len
char_end = eol_start + char_len
if b0 == `\r` and char_end < length:
debug_assert(char_len == 1, "corrupted byte sequence")
char_len += int(ptr[char_end] == `\n`)
eol_length = char_len
break
eol_start += c_len
eol_start += char_len

str_len = eol_start - offset + int(keepends) * eol_length
s = StringSlice[O](unsafe_from_utf8_ptr=ptr + offset, len=str_len)
Expand Down

0 comments on commit 323f4d6

Please sign in to comment.