Skip to content

Commit

Permalink
Add __iter__ (by ref for dict keys and list elements)
Browse files Browse the repository at this point in the history
Signed-off-by: rd4com <[email protected]>
  • Loading branch information
rd4com committed Nov 5, 2024
1 parent b782aee commit f213f48
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 0 deletions.
121 changes: 121 additions & 0 deletions stdlib/src/builtin/object.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ from sys.ffi import OpaquePointer
from builtin.builtin_list import _lit_mut_cast
from memory import Arc, memcmp, memcpy, UnsafePointer

from collections.dict import _DictEntryIter
from collections.list import _ListIter

from utils import StringRef, Variant

# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -2161,3 +2164,121 @@ struct object(
return hash(self._value.get_as_float())
# FIXME: hash(repr(self)) as fallback
return hash(repr(self))

fn __iter__(
ref [_]self,
) raises -> _ObjectIter[__origin_of(self._value.value)]:
"""Iterate over the object.
Returns:
An iterator object.
Raises:
If the object is not iterable.
Example:
```mojo
x = object([0, True, 2.0, "three"])
for i in x:
print(i[])
```
Note: iteration of lists and dicts keys are currently implemented.
"""
# Note: we iterate by reference, it should be fast for the user.
if self._value.is_list():
return _ObjectIter[__origin_of(self._value.value)](
_ObjectIter[__origin_of(self._value.value)].list_iter_hint,
self._value.get_as_list().__iter__(),
)
if self._value.is_dict():
return _ObjectIter[__origin_of(self._value.value)](
_ObjectIter[__origin_of(self._value.value)].dict_iter_hint,
self._value.get_as_dict().items(),
)
raise (
"'"
+ self._value._get_type_name()
+ "'"
+ " don't implement __iter__"
)


# ===----------------------------------------------------------------------=== #
# _ObjectIter
# ===----------------------------------------------------------------------=== #


@value
struct _ObjectIter[
iter_mutability: Bool, //,
iter_origin: Origin[iter_mutability].type,
]:
"""Iterator for object (list or dict).
Parameters:
iter_mutability: Whether the reference to the iterated value is mutable.
iter_origin: The origin of the iterator
"""

# Note: Return type is not mutable, because dict keys should not be mutated.

alias list_iter = _ListIter[object, False, iter_origin]
"""The iterator type for List"""
alias dict_iter = _DictEntryIter[object, object, iter_origin]
"""The iterator type for dict"""
alias list_iter_hint: Int = 0
"""The value used as a type hint for list iterators"""
alias dict_iter_hint: Int = 1
"""The value used as a type hint for dict iterators"""

var hint_type: Int
"Specifies the type of the iterator in the variant"
alias storage_type = Variant[Self.list_iter, Self.dict_iter]
"Variant type used to store an iterator"

var iterator: Self.storage_type

fn __iter__(self) -> Self:
return self

fn __next__(
inout self,
) raises -> Pointer[object, _lit_mut_cast[iter_origin, False].result]:
"""Return the next item and update to point to subsequent item.
Returns:
The next item in the iterable object that this iterator points to.
"""
if self.hint_type == Self.list_iter_hint:
return Pointer[
object, _lit_mut_cast[iter_origin, False].result
].address_of(
UnsafePointer.address_of(
self.iterator[Self.list_iter].__next__()[]
)[]
)
if self.hint_type == Self.dict_iter_hint:
return Pointer[
object, _lit_mut_cast[iter_origin, False].result
].address_of(
UnsafePointer.address_of(
self.iterator[Self.dict_iter].__next__()[].key
)[]
)
raise "Error in _ObjectIter.__next__"

@always_inline
fn __hasmore__(self) -> Bool:
if self.hint_type == Self.list_iter_hint:
return self.iterator[Self.list_iter].__hasmore__()
if self.hint_type == Self.dict_iter_hint:
return self.iterator[Self.dict_iter].__hasmore__()
return False

fn __len__(self) raises -> Int:
if self.hint_type == Self.list_iter_hint:
return self.iterator[Self.list_iter].__len__()
if self.hint_type == Self.dict_iter_hint:
return self.iterator[Self.dict_iter].__len__()
raise "Error in _ObjectIter.__len__"
38 changes: 38 additions & 0 deletions stdlib/test/builtin/test_object.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,43 @@ def test_object_getattr():
x.new_attr = 1


def test_object_iter():
var value = [0, 1]
x = object([0, True, 2.0, "three"])
x.append(value)
# asap del of value
assert_equal(x._value.get_as_list()[4]._value.ref_count(), 1)
i = 0
y = object([])
for element in x:
assert_equal(element[], x[i])
y.append(element[])
if i == 0:
x[4].append(3)
i += 1
assert_equal(i, 5)
assert_equal(x, y)
assert_equal(repr(y), "[0, True, 2.0, 'three', [0, 1, 3]]")
assert_equal(x._value.get_as_list()[4]._value.ref_count(), 2)
_ = y

x = object.dict()
x["one"] = 1
x[2] = 2.0
i = 0
results = object([])
for element in x:
if i == 0:
assert_equal(x["one"], x[element[]])
assert_equal(element[], object("one"))
elif i == 1:
assert_equal(element[], object(2))
results.append(element[])
i += 1
assert_equal(i, 2)
assert_equal(repr(results), "['one', 2]")


def main():
test_object_ctors()
test_comparison_ops()
Expand All @@ -752,3 +789,4 @@ def main():
test_object_tuple_add()
test_object_get_type_id()
test_object_getattr()
test_object_iter()

0 comments on commit f213f48

Please sign in to comment.