Skip to content

Commit

Permalink
Use custom HttpRequest builder to set customized useragent header for…
Browse files Browse the repository at this point in the history
… TFX.

PiperOrigin-RevId: 385225372
  • Loading branch information
zhitaoli authored and dhruvesh09 committed Jul 22, 2021
1 parent f6d6fdd commit e630b08
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 44 deletions.
3 changes: 2 additions & 1 deletion tfx/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def make_pipeline_sdk_required_install_packages():
'portpicker>=1.3.1,<2',
'protobuf>=3.13,<4',
'docker>=4.1,<5',
'google-apitools>=0.5,<1',
'google-api-python-client>=1.8,<2',
# TODO(b/176812386): Deprecate usage of jinja2 for placeholders.
'jinja2>=2.7.3,<3',
]
Expand All @@ -74,7 +76,6 @@ def make_required_install_packages():
'apache-beam[gcp]>=2.29,<3',
'attrs>=19.3.0,<21',
'click>=7,<8',
'google-api-python-client>=1.7.8,<2',
'google-cloud-aiplatform>=0.5.0,<0.8',
# TODO(b/193571051): remove 2.21 cap after TF 2.6 becomes available.
'google-cloud-bigquery>=1.28.0,<2.21',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
if exec_properties.get('output_example_spec'):
proto_utils.json_to_proto(exec_properties['output_example_spec'],
output_example_spec)
api = discovery.build(service_name, api_version)
api = discovery.build(
service_name,
api_version,
requestBuilder=telemetry_utils.TFXHttpRequest,
)
new_model_created = False
try:
new_model_created = runner.create_model_for_aip_prediction_if_not_exist(
Expand Down
9 changes: 5 additions & 4 deletions tfx/extensions/google_cloud_ai_platform/pusher/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Custom executor to push TFX model to AI Platform."""

import time
from typing import Any, Dict, List, Text
from typing import Any, Dict, List

from google.api_core import client_options # pylint: disable=unused-import
from googleapiclient import discovery
Expand Down Expand Up @@ -51,9 +51,9 @@
class Executor(tfx_pusher_executor.Executor):
"""Deploy a model to Google Cloud AI Platform serving."""

def Do(self, input_dict: Dict[Text, List[types.Artifact]],
output_dict: Dict[Text, List[types.Artifact]],
exec_properties: Dict[Text, Any]):
def Do(self, input_dict: Dict[str, List[types.Artifact]],
output_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any]):
"""Overrides the tfx_pusher_executor.
Args:
Expand Down Expand Up @@ -125,6 +125,7 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
api = discovery.build(
service_name,
api_version,
requestBuilder=telemetry_utils.TFXHttpRequest,
client_options=client_options.ClientOptions(api_endpoint=endpoint),
)
runner.deploy_model_for_aip_prediction(
Expand Down
6 changes: 5 additions & 1 deletion tfx/extensions/google_cloud_ai_platform/training_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,11 @@ def create_client(self) -> None:
only be used for one job, as each instance stores variables (e.g. job_id)
specific to each job.
"""
self._client = discovery.build('ml', 'v1')
self._client = discovery.build(
'ml',
'v1',
requestBuilder=telemetry_utils.TFXHttpRequest,
)

def create_training_args(self, input_dict: Dict[Text, List[types.Artifact]],
output_dict: Dict[Text, List[types.Artifact]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from tfx.proto import tuner_pb2
from tfx.types import standard_artifacts
from tfx.utils import path_utils
from tfx.utils import telemetry_utils


class KubeflowGCPIntegrationTest(kubeflow_test_utils.BaseKubeflowTest):
Expand Down Expand Up @@ -342,7 +343,11 @@ def _pusher(model_importer, model_blessing_importer):

# Use default service_name / api_version.
service_name, api_version = runner.get_service_name_and_api_version({})
api = discovery.build(service_name, api_version)
api = discovery.build(
service_name,
api_version,
requestBuilder=telemetry_utils.TFXHttpRequest,
)

# The model should be NotFound yet.
with self.assertRaisesRegex(googleapiclient_errors.HttpError,
Expand Down
15 changes: 0 additions & 15 deletions tfx/tools/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,6 @@ LABEL maintainer="[email protected]"

RUN python -m pip install --upgrade pip

# TODO(b/151392812): Remove `google-api-python-client` and `google-apitools`
# when patching is not needed any more.
RUN python -m pip install \
"google-api-python-client==1.8.0" \
"google-apitools==0.5.30"

COPY --from=wheel-builder /tfx/src/dist/*.whl /tfx/src/dist/
WORKDIR /tfx/src

Expand All @@ -93,12 +87,3 @@ RUN MLSDK_WHEEL=$(find dist -name "ml_pipelines_sdk-*.whl"); \
python -m pip install ${MLSDK_WHEEL} ${TFX_WHEEL}[docker-image] ; \
fi && \
echo "Installed python packages:\n" && python -m pip list

# Patch http.py in googleapiclient and base_api.py in apitools
# to use our own UserAgent.
# TODO(b/151392812): Remove this when other telemetries become available.
COPY --from=wheel-builder /tfx/src/tfx/tools/docker/patches /tfx/src/tfx/tools/docker/patches
RUN patch `python -c 'import googleapiclient; print(googleapiclient.__path__[0])'`/http.py \
/tfx/src/tfx/tools/docker/patches/http.patch && \
patch `python -c 'import apitools; print(apitools.__path__[0])'`/base/py/base_api.py \
/tfx/src/tfx/tools/docker/patches/base_api.patch
10 changes: 0 additions & 10 deletions tfx/tools/docker/patches/base_api.patch

This file was deleted.

11 changes: 0 additions & 11 deletions tfx/tools/docker/patches/http.patch

This file was deleted.

26 changes: 26 additions & 0 deletions tfx/utils/telemetry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Dict, List

from absl import logging
from googleapiclient import http
from tfx import version

# Common label names used.
Expand Down Expand Up @@ -101,3 +102,28 @@ def make_beam_labels_args() -> List[str]:
for k in sorted(labels):
result.extend(['--labels', '%s=%s' % (k, labels[k])])
return result


class TFXHttpRequest(http.HttpRequest):
"""HttpRequest builder that sets a customized useragent header for TFX.
This is used to track the usage of the TFX on Cloud AI.
"""

def __init__(self, *args, **kwargs):
"""Construct a HttpRequest.
Args:
*args: Positional arguments to pass to the base class constructor.
**kwargs: Keyword arguments to pass to the base class constructor.
"""
headers = kwargs.setdefault('headers', {})
# See Mozilla standard User Agent header Syntax:
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/User-Agent
# TODO(b/193915978): Stop relying on '-tfxpipeline-' suffix and use
# tfx/version instead. More labels set to scoped_labels can also be added
# the comments variable below, upon needed.
comments = '(client_context:tfxpipeline;)'
user_agent = f'tfx/{version.__version__} {comments}'
headers['user-agent'] = user_agent
super().__init__(*args, **kwargs)
10 changes: 10 additions & 0 deletions tfx/utils/telemetry_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests for tfx.utils.telemetry_utils."""

import sys
from googleapiclient import http
import tensorflow as tf

from tfx import version
Expand Down Expand Up @@ -88,6 +89,15 @@ def testScopedLabels(self):
},
**orig_labels))

def testTFXHttpRequest(self):
req = telemetry_utils.TFXHttpRequest(
http=http.build_http(),
postproc=None,
uri='http://example.com',
)
self.assertContainsInOrder(['tfx/', 'client_context:tfxpipeline;'],
req.headers['user-agent'])


if __name__ == '__main__':
tf.test.main()

0 comments on commit e630b08

Please sign in to comment.