A dead simple Python package for creating custom JAX pytree objects.
- Strives to be minimal, the implementation is just ~100 lines of code
- Has no dependencies other than JAX
- Its compatible with both
dataclasses
and regular classes - It has no intention of supporting Neural Network use cases (e.g. partitioning)
pip install simple-pytree
import jax
from simple_pytree import Pytree
class Foo(Pytree):
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo)
assert foo.x == -1 and foo.y == -2
You can mark fields as static by assigning static_field()
to a class attribute with the same name
as the instance attribute:
import jax
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified
assert foo.x == -1 and foo.y == 2
Static fields are not included in the pytree leaves, they are passed as pytree metadata instead.
simple_pytree
provides a dataclass
decorator you can use with classes
that contain static_field
s:
import jax
from simple_pytree import Pytree, dataclass, static_field
@dataclass
class Foo(Pytree):
x: int
y: int = static_field(default=2)
foo = Foo(1)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified
assert foo.x == -1 and foo.y == 2
simple_pytree.dataclass
is just a wrapper around dataclasses.dataclass
but
when used static analysis tools and IDEs will understand that static_field
is a
field specifier just like dataclasses.field
.
Pytree
objects are immutable by default after __init__
:
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo.x = 3 # AttributeError
If you want to make them mutable, you can use the mutable
argument in class definition:
from simple_pytree import Pytree, static_field
class Foo(Pytree, mutable=True):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo.x = 3 # OK
If you want to make a copy of a Pytree
object with some fields modified, you can use the .replace()
method:
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = foo.replace(x=10)
assert foo.x == 10 and foo.y == 2
replace
works for both mutable and immutable Pytree
objects. If the class
is a dataclass
, replace
internally use dataclasses.replace
.