Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checking changes back in with a bug fix. #1341

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions export/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Adds a new module base class and two new module subclasses one for TensorFlow and
one for Orbax.
- Moves the TF dependent export logic out of JaxModule and into TensorFlowModule.
- Removes the JaxModule dependency on TF and moves more logic frome ExportManager
to TensorFlowExport.
- Wires up the Orbax Model export flow.
- Adds a checkpoint path to the jax2obm_kwargs to allow specifying a checkpoing
to the Orbax export pathway.
Expand Down
22 changes: 15 additions & 7 deletions export/orbax/export/export_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,35 @@
"""Abstract base class for different export classes."""

import abc
from typing import Any

import tensorflow as tf
from typing import Any, Callable, Mapping, Sequence
from orbax.export import jax_module
from orbax.export import serving_config as osc


class ExportBase(abc.ABC):
"""Abstract base class for different export classes."""

# TODO: b/363033166 - Remove dependencies on TF in the base class.
def __init__(
self,
module: jax_module.JaxModule,
serving_configs: Sequence[osc.ServingConfig],
):
self._module = module
self._serving_configs = serving_configs

@abc.abstractmethod
def serving_signatures(self) -> Mapping[str, Callable[..., Any]]:
"""Returns a map of signature keys to serving functions."""

@abc.abstractmethod
def save(
self,
# TODO(b/363033166): Change this annotation once TF isolation is done.
jax_module: tf.Module,
model_path: str,
**kwargs: Any,
):
"""Saves the model.

Args:
jax_module: The `JaxModule` to be exported.
model_path: The path to save the model.
**kwargs: Additional arguments to pass to the `save` method. Accepted
arguments are `save_options` and `serving_signatures`.
Expand Down
138 changes: 21 additions & 117 deletions export/orbax/export/export_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Manage the exporting of a JAXModule."""

from collections.abc import Mapping, Sequence
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional, cast

from etils.epy import reraise_utils
from orbax.export import config
Expand All @@ -24,7 +24,6 @@
from orbax.export import obm_export
from orbax.export import serving_config as osc
from orbax.export import tensorflow_export
from orbax.export import utils
from orbax.export.modules import obm_module
import tensorflow as tf

