Skip to content

Commit

Permalink
Checking changes back in with a bug fix.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696909802
  • Loading branch information
Orbax Authors committed Nov 15, 2024
1 parent 8c02199 commit e8a407b
Show file tree
Hide file tree
Showing 11 changed files with 726 additions and 616 deletions.
2 changes: 2 additions & 0 deletions export/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Adds a new module base class and two new module subclasses one for TensorFlow and
one for Orbax.
- Moves the TF dependent export logic out of JaxModule and into TensorFlowModule.
- Removes the JaxModule dependency on TF and moves more logic frome ExportManager
to TensorFlowExport.
- Wires up the Orbax Model export flow.
- Adds a checkpoint path to the jax2obm_kwargs to allow specifying a checkpoing
to the Orbax export pathway.
Expand Down
22 changes: 15 additions & 7 deletions export/orbax/export/export_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,35 @@
"""Abstract base class for different export classes."""

import abc
from typing import Any

import tensorflow as tf
from typing import Any, Callable, Mapping, Sequence
from orbax.export import jax_module
from orbax.export import serving_config as osc


class ExportBase(abc.ABC):
"""Abstract base class for different export classes."""

# TODO: b/363033166 - Remove dependencies on TF in the base class.
def __init__(
self,
module: jax_module.JaxModule,
serving_configs: Sequence[osc.ServingConfig],
):
self._module = module
self._serving_configs = serving_configs

@abc.abstractmethod
def serving_signatures(self) -> Mapping[str, Callable[..., Any]]:
"""Returns a map of signature keys to serving functions."""

@abc.abstractmethod
def save(
self,
# TODO(b/363033166): Change this annotation once TF isolation is done.
jax_module: tf.Module,
model_path: str,
**kwargs: Any,
):
"""Saves the model.
Args:
jax_module: The `JaxModule` to be exported.
model_path: The path to save the model.
**kwargs: Additional arguments to pass to the `save` method. Accepted
arguments are `save_options` and `serving_signatures`.
Expand Down
138 changes: 21 additions & 117 deletions export/orbax/export/export_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Manage the exporting of a JAXModule."""

from collections.abc import Mapping, Sequence
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional, cast

from etils.epy import reraise_utils
from orbax.export import config
Expand All @@ -24,7 +24,6 @@
from orbax.export import obm_export
from orbax.export import serving_config as osc
from orbax.export import tensorflow_export
from orbax.export import utils
from orbax.export.modules import obm_module
import tensorflow as tf

