Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Make String and StringLiteral .splitlines() return List[StringSlice] #3894

Open
wants to merge 8 commits into
base: nightly
Choose a base branch
from
10 changes: 5 additions & 5 deletions stdlib/src/builtin/string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -743,10 +743,10 @@ struct StringLiteral(
"""
return str(self).split(sep, maxsplit)

fn splitlines(self, keepends: Bool = False) -> List[String]:
"""Split the string literal at line boundaries. This corresponds to Python's
[universal newlines:](
https://docs.python.org/3/library/stdtypes.html#str.splitlines)
fn splitlines(self, keepends: Bool = False) -> List[StaticString]:
"""Split the string literal at line boundaries. This corresponds to
Python's [universal newlines:](
https://docs.python.org/3/library/stdtypes.html#str.splitlines)
`"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`.

Args:
Expand All @@ -755,7 +755,7 @@ struct StringLiteral(
Returns:
A List of Strings containing the input split by line boundaries.
"""
return _to_string_list(self.as_string_slice().splitlines(keepends))
return self.as_string_slice().splitlines(keepends)

fn count(self, substr: String) -> Int:
"""Return the number of non-overlapping occurrences of substring
Expand Down
8 changes: 5 additions & 3 deletions stdlib/src/collections/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1882,10 +1882,12 @@ struct String(

return output

fn splitlines(self, keepends: Bool = False) -> List[String]:
fn splitlines(
ref self, keepends: Bool = False
) -> List[StringSlice[__origin_of(self)]]:
"""Split the string at line boundaries. This corresponds to Python's
[universal newlines:](
https://docs.python.org/3/library/stdtypes.html#str.splitlines)
https://docs.python.org/3/library/stdtypes.html#str.splitlines)
`"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`.

Args:
Expand All @@ -1894,7 +1896,7 @@ struct String(
Returns:
A List of Strings containing the input split by line boundaries.
"""
return _to_string_list(self.as_string_slice().splitlines(keepends))
return self.as_string_slice().splitlines(keepends)

fn replace(self, old: String, new: String) -> String:
"""Return a copy of the string with all occurrences of substring `old`
Expand Down
38 changes: 17 additions & 21 deletions stdlib/src/utils/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1047,29 +1047,23 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
offset += b_len
return length != 0

fn splitlines[
O: ImmutableOrigin, //
](self: StringSlice[O], keepends: Bool = False) -> List[StringSlice[O]]:
fn splitlines(self, keepends: Bool = False) -> List[Self]:
"""Split the string at line boundaries. This corresponds to Python's
[universal newlines:](
https://docs.python.org/3/library/stdtypes.html#str.splitlines)
`"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`.

Parameters:
O: The immutable origin.

Args:
keepends: If True, line breaks are kept in the resulting strings.

Returns:
A List of Strings containing the input split by line boundaries.
"""

# highly performance sensitive code, benchmark before touching
alias `\r` = UInt8(ord("\r"))
alias `\n` = UInt8(ord("\n"))

output = List[StringSlice[O]](capacity=128) # guessing
output = List[Self](capacity=128) # guessing
var ptr = self.unsafe_ptr()
var length = self.byte_length()
var offset = 0
Expand Down Expand Up @@ -1099,7 +1093,7 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
eol_start += char_len

var str_len = eol_start - offset + int(keepends) * eol_length
var s = StringSlice[O](ptr=ptr + offset, length=str_len)
var s = Self(ptr=ptr + offset, length=str_len)
output.append(s)
offset = eol_start + eol_length

Expand All @@ -1116,29 +1110,30 @@ fn _to_string_list[
len_fn: fn (T) -> Int,
unsafe_ptr_fn: fn (T) -> UnsafePointer[Byte],
](items: List[T]) -> List[String]:
i_len = len(items)
i_ptr = items.unsafe_ptr()
out_ptr = UnsafePointer[String].alloc(i_len)
var i_len = len(items)
var i_ptr = items.unsafe_ptr()
var out_ptr = UnsafePointer[String].alloc(i_len)

for i in range(i_len):
og_len = len_fn(i_ptr[i])
f_len = og_len + 1 # null terminator
p = UnsafePointer[Byte].alloc(f_len)
og_ptr = unsafe_ptr_fn(i_ptr[i])
var og_len = len_fn(i_ptr[i])
var f_len = og_len + 1 # null terminator
var p = UnsafePointer[Byte].alloc(f_len)
var og_ptr = unsafe_ptr_fn(i_ptr[i])
memcpy(p, og_ptr, og_len)
p[og_len] = 0 # null terminator
buf = String._buffer_type(ptr=p, length=f_len, capacity=f_len)
var buf = String._buffer_type(ptr=p, length=f_len, capacity=f_len)
martinvuyk marked this conversation as resolved.
Show resolved Hide resolved
(out_ptr + i).init_pointee_move(String(buf^))
return List[String](ptr=out_ptr, length=i_len, capacity=i_len)


@always_inline
fn _to_string_list[
O: ImmutableOrigin, //
fn to_string_list[
mut: Bool, O: Origin[mut], //
](items: List[StringSlice[O]]) -> List[String]:
"""Create a list of Strings **copying** the existing data.

Parameters:
mut: The mutability of the origin.
O: The origin of the data.

Args:
Expand All @@ -1158,12 +1153,13 @@ fn _to_string_list[


@always_inline
fn _to_string_list[
O: ImmutableOrigin, //
fn to_string_list[
mut: Bool, O: Origin[mut], //
](items: List[Span[Byte, O]]) -> List[String]:
"""Create a list of Strings **copying** the existing data.

Parameters:
mut: The mutability of the origin.
O: The origin of the data.

Args:
Expand Down
65 changes: 50 additions & 15 deletions stdlib/test/builtin/test_string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ from testing import (
assert_raises,
assert_true,
)
from utils import StringSlice


def test_add():
Expand Down Expand Up @@ -430,42 +431,76 @@ def test_split():


def test_splitlines():
alias L = List[String]
alias S = StringSlice[StaticConstantOrigin]
alias L = List[StringSlice[StaticConstantOrigin]]

# FIXME: remove once StringSlice conforms to TestableCollectionElement
fn _assert_equal[
O1: ImmutableOrigin, O2: ImmutableOrigin
](l1: List[StringSlice[O1]], l2: List[StringSlice[O2]]) raises:
assert_equal(len(l1), len(l2))
for i in range(len(l1)):
assert_equal(str(l1[i]), str(l2[i]))

# FIXME: remove once StringSlice conforms to TestableCollectionElement
fn _assert_equal[
O1: ImmutableOrigin
](l1: List[StringSlice[O1]], l2: List[String]) raises:
assert_equal(len(l1), len(l2))
for i in range(len(l1)):
assert_equal(str(l1[i]), l2[i])

martinvuyk marked this conversation as resolved.
Show resolved Hide resolved
# Test with no line breaks
assert_equal("hello world".splitlines(), L("hello world"))
_assert_equal(S("hello world").splitlines(), L("hello world"))

# Test with line breaks
assert_equal("hello\nworld".splitlines(), L("hello", "world"))
assert_equal("hello\rworld".splitlines(), L("hello", "world"))
assert_equal("hello\r\nworld".splitlines(), L("hello", "world"))
_assert_equal(S("hello\nworld").splitlines(), L("hello", "world"))
_assert_equal(S("hello\rworld").splitlines(), L("hello", "world"))
_assert_equal(S("hello\r\nworld").splitlines(), L("hello", "world"))

# Test with multiple different line breaks
s1 = "hello\nworld\r\nmojo\rlanguage\r\n"
s1 = S("hello\nworld\r\nmojo\rlanguage\r\n")
hello_mojo = L("hello", "world", "mojo", "language")
assert_equal(s1.splitlines(), hello_mojo)
assert_equal(
_assert_equal(s1.splitlines(), hello_mojo)
_assert_equal(
s1.splitlines(keepends=True),
L("hello\n", "world\r\n", "mojo\r", "language\r\n"),
)

# Test with an empty string
assert_equal("".splitlines(), L())
_assert_equal(S("").splitlines(), L())
# test \v \f \x1c \x1d
s2 = "hello\vworld\fmojo\x1clanguage\x1d"
assert_equal(s2.splitlines(), hello_mojo)
assert_equal(
s2 = S("hello\vworld\fmojo\x1clanguage\x1d")
_assert_equal(s2.splitlines(), hello_mojo)
_assert_equal(
s2.splitlines(keepends=True),
L("hello\v", "world\f", "mojo\x1c", "language\x1d"),
)

# test \x1c \x1d \x1e
s3 = "hello\x1cworld\x1dmojo\x1elanguage\x1e"
assert_equal(s3.splitlines(), hello_mojo)
assert_equal(
s3 = S("hello\x1cworld\x1dmojo\x1elanguage\x1e")
_assert_equal(s3.splitlines(), hello_mojo)
_assert_equal(
s3.splitlines(keepends=True),
L("hello\x1c", "world\x1d", "mojo\x1e", "language\x1e"),
)

# test \x85 \u2028 \u2029
var next_line = String(List[UInt8](0xC2, 0x85, 0))
"""TODO: \\x85"""
var unicode_line_sep = String(List[UInt8](0xE2, 0x80, 0xA8, 0))
"""TODO: \\u2028"""
var unicode_paragraph_sep = String(List[UInt8](0xE2, 0x80, 0xA9, 0))
"""TODO: \\u2029"""

for i in List(next_line, unicode_line_sep, unicode_paragraph_sep):
u = i[]
item = String("").join("hello", u, "world", u, "mojo", u, "language", u)
s = StringSlice(item)
_assert_equal(s.splitlines(), hello_mojo)
items = List("hello" + u, "world" + u, "mojo" + u, "language" + u)
_assert_equal(s.splitlines(keepends=True), items)


def test_float_conversion():
assert_equal(("4.5").__float__(), 4.5)
Expand Down
65 changes: 41 additions & 24 deletions stdlib/test/collections/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -849,58 +849,75 @@ def test_split():


def test_splitlines():
alias L = List[String]
alias S = StringSlice[StaticConstantOrigin]
alias L = List[StringSlice[StaticConstantOrigin]]

# FIXME: remove once StringSlice conforms to TestableCollectionElement
fn _assert_equal[
O1: ImmutableOrigin, O2: ImmutableOrigin
](l1: List[StringSlice[O1]], l2: List[StringSlice[O2]]) raises:
assert_equal(len(l1), len(l2))
for i in range(len(l1)):
assert_equal(str(l1[i]), str(l2[i]))

# FIXME: remove once StringSlice conforms to TestableCollectionElement
fn _assert_equal[
O1: ImmutableOrigin
](l1: List[StringSlice[O1]], l2: List[String]) raises:
assert_equal(len(l1), len(l2))
for i in range(len(l1)):
assert_equal(str(l1[i]), l2[i])

# Test with no line breaks
assert_equal(String("hello world").splitlines(), L("hello world"))
_assert_equal(S("hello world").splitlines(), L("hello world"))

# Test with line breaks
assert_equal(String("hello\nworld").splitlines(), L("hello", "world"))
assert_equal(String("hello\rworld").splitlines(), L("hello", "world"))
assert_equal(String("hello\r\nworld").splitlines(), L("hello", "world"))
_assert_equal(S("hello\nworld").splitlines(), L("hello", "world"))
_assert_equal(S("hello\rworld").splitlines(), L("hello", "world"))
_assert_equal(S("hello\r\nworld").splitlines(), L("hello", "world"))

# Test with multiple different line breaks
s1 = String("hello\nworld\r\nmojo\rlanguage\r\n")
s1 = S("hello\nworld\r\nmojo\rlanguage\r\n")
hello_mojo = L("hello", "world", "mojo", "language")
assert_equal(s1.splitlines(), hello_mojo)
assert_equal(
_assert_equal(s1.splitlines(), hello_mojo)
_assert_equal(
s1.splitlines(keepends=True),
L("hello\n", "world\r\n", "mojo\r", "language\r\n"),
)

# Test with an empty string
assert_equal(String("").splitlines(), L())
_assert_equal(S("").splitlines(), L())
# test \v \f \x1c \x1d
s2 = String("hello\vworld\fmojo\x1clanguage\x1d")
assert_equal(s2.splitlines(), hello_mojo)
assert_equal(
s2 = S("hello\vworld\fmojo\x1clanguage\x1d")
_assert_equal(s2.splitlines(), hello_mojo)
_assert_equal(
s2.splitlines(keepends=True),
L("hello\v", "world\f", "mojo\x1c", "language\x1d"),
)

# test \x1c \x1d \x1e
s3 = String("hello\x1cworld\x1dmojo\x1elanguage\x1e")
assert_equal(s3.splitlines(), hello_mojo)
assert_equal(
s3 = S("hello\x1cworld\x1dmojo\x1elanguage\x1e")
_assert_equal(s3.splitlines(), hello_mojo)
_assert_equal(
s3.splitlines(keepends=True),
L("hello\x1c", "world\x1d", "mojo\x1e", "language\x1e"),
)

# test \x85 \u2028 \u2029
var next_line = List[UInt8](0xC2, 0x85, 0)
var next_line = String(List[UInt8](0xC2, 0x85, 0))
"""TODO: \\x85"""
var unicode_line_sep = List[UInt8](0xE2, 0x80, 0xA8, 0)
var unicode_line_sep = String(List[UInt8](0xE2, 0x80, 0xA8, 0))
"""TODO: \\u2028"""
var unicode_paragraph_sep = List[UInt8](0xE2, 0x80, 0xA9, 0)
var unicode_paragraph_sep = String(List[UInt8](0xE2, 0x80, 0xA9, 0))
"""TODO: \\u2029"""

for i in List(next_line, unicode_line_sep, unicode_paragraph_sep):
u = String(i[])
u = i[]
item = String("").join("hello", u, "world", u, "mojo", u, "language", u)
assert_equal(item.splitlines(), hello_mojo)
assert_equal(
item.splitlines(keepends=True),
L("hello" + u, "world" + u, "mojo" + u, "language" + u),
)
s = StringSlice(item)
_assert_equal(s.splitlines(), hello_mojo)
items = List("hello" + u, "world" + u, "mojo" + u, "language" + u)
_assert_equal(s.splitlines(keepends=True), items)


def test_isupper():
Expand Down
Loading