Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module Slicing #115

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/Module.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
- set_parameters
- reset
- init
- slice

1 change: 1 addition & 0 deletions docs/api/nn/EMAParamsTree.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
- set_parameters
- reset
- init
- slice

1 change: 1 addition & 0 deletions docs/api/nn/Sequential.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
- set_parameters
- reset
- init
- slice

1 change: 1 addition & 0 deletions elegy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OutputStates,
Mode,
)
from . import module_slicing
from .generalized_module.generalized_module import GeneralizedModule
from .generalized_optimizer.generalized_optimizer import GeneralizedOptimizer

Expand Down
4 changes: 2 additions & 2 deletions elegy/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def add_summary(
path: types.Path,
module: tp.Any,
value: tp.Any,
input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None,
) -> None:
"""
A hook that lets you define a summary in the current module. Its primary
Expand All @@ -127,8 +128,7 @@ def call(self, x):

if not summaries_active():
return

LOCAL.summaries.append(types.Summary(path, module, value))
LOCAL.summaries.append(types.Summary(path, module, value, input_values))


def get_losses() -> types.Logs:
Expand Down
4 changes: 2 additions & 2 deletions elegy/hooks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_summaries(self):
elegy.hooks.add_summary(("a", 0, "b"), None, 2.0)
summaries = elegy.hooks.get_summaries()

assert summaries[0] == (("a", 0, "b"), None, 2.0)
assert summaries[0] == (("a", 0, "b"), None, 2.0, None)

def test_no_summaries(self):
assert not elegy.hooks.summaries_active()
Expand Down Expand Up @@ -65,4 +65,4 @@ def f(x):
assert x == 6
assert losses["x_loss"] == 6
assert metrics["x"] == 7
assert summaries[0] == (("a", 0, "b"), jax.nn.relu, 8)
assert summaries[0] == (("a", 0, "b"), jax.nn.relu, 8, None)
4 changes: 3 additions & 1 deletion elegy/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def summary_step(

entries: tp.List[types.SummaryTableEntry] = []

for path, module, value in summaries:
for path, module, value, input_values in summaries:

module_params, module_states = self.api_module.get_summary_params(
path=path,
Expand All @@ -240,7 +240,9 @@ def summary_step(
module_type_name=(
module.__class__.__name__ if is_generalizable(module) else ""
),
module=module,
output_value=value,
input_value=input_values,
trainable_params_count=(
utils.parameters_count(module_params)
if module_params is not None
Expand Down
6 changes: 6 additions & 0 deletions elegy/model/model_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def summary(
depth: int = 2,
tablefmt: str = "fancy_grid",
return_repr: bool = False,
return_raw_entries: bool = False,
**tablulate_kwargs,
) -> tp.Optional[str]:
"""
Expand Down Expand Up @@ -529,6 +530,9 @@ def summary(
total_entry = entries[-1]
entries = entries[:-1]

if return_raw_entries:
return entries

depth_groups: tp.Dict[str, tp.List[types.SummaryTableEntry]] = toolz.groupby(
lambda entry: "/".join(entry.path.split("/")[:depth]), entries
)
Expand All @@ -541,7 +545,9 @@ def get_grouped_entry(
return types.SummaryTableEntry(
path=entry.path,
module_type_name=entry.module_type_name,
module=entry.module,
output_value=entry.output_value,
input_value=entry.input_value,
trainable_params_count=sum(
entry_.trainable_params_count for entry_ in group
),
Expand Down
65 changes: 61 additions & 4 deletions elegy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from elegy import hooks, utils
from elegy import types

# placeholder for module module_slicing.py
# injected from inside the module because of a circular dependency
module_slicing = None


__all__ = [
"Module",
"to_module",
Expand Down Expand Up @@ -158,6 +163,7 @@ class Module(metaclass=ModuleMeta):
"set_parameters",
"reset",
"init",
"slice",
]

def __init__(self, name: tp.Optional[str] = None, dtype: tp.Any = jnp.float32):
Expand Down Expand Up @@ -368,19 +374,25 @@ def __call__(self, *args, **kwargs) -> tp.Any:
if hooks.summaries_active():
path = get_module_path(self)
assert path is not None
hooks.add_summary(path, self, outputs)
hooks.add_summary(path, self, outputs, (args, kwargs))

return outputs

@abstractmethod
def call(self, *args, **kwargs):
...

def add_summary(self, name: str, f: tp.Any, value: tp.Any):
def add_summary(
self,
name: str,
f: tp.Any,
value: tp.Any,
input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None,
):
if hooks.summaries_active():
path = get_module_path(self) + (name,)
assert path is not None
hooks.add_summary(path, f, value)
hooks.add_summary(path, f, value, input_values)

def init(
self,
Expand Down Expand Up @@ -569,9 +581,54 @@ def add_parameter(
name=utils.get_name(regularizer),
value=regularizer(value),
)

return value

def slice(
self,
start_module: tp.Union["Module", str, None],
end_module: tp.Union[
"Module", str, None, tp.List[tp.Union["Module", str, None]]
],
sample_input: np.ndarray,
) -> "Module":
"""
Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`.

Current limitations:

- all operations between `start_module` and `end_module` must be performed by modules
i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()`
- only one `start_module` is supported
- all modules between `start_module` and `end_module` must have a single output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on your comment this is not a limitation now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still is.
It's now possible to have inner modules that have multiple inputs but the result module still must have only one input. Single output limitation also holds.



Example usage:
```
x = jnp.zeros((2, 224, 224, 3))
resnet = elegy.nets.resnet.ResNet18()
submodule = resnet.slice(
start_module=None,
end_module=["/res_net_block_1", "/res_net_block_3", "/res_net_block_5", "/res_net_block_7" ],
sample_input=x,
)
outputs = elegy.Model(submodule).predict(x)
assert outputs[0].shape == (2, 56, 56, 64)
assert outputs[1].shape == (2, 28, 28, 128)
assert outputs[2].shape == (2, 14, 14, 256)
assert outputs[3].shape == (2, 7, 7, 512)
```

Arguments:
start_module: Child module or name of a child module which will be the input module of the resulting module.
If `None`, the first module is used.
end_module: Child module, name of child module, `None` or a list thereof which will be the output module(s) of the resulting module.
If `None`, the last module is used.
sample_input: An array representing a sample input to the parent module.
"""
return module_slicing.slice_module_from_to(
self, start_module, end_module, sample_input
)

def update_parameter(self, name: str, value: tp.Any) -> None:
"""
Update a parameter of the current module.
Expand Down
Loading