Expand Down Expand Up @@ -57,12 +56,12 @@ def __init__(
f' must be the same. The former is {version}. The latter is '
f'{module.export_version()}.'
)
# TODO(b/363033166): Skip this step for OBM once TF isolation is done.
self._module = tf.Module()
self._module.computation_module = module
self._serving_signatures = {}
if version == constants.ExportModelType.ORBAX_MODEL:
self.serialization_functions = obm_export.ObmExport()
self._version = version
self._jax_module = module
if self._version == constants.ExportModelType.ORBAX_MODEL:
self._serialization_functions = obm_export.ObmExport(
self._jax_module, serving_configs
)
obm_module_ = module.orbax_module()
if not isinstance(obm_module_, obm_module.ObmModule):
raise ValueError(
Expand All @@ -72,25 +71,26 @@ def __init__(
# TODO(bdwalker): Let `ObmExport.__init__() do this `build()` step.
obm_module_.build(serving_configs)
else:
self.serialization_functions = tensorflow_export.TensorFlowExport()
# TODO(bdwalker): Let `TensorFlowExport.__init__() do this
# `process_serving_configs()` step.
process_serving_configs(
serving_configs,
obx_export_config.obx_export_tf_preprocess_only, # pytype: disable=attribute-error
self._module,
self._serving_signatures,
self._serialization_functions = tensorflow_export.TensorFlowExport(
self._jax_module, serving_configs
)

@property
def tf_module(self) -> tf.Module:
"""Returns the tf.module maintained by the export manager."""
return self._module
if self._version == constants.ExportModelType.ORBAX_MODEL:
raise TypeError(
'tf_module is not implemented for export version'
' ExportModelType.ORBAX_MODEL.'
)
return cast(
tensorflow_export.TensorFlowExport,
self._serialization_functions).tf_export_module()

@property
def serving_signatures(self) -> Mapping[str, Callable[..., Any]]:
"""Returns a map of signature keys to serving functions."""
return self._serving_signatures
return self._serialization_functions.serving_signatures

def save(
self,
Expand All @@ -107,108 +107,12 @@ def save(
signature_overrides: signatures to override the self-maintained ones, or
additional signatures to export.
"""
serving_signatures = dict(self._serving_signatures)
if signature_overrides:
serving_signatures.update(signature_overrides)

self.serialization_functions.save(
jax_module=self._module,
self._serialization_functions.save(
model_path=model_path,
save_options=save_options,
serving_signatures=serving_signatures,
signature_overrides=signature_overrides,
)

def load(self, model_path: str, **kwargs: Any):
loaded = self.serialization_functions.load(model_path, **kwargs)
loaded = self._serialization_functions.load(model_path, **kwargs)
return loaded


def make_e2e_inference_fn(
model_fn: Callable[..., Any],
serving_config: osc.ServingConfig,
) -> Callable[..., Any]:
"""Creates an concrete end-to-end inference tf.function.
Args:
model_fn: a callable in TF context for the numeric computation.
serving_config: a ServingConfig that defines the input sigature,
pre-processor and post-processor of the inference function.
Returns:
A tf.function for end-to-end inference.
"""
infer_step_func_map = serving_config.bind(model_fn, require_numpy=False)
signature_key = serving_config.get_signature_keys()[0]
return utils.with_default_args(
infer_step_func_map[signature_key], serving_config.get_input_signature()
)


def process_serving_configs(
serving_configs: Sequence[osc.ServingConfig],
obx_export_tf_preprocess_only: bool,
module: tf.Module,
serving_signatures: Dict[str, Callable[..., Any]],
):
"""Processes the serving functions into their TF wrapped concrete functions.
The function will use the serving_configs and the methods defined in the
provided module to populate the serving_signatures map with the concrete
inference functions.
In addition, if trackable resources are provided in the serving_configs,
they will be added to the module's tf_trackable_resources property.
Args:
serving_configs: a sequence of which each element is a `ServingConfig`
cooresponding to a serving signature of the exported SavedModel.
obx_export_tf_preprocess_only: a boolean indicating whether to export only
the preprocessor.
module: A tf module that will provide the method definitions. The module
should have a JaxModule set as a computation_module property.
serving_signatures: a map of signature keys to serving functions. This map
will be populated by this function.
"""
tf_trackable_resources = []
for sc in serving_configs:
with maybe_reraise(f'Failed exporting signature_key={sc.signature_key} '):
if obx_export_tf_preprocess_only:
if not sc.tf_preprocessor:
raise ValueError(
'serving_config.tf_preprocessor must be provided when'
' in `obx_export_tf_preprocess_only` mode.'
)

def tf_preprocessor(*inputs):
return tf.nest.flatten(sc.tf_preprocessor(*inputs)) # pylint: disable=cell-var-from-loop

preprocessor = utils.with_default_args(
tf_preprocessor, sc.get_input_signature()
)
inference_fn = preprocessor
else:
method = sc.get_infer_step(module.computation_module.methods)
inference_fn = make_e2e_inference_fn(method, sc)

if isinstance(sc.signature_key, str):
keys = [sc.signature_key]
else:
keys = sc.signature_key

for key in keys:
if key in serving_signatures:
raise ValueError(
f'Duplicated key "{sc.signature_key}" in `serving_configs`.'
)
serving_signatures[key] = inference_fn

if sc.extra_trackable_resources is not None:
tf_trackable_resources.append(sc.extra_trackable_resources)

if len(serving_configs) == 1:
# Make this module callable. Once exported, it can be loaded back in
# python and the nested input structure will be preservered. In
# contrast, signatures will flatten the TensorSpecs of the to kwargs.
module.__call__ = inference_fn

module.tf_trackable_resources = tf_trackable_resources
Loading

0 comments on commit e8a407b

Please sign in to comment.