Expand Down Expand Up @@ -57,12 +56,12 @@ def __init__(
f' must be the same. The former is {version}. The latter is '
f'{module.export_version()}.'
)
# TODO(b/363033166): Skip this step for OBM once TF isolation is done.
self._module = tf.Module()
self._module.computation_module = module
self._serving_signatures = {}
if version == constants.ExportModelType.ORBAX_MODEL:
self.serialization_functions = obm_export.ObmExport()
self._version = version
self._jax_module = module
if self._version == constants.ExportModelType.ORBAX_MODEL:
self._serialization_functions = obm_export.ObmExport(
self._jax_module, serving_configs
)
obm_module_ = module.orbax_module()
if not isinstance(obm_module_, obm_module.ObmModule):
raise ValueError(
Expand All @@ -72,25 +71,26 @@ def __init__(
# TODO(bdwalker): Let `ObmExport.__init__() do this `build()` step.
obm_module_.build(serving_configs)
else:
self.serialization_functions = tensorflow_export.TensorFlowExport()
# TODO(bdwalker): Let `TensorFlowExport.__init__() do this
# `process_serving_configs()` step.
process_serving_configs(
serving_configs,
obx_export_config.obx_export_tf_preprocess_only, # pytype: disable=attribute-error
self._module,
self._serving_signatures,
self._serialization_functions = tensorflow_export.TensorFlowExport(
self._jax_module, serving_configs
)

@property
def tf_module(self) -> tf.Module:
"""Returns the tf.module maintained by the export manager."""
return self._module
if self._version == constants.ExportModelType.ORBAX_MODEL:
raise TypeError(
'tf_module is not implemented for export version'
' ExportModelType.ORBAX_MODEL.'
)
return cast(
tensorflow_export.TensorFlowExport,
self._serialization_functions).tf_export_module()

@property
def serving_signatures(self) -> Mapping[str, Callable[..., Any]]:
"""Returns a map of signature keys to serving functions."""
return self._serving_signatures
return self._serialization_functions.serving_signatures

def save(
self,
Expand All @@ -107,108 +107,12 @@ def save(
signature_overrides: signatures to override the self-maintained ones, or
additional signatures to export.
"""
serving_signatures = dict(self._serving_signatures)
if signature_overrides:
serving_signatures.update(signature_overrides)

self.serialization_functions.save(
jax_module=self._module,
self._serialization_functions.save(
model_path=model_path,
save_options=save_options,
serving_signatures=serving_signatures,
signature_overrides=signature_overrides,
)

def load(self, model_path: str, **kwargs: Any):
loaded = self.serialization_functions.load(model_path, **kwargs)
loaded = self._serialization_functions.load(model_path, **kwargs)
return loaded


def make_e2e_inference_fn(
model_fn: Callable[..., Any],
serving_config: osc.ServingConfig,
) -> Callable[..., Any]:
"""Creates an concrete end-to-end inference tf.function.
Args:
model_fn: a callable in TF context for the numeric computation.
serving_config: a ServingConfig that defines the input sigature,
pre-processor and post-processor of the inference function.
Returns:
A tf.function for end-to-end inference.
"""
infer_step_func_map = serving_config.bind(model_fn, require_numpy=False)
signature_key = serving_config.get_signature_keys()[0]
return utils.with_default_args(
infer_step_func_map[signature_key], serving_config.get_input_signature()
)


def process_serving_configs(
serving_configs: Sequence[osc.ServingConfig],
obx_export_tf_preprocess_only: bool,
module: tf.Module,
serving_signatures: Dict[str, Callable[..., Any]],
):
"""Processes the serving functions into their TF wrapped concrete functions.
The function will use the serving_configs and the methods defined in the
provided module to populate the serving_signatures map with the concrete
inference functions.
In addition, if trackable resources are provided in the serving_configs,
they will be added to the module's tf_trackable_resources property.
Args:
serving_configs: a sequence of which each element is a `ServingConfig`
cooresponding to a serving signature of the exported SavedModel.
obx_export_tf_preprocess_only: a boolean indicating whether to export only
the preprocessor.
module: A tf module that will provide the method definitions. The module
should have a JaxModule set as a computation_module property.
serving_signatures: a map of signature keys to serving functions. This map
will be populated by this function.
"""
tf_trackable_resources = []
for sc in serving_configs:
with maybe_reraise(f'Failed exporting signature_key={sc.signature_key} '):
if obx_export_tf_preprocess_only:
if not sc.tf_preprocessor:
raise ValueError(
'serving_config.tf_preprocessor must be provided when'
' in `obx_export_tf_preprocess_only` mode.'
)

def tf_preprocessor(*inputs):
return tf.nest.flatten(sc.tf_preprocessor(*inputs)) # pylint: disable=cell-var-from-loop

preprocessor = utils.with_default_args(
tf_preprocessor, sc.get_input_signature()
)
inference_fn = preprocessor
else:
method = sc.get_infer_step(module.computation_module.methods)
inference_fn = make_e2e_inference_fn(method, sc)

if isinstance(sc.signature_key, str):
keys = [sc.signature_key]
else:
keys = sc.signature_key

for key in keys:
if key in serving_signatures:
raise ValueError(
f'Duplicated key "{sc.signature_key}" in `serving_configs`.'
)
serving_signatures[key] = inference_fn

if sc.extra_trackable_resources is not None:
tf_trackable_resources.append(sc.extra_trackable_resources)

if len(serving_configs) == 1:
# Make this module callable. Once exported, it can be loaded back in
# python and the nested input structure will be preservered. In
# contrast, signatures will flatten the TensorSpecs of the to kwargs.
module.__call__ = inference_fn

module.tf_trackable_resources = tf_trackable_resources
Loading
Loading