From 0f9e5628903a3856beb8d71567f703d8546c9499 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Fri, 15 Nov 2024 14:53:02 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 697008941 --- export/CHANGELOG.md | 2 + export/orbax/export/export_base.py | 22 +- export/orbax/export/export_manager.py | 138 +---- export/orbax/export/export_manager_test.py | 535 +++--------------- export/orbax/export/jax_module.py | 4 +- export/orbax/export/modules/obm_module.py | 14 + .../orbax/export/modules/tensorflow_module.py | 13 +- export/orbax/export/obm_export.py | 19 +- export/orbax/export/tensorflow_export.py | 118 +++- export/orbax/export/tensorflow_export_test.py | 456 ++++++++++++++- export/orbax/export/utils.py | 21 + 11 files changed, 726 insertions(+), 616 deletions(-) diff --git a/export/CHANGELOG.md b/export/CHANGELOG.md index 0dfe2315..b46a4a9a 100644 --- a/export/CHANGELOG.md +++ b/export/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Adds a new module base class and two new module subclasses one for TensorFlow and one for Orbax. - Moves the TF dependent export logic out of JaxModule and into TensorFlowModule. +- Removes the JaxModule dependency on TF and moves more logic frome ExportManager + to TensorFlowExport. - Wires up the Orbax Model export flow. - Adds a checkpoint path to the jax2obm_kwargs to allow specifying a checkpoing to the Orbax export pathway. diff --git a/export/orbax/export/export_base.py b/export/orbax/export/export_base.py index 7faf84d8..879c221e 100644 --- a/export/orbax/export/export_base.py +++ b/export/orbax/export/export_base.py @@ -15,27 +15,35 @@ """Abstract base class for different export classes.""" import abc -from typing import Any - -import tensorflow as tf +from typing import Any, Callable, Mapping, Sequence +from orbax.export import jax_module +from orbax.export import serving_config as osc class ExportBase(abc.ABC): """Abstract base class for different export classes.""" - # TODO: b/363033166 - Remove dependencies on TF in the base class. + def __init__( + self, + module: jax_module.JaxModule, + serving_configs: Sequence[osc.ServingConfig], + ): + self._module = module + self._serving_configs = serving_configs + + @abc.abstractmethod + def serving_signatures(self) -> Mapping[str, Callable[..., Any]]: + """Returns a map of signature keys to serving functions.""" + @abc.abstractmethod def save( self, - # TODO(b/363033166): Change this annotation once TF isolation is done. - jax_module: tf.Module, model_path: str, **kwargs: Any, ): """Saves the model. Args: - jax_module: The `JaxModule` to be exported. model_path: The path to save the model. **kwargs: Additional arguments to pass to the `save` method. Accepted arguments are `save_options` and `serving_signatures`. diff --git a/export/orbax/export/export_manager.py b/export/orbax/export/export_manager.py index 655f7f5c..3b9da718 100644 --- a/export/orbax/export/export_manager.py +++ b/export/orbax/export/export_manager.py @@ -15,7 +15,7 @@ """Manage the exporting of a JAXModule.""" from collections.abc import Mapping, Sequence -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional, cast from etils.epy import reraise_utils from orbax.export import config @@ -24,7 +24,6 @@ from orbax.export import obm_export from orbax.export import serving_config as osc from orbax.export import tensorflow_export -from orbax.export import utils from orbax.export.modules import obm_module import tensorflow as tf @@ -57,12 +56,12 @@ def __init__( f' must be the same. The former is {version}. The latter is ' f'{module.export_version()}.' ) - # TODO(b/363033166): Skip this step for OBM once TF isolation is done. - self._module = tf.Module() - self._module.computation_module = module - self._serving_signatures = {} - if version == constants.ExportModelType.ORBAX_MODEL: - self.serialization_functions = obm_export.ObmExport() + self._version = version + self._jax_module = module + if self._version == constants.ExportModelType.ORBAX_MODEL: + self._serialization_functions = obm_export.ObmExport( + self._jax_module, serving_configs + ) obm_module_ = module.orbax_module() if not isinstance(obm_module_, obm_module.ObmModule): raise ValueError( @@ -72,25 +71,26 @@ def __init__( # TODO(bdwalker): Let `ObmExport.__init__() do this `build()` step. obm_module_.build(serving_configs) else: - self.serialization_functions = tensorflow_export.TensorFlowExport() - # TODO(bdwalker): Let `TensorFlowExport.__init__() do this - # `process_serving_configs()` step. - process_serving_configs( - serving_configs, - obx_export_config.obx_export_tf_preprocess_only, # pytype: disable=attribute-error - self._module, - self._serving_signatures, + self._serialization_functions = tensorflow_export.TensorFlowExport( + self._jax_module, serving_configs ) @property def tf_module(self) -> tf.Module: """Returns the tf.module maintained by the export manager.""" - return self._module + if self._version == constants.ExportModelType.ORBAX_MODEL: + raise TypeError( + 'tf_module is not implemented for export version' + ' ExportModelType.ORBAX_MODEL.' + ) + return cast( + tensorflow_export.TensorFlowExport, + self._serialization_functions).tf_export_module() @property def serving_signatures(self) -> Mapping[str, Callable[..., Any]]: """Returns a map of signature keys to serving functions.""" - return self._serving_signatures + return self._serialization_functions.serving_signatures def save( self, @@ -107,108 +107,12 @@ def save( signature_overrides: signatures to override the self-maintained ones, or additional signatures to export. """ - serving_signatures = dict(self._serving_signatures) - if signature_overrides: - serving_signatures.update(signature_overrides) - - self.serialization_functions.save( - jax_module=self._module, + self._serialization_functions.save( model_path=model_path, save_options=save_options, - serving_signatures=serving_signatures, + signature_overrides=signature_overrides, ) def load(self, model_path: str, **kwargs: Any): - loaded = self.serialization_functions.load(model_path, **kwargs) + loaded = self._serialization_functions.load(model_path, **kwargs) return loaded - - -def make_e2e_inference_fn( - model_fn: Callable[..., Any], - serving_config: osc.ServingConfig, -) -> Callable[..., Any]: - """Creates an concrete end-to-end inference tf.function. - - Args: - model_fn: a callable in TF context for the numeric computation. - serving_config: a ServingConfig that defines the input sigature, - pre-processor and post-processor of the inference function. - - Returns: - A tf.function for end-to-end inference. - """ - infer_step_func_map = serving_config.bind(model_fn, require_numpy=False) - signature_key = serving_config.get_signature_keys()[0] - return utils.with_default_args( - infer_step_func_map[signature_key], serving_config.get_input_signature() - ) - - -def process_serving_configs( - serving_configs: Sequence[osc.ServingConfig], - obx_export_tf_preprocess_only: bool, - module: tf.Module, - serving_signatures: Dict[str, Callable[..., Any]], -): - """Processes the serving functions into their TF wrapped concrete functions. - - The function will use the serving_configs and the methods defined in the - provided module to populate the serving_signatures map with the concrete - inference functions. - - In addition, if trackable resources are provided in the serving_configs, - they will be added to the module's tf_trackable_resources property. - - Args: - serving_configs: a sequence of which each element is a `ServingConfig` - cooresponding to a serving signature of the exported SavedModel. - obx_export_tf_preprocess_only: a boolean indicating whether to export only - the preprocessor. - module: A tf module that will provide the method definitions. The module - should have a JaxModule set as a computation_module property. - serving_signatures: a map of signature keys to serving functions. This map - will be populated by this function. - """ - tf_trackable_resources = [] - for sc in serving_configs: - with maybe_reraise(f'Failed exporting signature_key={sc.signature_key} '): - if obx_export_tf_preprocess_only: - if not sc.tf_preprocessor: - raise ValueError( - 'serving_config.tf_preprocessor must be provided when' - ' in `obx_export_tf_preprocess_only` mode.' - ) - - def tf_preprocessor(*inputs): - return tf.nest.flatten(sc.tf_preprocessor(*inputs)) # pylint: disable=cell-var-from-loop - - preprocessor = utils.with_default_args( - tf_preprocessor, sc.get_input_signature() - ) - inference_fn = preprocessor - else: - method = sc.get_infer_step(module.computation_module.methods) - inference_fn = make_e2e_inference_fn(method, sc) - - if isinstance(sc.signature_key, str): - keys = [sc.signature_key] - else: - keys = sc.signature_key - - for key in keys: - if key in serving_signatures: - raise ValueError( - f'Duplicated key "{sc.signature_key}" in `serving_configs`.' - ) - serving_signatures[key] = inference_fn - - if sc.extra_trackable_resources is not None: - tf_trackable_resources.append(sc.extra_trackable_resources) - - if len(serving_configs) == 1: - # Make this module callable. Once exported, it can be loaded back in - # python and the nested input structure will be preservered. In - # contrast, signatures will flatten the TensorSpecs of the to kwargs. - module.__call__ = inference_fn - - module.tf_trackable_resources = tf_trackable_resources diff --git a/export/orbax/export/export_manager_test.py b/export/orbax/export/export_manager_test.py index 317748f9..48abb6fc 100644 --- a/export/orbax/export/export_manager_test.py +++ b/export/orbax/export/export_manager_test.py @@ -12,22 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -import os from absl.testing import parameterized -import chex import jax import jax.numpy as jnp -from orbax.export import config from orbax.export import constants from orbax.export import export_manager from orbax.export import jax_module from orbax.export import serving_config as sc -from orbax.export import utils -from orbax.export.export_manager import make_e2e_inference_fn +from orbax.export.modules import tensorflow_module import tensorflow as tf - def _from_feature_dict(feature_dict): return feature_dict['feat'] @@ -36,18 +30,9 @@ def _add_output_name(outputs): return {'outputs': outputs} -_ZERO_VAR = tf.Variable(0) - - -def _add_zero(x): - return x + _ZERO_VAR - - -def _linear(params, x, with_bias=False): - y = x @ params['w'] - if with_bias: - return y + params['b'] - return y +@jax.jit +def apply_fn(params, x): + return x + params['bias'] class ExportManagerTest(tf.test.TestCase, parameterized.TestCase): @@ -56,465 +41,111 @@ def setUp(self): super().setUp() self._output_dir = self.create_tempdir().full_path - @parameterized.named_parameters( - dict( - testcase_name='normal', - input_signature=[ - {'feat': tf.TensorSpec((), tf.dtypes.int32, 'feat')} - ], - preprocessor=_from_feature_dict, - postprocessor=_add_output_name, - inputs=[{'feat': tf.constant(1)}], - outputs={'outputs': tf.constant(2)}, - ), - dict( - testcase_name='embedded input signature', - preprocessor=tf.function( - _from_feature_dict, - [{'feat': tf.TensorSpec((), tf.dtypes.int32, 'feat')}], - ), - postprocessor=_add_output_name, - inputs=[{'feat': tf.constant(1)}], - outputs={'outputs': tf.constant(2)}, - ), - dict( - testcase_name='no preprocessor', - input_signature=[tf.TensorSpec((), tf.dtypes.int32, 'feat')], - postprocessor=_add_output_name, - inputs=[tf.constant(1)], - outputs={'outputs': tf.constant(2)}, - ), - dict( - testcase_name='no postprocessor', - input_signature=[ - {'feat': tf.TensorSpec((), tf.dtypes.int32, 'feat')} - ], - preprocessor=_from_feature_dict, - inputs=[{'feat': tf.constant(1)}], - outputs=tf.constant(2), - ), - dict( - testcase_name='core module only', - input_signature=[tf.TensorSpec((), tf.dtypes.int32, 'feat')], - inputs=[tf.constant(1)], - outputs=tf.constant(2), - ), - dict( - testcase_name='default value', - input_signature=[ - utils.TensorSpecWithDefault( - tf.TensorSpec((), tf.dtypes.int32, 'feat'), 1 - ) - ], - inputs=[], - outputs=tf.constant(2), - ), - ) - def test_make_e2e_inference_fn( - self, - inputs, - outputs, - input_signature=None, - preprocessor=None, - postprocessor=None, - ): - method = jax_module.JaxModule( - {'bias': jnp.array(1)}, - lambda p, x: x + p['bias'], - ).methods[constants.DEFAULT_METHOD_KEY] - inference_fn = make_e2e_inference_fn( - method, - sc.ServingConfig('key', input_signature, preprocessor, postprocessor), + def test_get_tf_module_tensorflow_export(self): + serving_config = ( + sc.ServingConfig( + 'serving_default', + input_signature=[ + tf.TensorSpec((), tf.int32, 'x'), + ], + ), ) - self.assertAllEqual(inference_fn(*inputs), outputs) - - @parameterized.named_parameters( - dict( - testcase_name='multiple signatures', - serving_configs=[ - sc.ServingConfig( - 'with_processors', - input_signature=[ - {'feat': tf.TensorSpec((), tf.dtypes.int32, 'feat')} - ], - tf_preprocessor=_from_feature_dict, - tf_postprocessor=_add_output_name, - ), - sc.ServingConfig( - 'without_processors', - input_signature=[tf.TensorSpec((), tf.dtypes.int32)], - ), - ], - expected_keys=['with_processors', 'without_processors'], - ), - dict( - testcase_name='multiple keys same signature', - serving_configs=[ - sc.ServingConfig( - ['serving_default', 'without_processors'], - input_signature=[tf.TensorSpec((), tf.dtypes.int32)], - ), - ], - expected_keys=['serving_default', 'without_processors'], - ), - dict( - testcase_name='trackables in preprocessor', - serving_configs=[ - sc.ServingConfig( - 'serving_default', - input_signature=[tf.TensorSpec((), tf.dtypes.int32)], - tf_preprocessor=_add_zero, - extra_trackable_resources=_ZERO_VAR, - ), - ], - expected_keys=['serving_default'], - ), - ) - def test_save(self, serving_configs, expected_keys): em = export_manager.ExportManager( jax_module.JaxModule( - {'bias': jnp.array(1)}, lambda p, x: x + p['bias'] + params={'bias': jnp.array(1, jnp.int32)}, + apply_fn=apply_fn, + export_version=constants.ExportModelType.TF_SAVEDMODEL, ), - serving_configs, + serving_config, ) - em.save(self._output_dir) - loaded = tf.saved_model.load(self._output_dir, ['serve']) - self.assertCountEqual(expected_keys, em.serving_signatures.keys()) - self.assertCountEqual(expected_keys, loaded.signatures.keys()) + self.assertEqual(type(em.tf_module), tf.Module) - @parameterized.named_parameters( - dict( - testcase_name='all default', - serving_config=sc.ServingConfig( - 'serving_default', - input_signature=[ - utils.TensorSpecWithDefault( - tf.TensorSpec((), tf.int32, 'x'), 2 - ), - utils.TensorSpecWithDefault( - tf.TensorSpec((), tf.int32, 'y'), 3 - ), - ], - tf_preprocessor=lambda x, y: x + y, - ), - serving_inputs={}, - expected_outputs=6, - ), - dict( - testcase_name='some default', - serving_config=sc.ServingConfig( - 'serving_default', - input_signature=[ - tf.TensorSpec((), tf.int32, 'x'), - utils.TensorSpecWithDefault( - tf.TensorSpec((), tf.int32, 'y'), 2 - ), - ], - tf_preprocessor=lambda x, y: x + y, - ), - serving_inputs={'x': 3}, - expected_outputs=6, - ), - dict( - testcase_name='override default', - serving_config=sc.ServingConfig( - 'serving_default', - input_signature=[ - tf.TensorSpec((), tf.int32, 'x'), - utils.TensorSpecWithDefault( - tf.TensorSpec((), tf.int32, 'y'), 2 - ), - ], - tf_preprocessor=lambda x, y: x + y, - ), - serving_inputs={'x': 1, 'y': 3}, - expected_outputs=5, - ), - dict( - testcase_name='nested', - serving_config=sc.ServingConfig( - 'serving_default', - input_signature=[ - tf.TensorSpec((), tf.int32, 'x'), - { - 'y': utils.TensorSpecWithDefault( - tf.TensorSpec((), tf.int32, 'y'), 2 - ), - 'z': utils.TensorSpecWithDefault( - tf.TensorSpec((), tf.int32, 'z'), 3 - ), - }, - ], - tf_preprocessor=lambda x, extra: x + extra['y'] + extra['z'], - ), - serving_inputs={'x': 1}, - expected_outputs=7, - ), - ) - def test_save_default_inputs( - self, serving_config, serving_inputs, expected_outputs - ): + def test_tf_export_module_attributes_tensorflow_export(self): + serving_config = ( + sc.ServingConfig( + 'serving_default', + input_signature=[ + tf.TensorSpec((), tf.int32, 'x'), + ], + ), + ) em = export_manager.ExportManager( jax_module.JaxModule( - {'bias': jnp.array(1, jnp.int32)}, lambda p, x: x + p['bias'] + params={'bias': jnp.array(1, jnp.int32)}, + apply_fn=apply_fn, + export_version=constants.ExportModelType.TF_SAVEDMODEL, ), - [serving_config], + serving_config, ) - em.save(self._output_dir) - # TODO(b/277814477): use the TF2 API - # loaded.signature['serving_default'](**serving_inputs) - # once it supports default values. - with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess: - meta_graph_def = tf.compat.v1.saved_model.loader.load( - sess, ['serve'], self._output_dir - ) - signature_def = meta_graph_def.signature_def[serving_config.signature_key] - output_tensor_name = signature_def.outputs['output_0'].name - fetch = sess.graph.get_tensor_by_name(output_tensor_name) - feed_dict = { - sess.graph.get_tensor_by_name(signature_def.inputs[k].name): v - for k, v in serving_inputs.items() - } - outputs = sess.run(fetch, feed_dict=feed_dict) - self.assertAllEqual(outputs, expected_outputs) + self.assertEqual(type(em.tf_module), tf.Module) + self.assertIsNotNone(em.tf_module.__call__) + self.assertTrue(isinstance(em.tf_module.computation_module, tf.Module)) - def test_save_multiple_model_functions(self): - linear_mdl = jax_module.JaxModule( - params={ - 'w': jnp.zeros((4, 2), jnp.int32), - 'b': jnp.ones((2,), jnp.int32), - }, - apply_fn={ - 'with_bias': functools.partial(_linear, with_bias=True), - 'without_bias': functools.partial(_linear, with_bias=False), - }, - input_polymorphic_shape={'with_bias': None, 'without_bias': None}, + def test_get_tf_module_orbax_model_export(self): + serving_config = ( + sc.ServingConfig( + 'serving_default', + input_signature=[ + tf.TensorSpec((), tf.int32, 'x'), + ], + ), ) - em = export_manager.ExportManager( - linear_mdl, - serving_configs=[ - sc.ServingConfig( - 'serving_default', - method_key='with_bias', - input_signature=[ - tf.TensorSpec(shape=(1, 4), dtype=tf.int32, name='x') - ], - tf_postprocessor=lambda out: {'y': out}, - ), - sc.ServingConfig( - 'no_bias', - method_key='without_bias', - input_signature=[ - tf.TensorSpec(shape=(1, 4), dtype=tf.int32, name='x') - ], - tf_postprocessor=lambda out: {'y': out}, - ), - ], - ) - em.save(self._output_dir) - loaded = tf.saved_model.load(self._output_dir, ['serve']) - - expected_keys = ['serving_default', 'no_bias'] - self.assertCountEqual(expected_keys, em.serving_signatures.keys()) - self.assertCountEqual(expected_keys, loaded.signatures.keys()) - - x = jnp.zeros((1, 4), jnp.int32) - self.assertAllEqual( - loaded.signatures['serving_default'](x=x)['y'], jnp.ones((1, 2)) - ) - self.assertAllEqual( - loaded.signatures['no_bias'](x=x)['y'], jnp.zeros((1, 2)) + jax_module.JaxModule( + params={'bias': jnp.array(1, jnp.int32)}, + apply_fn=apply_fn, + export_version=constants.ExportModelType.ORBAX_MODEL, + ), + serving_config, + constants.ExportModelType.ORBAX_MODEL, ) - def test_callable_module(self): - module = jax_module.JaxModule( - jnp.asarray(0.0), - lambda w, x: w + jnp.sum(x['a']['b']), - ) - dummy_inputs = {'a': {'b': jnp.ones(3, jnp.float32)}} + with self.assertRaises(TypeError): + em.tf_module # pylint: disable=pointless-statement - input_signature = jax.tree.map( - lambda x: tf.TensorSpec(dtype=x.dtype, shape=x.shape), dummy_inputs + def test_get_serving_signatures_tensorflow_export(self): + serving_config = ( + sc.ServingConfig( + 'serving_default', + input_signature=[ + tf.TensorSpec((), tf.int32, 'x'), + ], + ), ) em = export_manager.ExportManager( - module, - [ - sc.ServingConfig( - 'serving_default', - input_signature=[input_signature], - tf_postprocessor=lambda out: {'y': out}, - ) - ], - ) - em.save(self._output_dir) - loaded = em.load(self._output_dir) - result = loaded(dummy_inputs) - self.assertAllClose(result['y'], jnp.asarray(3.0)) - - def test_save_non_differentiable_fn(self): - - def non_differetiable_fn(_, x): - _, x = jax.lax.while_loop( - cond_fun=lambda state: state[0], - body_fun=lambda state: (False, state[1] + 1), - init_val=(False, x), - ) - return x - - serving_config = sc.ServingConfig( - 'serving', [tf.TensorSpec((), tf.float32)] + jax_module.JaxModule( + params={'bias': jnp.array(1, jnp.int32)}, + apply_fn=apply_fn, + export_version=constants.ExportModelType.TF_SAVEDMODEL, + ), + serving_config, ) - # https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#saved-model-for-non-differentiable-jax-functions - with self.assertRaises(ValueError): - export_manager.ExportManager( - jax_module.JaxModule( - {'dummy': jnp.array(1.0)}, non_differetiable_fn, trainable=True - ), - [serving_config], - ).save(os.path.join(self._output_dir, '0')) - - # Okay with with_gradients=False (default). - export_manager.ExportManager( - jax_module.JaxModule({'dummy': jnp.array(1.0)}, non_differetiable_fn), - [serving_config], - ).save(os.path.join(self._output_dir, '1')) - - def test_init_invalid_arguments(self): - single_fn_module = jax_module.JaxModule({}, lambda p, x: x) - multi_fn_module = jax_module.JaxModule( - {}, - {'foo': lambda p, x: x, 'bar': lambda p, x: x}, - input_polymorphic_shape={'foo': None, 'bar': None}, + self.assertContainsExactSubsequence( + em.serving_signatures.keys(), ['serving_default'] ) - with self.assertRaisesRegex(ValueError, 'Duplicated key'): - export_manager.ExportManager( - single_fn_module, - [ - sc.ServingConfig('serving', [tf.TensorSpec((), tf.float32)]), - sc.ServingConfig('serving', [tf.TensorSpec((), tf.int32)]), - ], - ) - with self.assertRaisesRegex(ValueError, 'Duplicated key'): - export_manager.ExportManager( - single_fn_module, - [ - sc.ServingConfig( - ['serve', 'serve'], [tf.TensorSpec((), tf.float32)] - ), - ], - ) - with self.assertRaisesRegex(ValueError, '`method_key` is not specified'): - export_manager.ExportManager( - multi_fn_module, - [ - sc.ServingConfig('serving', [tf.TensorSpec((), tf.float32)]), - ], - ) - with self.assertRaisesRegex(ValueError, 'Method key "baz" is not found'): - export_manager.ExportManager( - multi_fn_module, - [ - sc.ServingConfig( - 'serving', [tf.TensorSpec((), tf.float32)], method_key='baz' - ), - ], - ) - def test_variable_update(self): - module = jax_module.JaxModule( - {'bias': jnp.array(1)}, lambda p, x: x + p['bias'] + def test_get_serving_signatures_orbax_export(self): + serving_config = ( + sc.ServingConfig( + 'serving_default', + input_signature=[ + tf.TensorSpec((), tf.int32, 'x'), + ], + ), ) em = export_manager.ExportManager( - module, - serving_configs=[ - sc.ServingConfig( - 'serving_default', - input_signature=[tf.TensorSpec((), tf.dtypes.int32, name='x')], - tf_postprocessor=lambda out: {'y': out}, - ), - ], - ) - em.save(self._output_dir) - loaded = tf.saved_model.load(self._output_dir, ['serve']) - res = loaded.signatures['serving_default'](x=1)['y'] - self.assertAllEqual(res, 2) - - module.update_variables({'bias': jnp.array(2)}) - em.save(self._output_dir) - loaded = tf.saved_model.load(self._output_dir, ['serve']) - res = loaded.signatures['serving_default'](x=1)['y'] - self.assertAllEqual(res, 3) - - def test_return_preprocess_only_fn(self): - - def tf_preprocessor(*inputs): - x = inputs[0] - return tf.math.sin(x) + 1 - - serving_configs = [ - sc.ServingConfig( - 'serving_1', - input_signature=[tf.TensorSpec((), tf.dtypes.float32)], - tf_preprocessor=tf_preprocessor, + jax_module.JaxModule( + params={'bias': jnp.array(1, jnp.int32)}, + apply_fn=apply_fn, + export_version=constants.ExportModelType.ORBAX_MODEL, ), - ] - inputs = [tf.random.uniform([10], dtype=tf.float32)] - dict_inputs = {f'inputs_{i}': v for i, v in enumerate(inputs)} - - with config.obx_export_tf_preprocess_only(True): - - with self.subTest('with_preprocessor'): - em = export_manager.ExportManager( - jax_module.JaxModule( - {'bias': jnp.array(1)}, lambda p, x: x + p['bias'] - ), - serving_configs, - ) - em.save(self._output_dir) - loaded = tf.saved_model.load(self._output_dir, ['serve']) - chex.assert_trees_all_close( - loaded.signatures['serving_1'](**dict_inputs), - {'output_0': tf_preprocessor(*inputs)}, - ) - with self.subTest('without_preprocessor'): - serving_configs = [ - sc.ServingConfig( - 'serving_2', - input_signature=[tf.TensorSpec((), tf.dtypes.float32)], - ), - ] - with self.assertRaisesRegex( - ValueError, 'serving_config.tf_preprocessor must be provided' - ): - em = export_manager.ExportManager( - jax_module.JaxModule( - {'bias': jnp.array(1)}, lambda p, x: x + p['bias'] - ), - serving_configs, - ) - em.save(self._output_dir) + serving_config, + constants.ExportModelType.ORBAX_MODEL, + ) - # TODO(bdwalker): Re-enable this test once the ObmExport is fully implemented. - # def test_export_with_obx_model_export(self): - # serving_configs = [ - # sc.ServingConfig( - # 'serving_config', - # input_signature=[tf.TensorSpec((), tf.dtypes.float32)], - # ), - # ] - # with self.assertRaises(NotImplementedError): - # em = export_manager.ExportManager( - # jax_module.JaxModule( - # {'bias': jnp.array(1)}, - # lambda p, x: x + p['bias'], - # export_version=constants.ExportModelType.ORBAX_MODEL, - # ), - # serving_configs, - # constants.ExportModelType.ORBAX_MODEL, - # ) - # em.save(self._output_dir) + with self.assertRaises(NotImplementedError): + em.serving_signatures # pylint: disable=pointless-statement if __name__ == '__main__': diff --git a/export/orbax/export/jax_module.py b/export/orbax/export/jax_module.py index b7106d05..ba4865fc 100644 --- a/export/orbax/export/jax_module.py +++ b/export/orbax/export/jax_module.py @@ -25,11 +25,11 @@ from orbax.export.modules import tensorflow_module import tensorflow as tf - PyTree = orbax_export_typing.PyTree ApplyFn = orbax_export_typing.ApplyFn +# TODO(bdwalker): Remove tf.Module base class. class JaxModule(tf.Module, orbax_module_base.OrbaxModuleBase): """An exportable module for JAX functions and parameters. @@ -44,12 +44,12 @@ def __init__( trainable: Optional[Union[bool, PyTree]] = None, input_polymorphic_shape: Union[PyTree, Mapping[str, PyTree], None] = None, jax2tf_kwargs: Optional[Mapping[str, Any]] = None, - jax2obm_kwargs: Optional[Mapping[str, Any]] = None, jit_compile: Union[bool, Mapping[str, bool]] = True, pspecs: Optional[PyTree] = None, allow_multi_axis_sharding_consolidation: Optional[bool] = None, export_version: constants.ExportModelType = constants.ExportModelType.TF_SAVEDMODEL, flatten_signature: bool = False, + jax2obm_kwargs: Optional[Mapping[str, Any]] = None, ): """JaxModule constructor. diff --git a/export/orbax/export/modules/obm_module.py b/export/orbax/export/modules/obm_module.py index 58fd4277..624b10e8 100644 --- a/export/orbax/export/modules/obm_module.py +++ b/export/orbax/export/modules/obm_module.py @@ -32,6 +32,18 @@ ApplyFn = orbax_export_typing.ApplyFn +def _to_jax_dtype(t): + if isinstance(t, tf.DType): + return t.as_numpy_dtype() + return t + + +def _to_jax_spec(tree: PyTree) -> PyTree: + return jax.tree_util.tree_map( + lambda x: jax.ShapeDtypeStruct(x.shape, _to_jax_dtype(x.dtype)), tree + ) + + def _to_sequence(a): if isinstance(a, Sequence): return a @@ -77,6 +89,8 @@ def __init__( else False ) + self._params_args_spec = _to_jax_spec(params) + # Set the Orbax checkpoint path if provided in the jax2obm_kwargs. self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs) diff --git a/export/orbax/export/modules/tensorflow_module.py b/export/orbax/export/modules/tensorflow_module.py index a8bef5a1..0d8b5a45 100644 --- a/export/orbax/export/modules/tensorflow_module.py +++ b/export/orbax/export/modules/tensorflow_module.py @@ -17,8 +17,8 @@ from collections.abc import Callable, Mapping import dataclasses from typing import Any, Optional, Sequence, Union - from absl import logging +from etils.epy import reraise_utils import jax from jax import export as jax_export from jax.experimental import jax2tf @@ -34,6 +34,7 @@ PyTree = orbax_export_typing.PyTree ApplyFn = orbax_export_typing.ApplyFn obx_export_config = config.config +maybe_reraise = reraise_utils.maybe_reraise def _same_keys(a: Mapping[str, Any], b: Mapping[str, Any]) -> bool: @@ -73,7 +74,7 @@ class _NonTrackableMetadata: allow_multi_axis_sharding_consolidation: Optional[bool] -class TensorFlowModule(orbax_module_base.OrbaxModuleBase, tf.Module): +class TensorFlowModule(tf.Module, orbax_module_base.OrbaxModuleBase): """An exportable module for JAX functions and parameters. Holds tf.Variables converted from JAX parameters, as well as TF functions @@ -83,7 +84,7 @@ class TensorFlowModule(orbax_module_base.OrbaxModuleBase, tf.Module): def __init__( self, params: PyTree, - apply_fn: Union[Callable[..., Any], Mapping[str, ApplyFn]], + apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]], **kwargs: Any, ): jax2tf_kwargs = kwargs.get('jax2tf_kwargs', None) @@ -313,10 +314,7 @@ def _to_tf_variable(x, name, trainable, pspec): names = export_utils.get_param_names(params) if pspecs is None: pspecs = jax.tree_util.tree_map(lambda x: None, params) - logging.info('pspecs: %s', pspecs) - logging.info('params shape: %s', jax.tree.map(lambda x: x.shape, params)) - logging.info('names: %s', names) - logging.info('trainable: %s', trainable) + return jax.tree_util.tree_map( _to_tf_variable, params, names, trainable, pspecs ) @@ -357,6 +355,7 @@ def _make_tf_closure( logging.vlog(3, 'jax2tf_kwargs=%s', jax2tf_kwargs) apply_fn_tf = jax2tf.convert(apply_fn, **jax2tf_kwargs) + return tf.function( lambda x: apply_fn_tf( export_utils.get_variable_tree( diff --git a/export/orbax/export/obm_export.py b/export/orbax/export/obm_export.py index 071d72ad..1c8f9e77 100644 --- a/export/orbax/export/obm_export.py +++ b/export/orbax/export/obm_export.py @@ -14,14 +14,12 @@ """Export class that implements the save and load abstract class defined in Export Base for use with the Orbax Model export format.""" -from typing import Any, cast +from typing import Any, Callable, Mapping, cast from absl import logging from orbax.export import constants from orbax.export import export_base -from orbax.export import jax_module as jax_module_lib from orbax.export.modules import obm_module -import tensorflow as tf class ObmExport(export_base.ExportBase): @@ -29,31 +27,30 @@ class ObmExport(export_base.ExportBase): def save( self, - # TODO(b/363033166): Change this annotation once TF isolation is done. - jax_module: tf.Module, model_path: str, **kwargs: Any, ): """Saves a Jax model in the Orbax Model export format. Args: - jax_module: The `JaxModule` to be exported. model_path: The path to save the model. **kwargs: Additional arguments to pass to the `save` method. Accepted arguments are `save_options` and `serving_signatures`. """ - # TODO(b/363033166): Remove this step once TF isolation is done. - jax_module_: jax_module_lib.JaxModule = jax_module.computation_module - - if jax_module_.export_version() != constants.ExportModelType.ORBAX_MODEL: + if self._module.export_version() != constants.ExportModelType.ORBAX_MODEL: raise ValueError( "JaxModule is not of type ORBAX_MODEL. Please use the correct" " export_version. Expected ORBAX_MODEL, got" - f" {jax_module_.export_version()}" + f" {self._module.export_version()}" ) def load(self, model_path: str, **kwargs: Any): """Loads the model previously saved in the Orbax Model export format.""" logging.info("Loading model using Orbax Export Model.") raise NotImplementedError("ObmExport.load not implemented yet.") + + @property + def serving_signatures(self) -> Mapping[str, Callable[..., Any]]: + """Returns a map of signature keys to serving functions.""" + raise NotImplementedError("ObmExport.load not implemented yet.") diff --git a/export/orbax/export/tensorflow_export.py b/export/orbax/export/tensorflow_export.py index 1bdb713a..8a5ca937 100644 --- a/export/orbax/export/tensorflow_export.py +++ b/export/orbax/export/tensorflow_export.py @@ -14,25 +14,45 @@ """Export class that implements the save and load abstract class defined in Export Base for use with the TensorFlow SavedModel export format.""" -from typing import Any +from typing import Any, Callable, Dict, Mapping, Sequence from absl import logging +from etils.epy import reraise_utils +from orbax.export import config from orbax.export import export_base +from orbax.export import jax_module +from orbax.export import serving_config as osc +from orbax.export import utils import tensorflow as tf +obx_export_config = config.config +maybe_reraise = reraise_utils.maybe_reraise + + class TensorFlowExport(export_base.ExportBase): """Defines the save and load methods for exporting a model using TensorFlow SavedModel.""" + def __init__( + self, + module: jax_module.JaxModule, + serving_configs: Sequence[osc.ServingConfig], + ): + self._tf_module = tf.Module() + self._tf_module.computation_module = module + self._serving_signatures = {} + self._process_serving_configs( + serving_configs, + obx_export_config.obx_export_tf_preprocess_only, # pytype: disable=attribute-error + ) + def save( self, - jax_module: tf.Module, model_path: str, **kwargs: Any, ): """Saves the model. Args: - jax_module: The `JaxModule` to be exported. model_path: The path to save the model. **kwargs: Additional arguments to pass to the `save` method. Accepted arguments are `save_options` and `serving_signatures`. @@ -50,15 +70,21 @@ def save( f'{type(save_options)}' ) save_options.experimental_custom_gradients = ( - jax_module.computation_module.with_gradient + self._tf_module.computation_module.with_gradient ) - serving_signatures = ( - kwargs['serving_signatures'] if 'serving_signatures' in kwargs else {} + serving_signatures = dict(self._serving_signatures) + signature_overrides = ( + kwargs['signature_overrides'] + if 'signature_overrides' in kwargs and kwargs['signature_overrides'] + else {} ) + if signature_overrides: + serving_signatures.update(signature_overrides) + tf.saved_model.save( - jax_module, + self._tf_module, model_path, serving_signatures, options=save_options, @@ -68,3 +94,81 @@ def load(self, model_path: str, **kwargs: Any) -> Any: """Loads the model previously saved in the TensorFlow SavedModel format.""" logging.info('Loading model using TensorFlow SavedModel.') return tf.saved_model.load(model_path, **kwargs) + + def tf_export_module(self) -> tf.Module: + """Returns the tf.Module that was exported.""" + return self._tf_module + + @property + def serving_signatures(self) -> Mapping[str, Callable[..., Any]]: + """Returns a map of signature keys to serving functions.""" + + return self._serving_signatures + + def _process_serving_configs( + self, + serving_configs: Sequence[osc.ServingConfig], + obx_export_tf_preprocess_only: bool, + ): + """Processes the serving functions into their TF wrapped concrete functions. + + The function will use the serving_configs and the methods defined in the + provided module to populate the serving_signatures map with the concrete + inference functions. + + In addition, if trackable resources are provided in the serving_configs, + they will be added to the module's tf_trackable_resources property. + + Args: + serving_configs: a sequence of which each element is a `ServingConfig` + cooresponding to a serving signature of the exported SavedModel. + obx_export_tf_preprocess_only: a boolean indicating whether to export only + the preprocessor. + module: A tf module that will provide the method definitions. The module + should have a JaxModule set as a computation_module property. + serving_signatures: a map of signature keys to serving functions. This map + will be populated by this function. + """ + tf_trackable_resources = [] + for sc in serving_configs: + with maybe_reraise(f'Failed exporting signature_key={sc.signature_key} '): + if obx_export_tf_preprocess_only: + if not sc.tf_preprocessor: + raise ValueError( + 'serving_config.tf_preprocessor must be provided when' + ' in `obx_export_tf_preprocess_only` mode.' + ) + + def tf_preprocessor(*inputs): + return tf.nest.flatten(sc.tf_preprocessor(*inputs)) # pylint: disable=cell-var-from-loop + + preprocessor = utils.with_default_args( + tf_preprocessor, sc.get_input_signature() + ) + inference_fn = preprocessor + else: + method = sc.get_infer_step(self._tf_module.computation_module.methods) + inference_fn = utils.make_e2e_inference_fn(method, sc) + + if isinstance(sc.signature_key, str): + keys = [sc.signature_key] + else: + keys = sc.signature_key + + for key in keys: + if key in self._serving_signatures: + raise ValueError( + f'Duplicated key "{sc.signature_key}" in `serving_configs`.' + ) + self._serving_signatures[key] = inference_fn + + if sc.extra_trackable_resources is not None: + tf_trackable_resources.append(sc.extra_trackable_resources) + + if len(serving_configs) == 1: + # Make this module callable. Once exported, it can be loaded back in + # python and the nested input structure will be preservered. In + # contrast, signatures will flatten the TensorSpecs of the to kwargs. + self._tf_module.__call__ = inference_fn + + self._tf_module.tf_trackable_resources = tf_trackable_resources diff --git a/export/orbax/export/tensorflow_export_test.py b/export/orbax/export/tensorflow_export_test.py index a5e77848..07a6e83a 100644 --- a/export/orbax/export/tensorflow_export_test.py +++ b/export/orbax/export/tensorflow_export_test.py @@ -12,54 +12,484 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import logging +import os from absl.testing import parameterized +import chex +import jax import jax.numpy as jnp -from orbax.export import export_manager as em +from orbax.export import config +from orbax.export import constants from orbax.export import jax_module -from orbax.export import serving_config as osc +from orbax.export import serving_config as sc from orbax.export import tensorflow_export +from orbax.export import utils import tensorflow as tf +def _from_feature_dict(feature_dict): + return feature_dict['feat'] + + +def _add_output_name(outputs): + return {'outputs': outputs} + + +def _linear(params, x, with_bias=False): + y = x @ params['w'] + if with_bias: + return y + params['b'] + return y + + +_ZERO_VAR = tf.Variable(0) + + +def _add_zero(x): + return x + _ZERO_VAR + + class TensorFlowExportTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() self._output_dir = self.create_tempdir().full_path + @parameterized.named_parameters( + dict( + testcase_name='normal', + input_signature=[ + {'feat': tf.TensorSpec((), tf.dtypes.int32, 'feat')} + ], + preprocessor=_from_feature_dict, + postprocessor=_add_output_name, + inputs=[{'feat': tf.constant(1)}], + outputs={'outputs': tf.constant(2)}, + ), + dict( + testcase_name='embedded input signature', + preprocessor=tf.function( + _from_feature_dict, + [{'feat': tf.TensorSpec((), tf.dtypes.int32, 'feat')}], + ), + postprocessor=_add_output_name, + inputs=[{'feat': tf.constant(1)}], + outputs={'outputs': tf.constant(2)}, + ), + dict( + testcase_name='no preprocessor', + input_signature=[tf.TensorSpec((), tf.dtypes.int32, 'feat')], + postprocessor=_add_output_name, + inputs=[tf.constant(1)], + outputs={'outputs': tf.constant(2)}, + ), + dict( + testcase_name='no postprocessor', + input_signature=[ + {'feat': tf.TensorSpec((), tf.dtypes.int32, 'feat')} + ], + preprocessor=_from_feature_dict, + inputs=[{'feat': tf.constant(1)}], + outputs=tf.constant(2), + ), + dict( + testcase_name='core module only', + input_signature=[tf.TensorSpec((), tf.dtypes.int32, 'feat')], + inputs=[tf.constant(1)], + outputs=tf.constant(2), + ), + dict( + testcase_name='default value', + input_signature=[ + utils.TensorSpecWithDefault( + tf.TensorSpec((), tf.dtypes.int32, 'feat'), 1 + ) + ], + inputs=[], + outputs=tf.constant(2), + ), + ) + def test_make_e2e_inference_fn( + self, + inputs, + outputs, + input_signature=None, + preprocessor=None, + postprocessor=None, + ): + method = jax_module.JaxModule( + {'bias': jnp.array(1)}, + lambda p, x: x + p['bias'], + ).methods[constants.DEFAULT_METHOD_KEY] + inference_fn = utils.make_e2e_inference_fn( + method, + sc.ServingConfig('key', input_signature, preprocessor, postprocessor), + ) + self.assertAllEqual(inference_fn(*inputs), outputs) + @parameterized.named_parameters( dict( testcase_name='multiple signatures', serving_configs=[ - osc.ServingConfig( + sc.ServingConfig( 'without_processors', input_signature=[tf.TensorSpec((), tf.dtypes.int32)], ), ], expected_keys=['without_processors'], ), + dict( + testcase_name='multiple keys same signature', + serving_configs=[ + sc.ServingConfig( + ['serving_default', 'without_processors'], + input_signature=[tf.TensorSpec((), tf.dtypes.int32)], + ), + ], + expected_keys=['serving_default', 'without_processors'], + ), + dict( + testcase_name='trackables in preprocessor', + serving_configs=[ + sc.ServingConfig( + 'serving_default', + input_signature=[tf.TensorSpec((), tf.dtypes.int32)], + tf_preprocessor=_add_zero, + extra_trackable_resources=_ZERO_VAR, + ), + ], + expected_keys=['serving_default'], + ), ) def test_save(self, serving_configs, expected_keys): - module = tf.Module() - module.computation_module = jax_module.JaxModule( + module = jax_module.JaxModule( {'bias': jnp.array(1)}, lambda p, x: x + p['bias'] ) serving_signatures = {} - tfe = tensorflow_export.TensorFlowExport() - em.process_serving_configs( - serving_configs, - obx_export_tf_preprocess_only=False, - module=module, - serving_signatures=serving_signatures, - ) + tfe = tensorflow_export.TensorFlowExport(module, serving_configs) tfe.save( - module, self._output_dir, serving_signatures=serving_signatures, ) loaded = tf.saved_model.load(self._output_dir, ['serve']) self.assertCountEqual(expected_keys, loaded.signatures.keys()) + def test_variable_update(self): + module = jax_module.JaxModule( + {'bias': jnp.array(1)}, lambda p, x: x + p['bias'] + ) + em = tensorflow_export.TensorFlowExport( + module, + serving_configs=[ + sc.ServingConfig( + 'serving_default', + input_signature=[tf.TensorSpec((), tf.dtypes.int32, name='x')], + tf_postprocessor=lambda out: {'y': out}, + ), + ], + ) + em.save(self._output_dir) + loaded = tf.saved_model.load(self._output_dir, ['serve']) + res = loaded.signatures['serving_default'](x=1)['y'] + self.assertAllEqual(res, 2) + + module.update_variables({'bias': jnp.array(2)}) + em.save(self._output_dir) + loaded = tf.saved_model.load(self._output_dir, ['serve']) + res = loaded.signatures['serving_default'](x=1)['y'] + self.assertAllEqual(res, 3) + + def test_return_preprocess_only_fn(self): + + def tf_preprocessor(*inputs): + x = inputs[0] + return tf.math.sin(x) + 1 + + serving_configs = [ + sc.ServingConfig( + 'serving_1', + input_signature=[tf.TensorSpec((), tf.dtypes.float32)], + tf_preprocessor=tf_preprocessor, + ), + ] + inputs = [tf.random.uniform([10], dtype=tf.float32)] + dict_inputs = {f'inputs_{i}': v for i, v in enumerate(inputs)} + + with config.obx_export_tf_preprocess_only(True): + + with self.subTest('with_preprocessor'): + em = tensorflow_export.TensorFlowExport( + jax_module.JaxModule( + {'bias': jnp.array(1)}, lambda p, x: x + p['bias'] + ), + serving_configs, + ) + em.save(self._output_dir) + loaded = tf.saved_model.load(self._output_dir, ['serve']) + chex.assert_trees_all_close( + loaded.signatures['serving_1'](**dict_inputs), + {'output_0': tf_preprocessor(*inputs)}, + ) + with self.subTest('without_preprocessor'): + serving_configs = [ + sc.ServingConfig( + 'serving_2', + input_signature=[tf.TensorSpec((), tf.dtypes.float32)], + ), + ] + with self.assertRaisesRegex( + ValueError, 'serving_config.tf_preprocessor must be provided' + ): + em = tensorflow_export.TensorFlowExport( + jax_module.JaxModule( + {'bias': jnp.array(1)}, lambda p, x: x + p['bias'] + ), + serving_configs, + ) + em.save(self._output_dir) + + def test_init_invalid_arguments(self): + single_fn_module = jax_module.JaxModule({}, lambda p, x: x) + multi_fn_module = jax_module.JaxModule( + {}, + {'foo': lambda p, x: x, 'bar': lambda p, x: x}, + input_polymorphic_shape={'foo': None, 'bar': None}, + ) + with self.assertRaisesRegex(ValueError, 'Duplicated key'): + tensorflow_export.TensorFlowExport( + single_fn_module, + [ + sc.ServingConfig('serving', [tf.TensorSpec((), tf.float32)]), + sc.ServingConfig('serving', [tf.TensorSpec((), tf.int32)]), + ], + ) + with self.assertRaisesRegex(ValueError, 'Duplicated key'): + tensorflow_export.TensorFlowExport( + single_fn_module, + [ + sc.ServingConfig( + ['serve', 'serve'], [tf.TensorSpec((), tf.float32)] + ), + ], + ) + with self.assertRaisesRegex(ValueError, '`method_key` is not specified'): + tensorflow_export.TensorFlowExport( + multi_fn_module, + [ + sc.ServingConfig('serving', [tf.TensorSpec((), tf.float32)]), + ], + ) + with self.assertRaisesRegex(ValueError, 'Method key "baz" is not found'): + tensorflow_export.TensorFlowExport( + multi_fn_module, + [ + sc.ServingConfig( + 'serving', [tf.TensorSpec((), tf.float32)], method_key='baz' + ), + ], + ) + + def test_save_non_differentiable_fn(self): + + def non_differetiable_fn(_, x): + _, x = jax.lax.while_loop( + cond_fun=lambda state: state[0], + body_fun=lambda state: (False, state[1] + 1), + init_val=(False, x), + ) + return x + + serving_config = sc.ServingConfig( + 'serving', [tf.TensorSpec((), tf.float32)] + ) + + # https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#saved-model-for-non-differentiable-jax-functions + with self.assertRaises(ValueError): + tensorflow_export.TensorFlowExport( + jax_module.JaxModule( + {'dummy': jnp.array(1.0)}, non_differetiable_fn, trainable=True + ), + [serving_config], + ).save(os.path.join(self._output_dir, '0')) + + # Okay with with_gradients=False (default). + tensorflow_export.TensorFlowExport( + jax_module.JaxModule({'dummy': jnp.array(1.0)}, non_differetiable_fn), + [serving_config], + ).save(os.path.join(self._output_dir, '1')) + + def test_callable_module(self): + module = jax_module.JaxModule( + jnp.asarray(0.0), + lambda w, x: w + jnp.sum(x['a']['b']), + ) + dummy_inputs = {'a': {'b': jnp.ones(3, jnp.float32)}} + + input_signature = jax.tree.map( + lambda x: tf.TensorSpec(dtype=x.dtype, shape=x.shape), dummy_inputs + ) + em = tensorflow_export.TensorFlowExport( + module, + [ + sc.ServingConfig( + 'serving_default', + input_signature=[input_signature], + tf_postprocessor=lambda out: {'y': out}, + ) + ], + ) + em.save(self._output_dir) + loaded = em.load(self._output_dir) + result = loaded(dummy_inputs) + self.assertAllClose(result['y'], jnp.asarray(3.0)) + + def test_save_multiple_model_functions(self): + linear_mdl = jax_module.JaxModule( + params={ + 'w': jnp.zeros((4, 2), jnp.int32), + 'b': jnp.ones((2,), jnp.int32), + }, + apply_fn={ + 'with_bias': functools.partial(_linear, with_bias=True), + 'without_bias': functools.partial(_linear, with_bias=False), + }, + input_polymorphic_shape={'with_bias': None, 'without_bias': None}, + ) + + em = tensorflow_export.TensorFlowExport( + linear_mdl, + serving_configs=[ + sc.ServingConfig( + 'serving_default', + method_key='with_bias', + input_signature=[ + tf.TensorSpec(shape=(1, 4), dtype=tf.int32, name='x') + ], + tf_postprocessor=lambda out: {'y': out}, + ), + sc.ServingConfig( + 'no_bias', + method_key='without_bias', + input_signature=[ + tf.TensorSpec(shape=(1, 4), dtype=tf.int32, name='x') + ], + tf_postprocessor=lambda out: {'y': out}, + ), + ], + ) + em.save(self._output_dir) + loaded = tf.saved_model.load(self._output_dir, ['serve']) + + logging.info('loaded: %s', loaded.signatures) + expected_keys = ['serving_default', 'no_bias'] + self.assertCountEqual(expected_keys, em.serving_signatures.keys()) + self.assertCountEqual(expected_keys, loaded.signatures.keys()) + + x = jnp.zeros((1, 4), jnp.int32) + self.assertAllEqual( + loaded.signatures['serving_default'](x=x)['y'], jnp.ones((1, 2)) + ) + self.assertAllEqual( + loaded.signatures['no_bias'](x=x)['y'], jnp.zeros((1, 2)) + ) + + @parameterized.named_parameters( + dict( + testcase_name='all default', + serving_config=sc.ServingConfig( + 'serving_default', + input_signature=[ + utils.TensorSpecWithDefault( + tf.TensorSpec((), tf.int32, 'x'), 2 + ), + utils.TensorSpecWithDefault( + tf.TensorSpec((), tf.int32, 'y'), 3 + ), + ], + tf_preprocessor=lambda x, y: x + y, + ), + serving_inputs={}, + expected_outputs=6, + ), + dict( + testcase_name='some default', + serving_config=sc.ServingConfig( + 'serving_default', + input_signature=[ + tf.TensorSpec((), tf.int32, 'x'), + utils.TensorSpecWithDefault( + tf.TensorSpec((), tf.int32, 'y'), 2 + ), + ], + tf_preprocessor=lambda x, y: x + y, + ), + serving_inputs={'x': 3}, + expected_outputs=6, + ), + dict( + testcase_name='override default', + serving_config=sc.ServingConfig( + 'serving_default', + input_signature=[ + tf.TensorSpec((), tf.int32, 'x'), + utils.TensorSpecWithDefault( + tf.TensorSpec((), tf.int32, 'y'), 2 + ), + ], + tf_preprocessor=lambda x, y: x + y, + ), + serving_inputs={'x': 1, 'y': 3}, + expected_outputs=5, + ), + dict( + testcase_name='nested', + serving_config=sc.ServingConfig( + 'serving_default', + input_signature=[ + tf.TensorSpec((), tf.int32, 'x'), + { + 'y': utils.TensorSpecWithDefault( + tf.TensorSpec((), tf.int32, 'y'), 2 + ), + 'z': utils.TensorSpecWithDefault( + tf.TensorSpec((), tf.int32, 'z'), 3 + ), + }, + ], + tf_preprocessor=lambda x, extra: x + extra['y'] + extra['z'], + ), + serving_inputs={'x': 1}, + expected_outputs=7, + ), + ) + def test_save_default_inputs( + self, serving_config, serving_inputs, expected_outputs + ): + em = tensorflow_export.TensorFlowExport( + jax_module.JaxModule( + {'bias': jnp.array(1, jnp.int32)}, lambda p, x: x + p['bias'] + ), + [serving_config], + ) + em.save(self._output_dir) + # TODO(b/277814477): use the TF2 API + # loaded.signature['serving_default'](**serving_inputs) + # once it supports default values. + with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess: + meta_graph_def = tf.compat.v1.saved_model.loader.load( + sess, ['serve'], self._output_dir + ) + signature_def = meta_graph_def.signature_def[serving_config.signature_key] + output_tensor_name = signature_def.outputs['output_0'].name + fetch = sess.graph.get_tensor_by_name(output_tensor_name) + feed_dict = { + sess.graph.get_tensor_by_name(signature_def.inputs[k].name): v + for k, v in serving_inputs.items() + } + outputs = sess.run(fetch, feed_dict=feed_dict) + self.assertAllEqual(outputs, expected_outputs) + if __name__ == '__main__': tf.test.main() diff --git a/export/orbax/export/utils.py b/export/orbax/export/utils.py index d7f3bc74..2ecd7419 100644 --- a/export/orbax/export/utils.py +++ b/export/orbax/export/utils.py @@ -25,6 +25,7 @@ from jax import export as jax_export from jax import tree_util import jaxtyping +from orbax.export import serving_config as osc import tensorflow as tf ConfigProto = Any @@ -470,3 +471,23 @@ def get_variable_tree( """Returns the PyTree of the tf.Variables or obm.Variables associated with the var_treedef.""" return jax.tree_util.tree_unflatten(var_treedef, var_leaves) + +def make_e2e_inference_fn( + model_fn: Callable[..., Any], + serving_config: osc.ServingConfig, +) -> Callable[..., Any]: + """Creates an concrete end-to-end inference tf.function. + + Args: + model_fn: a callable in TF context for the numeric computation. + serving_config: a ServingConfig that defines the input sigature, + pre-processor and post-processor of the inference function. + + Returns: + A tf.function for end-to-end inference. + """ + infer_step_func_map = serving_config.bind(model_fn, require_numpy=False) + signature_key = serving_config.get_signature_keys()[0] + return with_default_args( + infer_step_func_map[signature_key], serving_config.get_input_signature() + )