Skip to content

Commit

Permalink
use var again
Browse files Browse the repository at this point in the history
Signed-off-by: martinvuyk <[email protected]>
  • Loading branch information
martinvuyk committed Nov 4, 2024
1 parent a619944 commit 7213788
Showing 1 changed file with 52 additions and 40 deletions.
92 changes: 52 additions & 40 deletions stdlib/src/utils/span.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ struct Span[
Notes:
The function works on an empty span, always returning `-1`.
"""
_len = len(self)
var _len = len(self)

if not subseq:

Expand All @@ -410,8 +410,8 @@ struct Span[
n_s = start
else:
n_s = normalize_index["Span", ignore_zero_length=True](start, self)
s_ptr = self.unsafe_ptr()
haystack = __type_of(self)(unsafe_ptr=s_ptr + n_s, len=_len - n_s)
var s_ptr = self.unsafe_ptr()
var haystack = __type_of(self)(unsafe_ptr=s_ptr + n_s, len=_len - n_s)
var loc: UnsafePointer[Scalar[D]]

@parameter
Expand Down Expand Up @@ -534,15 +534,15 @@ fn _memchr[
](span: Span[Scalar[type], O], char: Scalar[type]) -> UnsafePointer[
Scalar[type]
] as output:
haystack = span.unsafe_ptr()
length = len(span)
var haystack = span.unsafe_ptr()
var length = len(span)
alias bool_mask_width = simdwidthof[DType.bool]()
first_needle = SIMD[type, bool_mask_width](char)
vectorized_end = _align_down(length, bool_mask_width)
var first_needle = SIMD[type, bool_mask_width](char)
var vectorized_end = _align_down(length, bool_mask_width)

for i in range(0, vectorized_end, bool_mask_width):
bool_mask = haystack.load[width=bool_mask_width](i) == first_needle
mask = pack_bits(bool_mask)
var bool_mask = haystack.load[width=bool_mask_width](i) == first_needle
var mask = pack_bits(bool_mask)
if mask:
output = haystack + int(i + count_trailing_zeros(mask))
return
Expand All @@ -561,10 +561,10 @@ fn _memmem[
](
haystack_span: Span[Scalar[type], O], needle_span: Span[Scalar[type]]
) -> UnsafePointer[Scalar[type]] as output:
haystack = haystack_span.unsafe_ptr()
haystack_len = len(haystack_span)
needle = needle_span.unsafe_ptr()
needle_len = len(needle_span)
var haystack = haystack_span.unsafe_ptr()
var haystack_len = len(haystack_span)
var needle = needle_span.unsafe_ptr()
var needle_len = len(needle_span)
debug_assert(needle_len > 0, "needle_len must be > 0")
if needle_len == 1:
output = _memchr[type](haystack_span, needle[0])
Expand All @@ -574,20 +574,26 @@ fn _memmem[
return

alias bool_mask_width = simdwidthof[DType.bool]()
vectorized_end = _align_down(haystack_len - needle_len + 1, bool_mask_width)
var vectorized_end = _align_down(
haystack_len - needle_len + 1, bool_mask_width
)

first_needle = SIMD[type, bool_mask_width](needle[0])
last_needle = SIMD[type, bool_mask_width](needle[needle_len - 1])
var first_needle = SIMD[type, bool_mask_width](needle[0])
var last_needle = SIMD[type, bool_mask_width](needle[needle_len - 1])

for i in range(0, vectorized_end, bool_mask_width):
first_block = haystack.load[width=bool_mask_width](i)
last_block = haystack.load[width=bool_mask_width](i + needle_len - 1)
var first_block = haystack.load[width=bool_mask_width](i)
var last_block = haystack.load[width=bool_mask_width](
i + needle_len - 1
)

bool_mask = (first_needle == first_block) & (last_needle == last_block)
mask = pack_bits(bool_mask)
var bool_mask = (first_needle == first_block) & (
last_needle == last_block
)
var mask = pack_bits(bool_mask)

while mask:
offset = int(i + count_trailing_zeros(mask))
var offset = int(i + count_trailing_zeros(mask))
if memcmp(haystack + offset + 1, needle + 1, needle_len - 1) == 0:
output = haystack + offset
return
Expand All @@ -609,22 +615,22 @@ fn _memrchr[
](span: Span[Scalar[type], O], char: Scalar[type]) -> UnsafePointer[
Scalar[type]
] as output:
haystack = span.unsafe_ptr()
length = len(span)
var haystack = span.unsafe_ptr()
var length = len(span)
alias bool_mask_width = simdwidthof[DType.bool]()
first_needle = SIMD[type, bool_mask_width](char)
vectorized_end = _align_down(length, bool_mask_width)
var first_needle = SIMD[type, bool_mask_width](char)
var vectorized_end = _align_down(length, bool_mask_width)

for i in reversed(range(vectorized_end, length)):
if haystack[i] == char:
output = haystack + i
return

for i in reversed(range(0, vectorized_end, bool_mask_width)):
bool_mask = haystack.load[width=bool_mask_width](i) == first_needle
mask = pack_bits(bool_mask)
var bool_mask = haystack.load[width=bool_mask_width](i) == first_needle
var mask = pack_bits(bool_mask)
if mask:
zeros = int(count_leading_zeros(mask)) + 1
var zeros = int(count_leading_zeros(mask)) + 1
output = haystack + (i + bool_mask_width - zeros)
return

Expand All @@ -637,10 +643,10 @@ fn _memrmem[
](
haystack_span: Span[Scalar[type], O], needle_span: Span[Scalar[type]]
) -> UnsafePointer[Scalar[type]] as output:
haystack = haystack_span.unsafe_ptr()
haystack_len = len(haystack_span)
needle = needle_span.unsafe_ptr()
needle_len = len(needle_span)
var haystack = haystack_span.unsafe_ptr()
var haystack_len = len(haystack_span)
var needle = needle_span.unsafe_ptr()
var needle_len = len(needle_span)
debug_assert(needle_len > 0, "needle_len must be > 0")

if needle_len == 1:
Expand All @@ -651,7 +657,9 @@ fn _memrmem[
return

alias bool_mask_width = simdwidthof[DType.bool]()
vectorized_end = _align_down(haystack_len - needle_len + 1, bool_mask_width)
var vectorized_end = _align_down(
haystack_len - needle_len + 1, bool_mask_width
)

for i in reversed(range(vectorized_end, haystack_len - needle_len + 1)):
if haystack[i] != needle[0]:
Expand All @@ -661,18 +669,22 @@ fn _memrmem[
output = haystack + i
return

first_needle = SIMD[type, bool_mask_width](needle[0])
last_needle = SIMD[type, bool_mask_width](needle[needle_len - 1])
var first_needle = SIMD[type, bool_mask_width](needle[0])
var last_needle = SIMD[type, bool_mask_width](needle[needle_len - 1])

for i in reversed(range(0, vectorized_end, bool_mask_width)):
first_block = haystack.load[width=bool_mask_width](i)
last_block = haystack.load[width=bool_mask_width](i + needle_len - 1)
var first_block = haystack.load[width=bool_mask_width](i)
var last_block = haystack.load[width=bool_mask_width](
i + needle_len - 1
)

bool_mask = (first_needle == first_block) & (last_needle == last_block)
mask = pack_bits(bool_mask)
var bool_mask = (first_needle == first_block) & (
last_needle == last_block
)
var mask = pack_bits(bool_mask)

while mask:
offset = i + bool_mask_width - int(count_leading_zeros(mask))
var offset = i + bool_mask_width - int(count_leading_zeros(mask))
if memcmp(haystack + offset, needle + 1, needle_len - 1) == 0:
output = haystack + offset - 1
return
Expand Down

0 comments on commit 7213788

Please sign in to comment.