Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Catch promotion script up with external changes #76

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 115 additions & 62 deletions deployment/promotion/promote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
self-contained builds of the model version are found.
- If you need to create a self-contained build of a model version, use the create_scb.py script.

Configuration is done via environment variables. All are mandatory except VERTA_DEST_REGISTERED_MODEL:
Configuration is done via environment variables. All are mandatory except VERTA_DEST_REGISTERED_MODEL_ID:

- VERTA_SOURCE_MODEL_VERSION_ID: The ID of the model version to promote
- VERTA_SOURCE_HOST: The source Verta instance to promote from
- VERTA_SOURCE_EMAIL: The email address for authentication to the source Verta instance
- VERTA_SOURCE_DEV_KEY: The dev key associated to the email address on the source Verta instance
- VERTA_SOURCE_WORKSPACE: The workspace associated with the build on the source Verta instance
- VERTA_SOURCE_WORKSPACE_0: The workspace associated with the build on the source Verta instance
- VERTA_DEST_HOST: The destination Verta instance to promote to
- VERTA_DEST_EMAIL: The email address for authentication to the destination Verta instance
- VERTA_DEST_DEV_KEY: The dev key associated to the email address on the destination Verta instance
- VERTA_DEST_WORKSPACE: The name of the workspace associated with the build on the destination Verta instance
- VERTA_DEST_REGISTERED_MODEL: [optional] The name of the registered model to promote to. If missing, we'll create a new registered model
- VERTA_DEST_WORKSPACE: The workspace associated with the build on the destination Verta instance
- VERTA_DEST_REGISTERED_MODEL_ID: [optional] The ID of the registered model to promote to. If missing, we'll create a new registered model

Optional environment variables to configure curl usage:
VERTA_CURL_OPTS: Options to pass to curl. Defaults to '-O'
Expand All @@ -35,12 +35,52 @@
import os
import datetime

env_vars = ['VERTA_SOURCE_MODEL_VERSION_ID', 'VERTA_SOURCE_HOST', 'VERTA_SOURCE_EMAIL', 'VERTA_SOURCE_DEV_KEY',
'VERTA_SOURCE_WORKSPACE', 'VERTA_DEST_HOST', 'VERTA_DEST_EMAIL', 'VERTA_DEST_DEV_KEY',
env_vars = ['VERTA_SOURCE_MODEL_VERSION_ID', 'VERTA_SOURCE_HOST', 'VERTA_SOURCE_EMAIL',
'VERTA_SOURCE_DEV_KEY',
'VERTA_SOURCE_WORKSPACE_0', 'VERTA_DEST_HOST', 'VERTA_DEST_EMAIL',
'VERTA_DEST_DEV_KEY',
'VERTA_DEST_WORKSPACE']
opt_env_vars = ['VERTA_DEST_REGISTERED_MODEL']

opt_env_vars = ['VERTA_DEST_REGISTERED_MODEL_ID']

params = {}

proxies = {
"http": None,
"https": None
}

if not os.environ.get('VERTA_DEST_WORKSPACE'):
host = 'https://' + os.environ.get(
'VERTA_DEST_HOST') + '/api/v1/uac-proxy/workspace/getVisibleWorkspaces'
headers_dict = {'grpc-metadata-source': 'PythonClient',
'grpc-metadata-email': os.environ.get('VERTA_DEST_EMAIL'),
'grpc-metadata-developer_key': os.environ.get('VERTA_DEST_DEV_KEY')}
workspaces_dest = requests.get(host, headers=headers_dict, proxies=proxies)

source_workspace_id = os.environ.get('VERTA_SOURCE_WORKSPACE_0')
host = 'https://' + os.environ.get(
'VERTA_SOURCE_HOST') + '/api/v1/uac-proxy/workspace/getVisibleWorkspaces'
headers_dict = {'grpc-metadata-source': 'PythonClient',
'grpc-metadata-email': os.environ.get('VERTA_SOURCE_EMAIL'),
'grpc-metadata-developer_key': os.environ.get('VERTA_SOURCE_DEV_KEY')}
workspaces_source = requests.get(host, headers=headers_dict, proxies=proxies)

for item in workspaces_source.json()['workspace']:
if 'id' in item.keys() and item['id'] == source_workspace_id:
if 'org_name' in item.keys():
source_workspace = item['org_name']
else:
source_workspace = item['username']

if source_workspace == None:
print('Source workspace ID could not be matched')

for item in workspaces_dest.json()['workspace']:
if 'org_name' in item.keys() and item['org_name'] == source_workspace:
os.environ['VERTA_DEST_WORKSPACE'] = item['org_name']
elif 'username' in item.keys() and item['username'] == source_workspace:
os.environ['VERTA_DEST_WORKSPACE'] = item['username']

for param_name in env_vars:
param = os.environ.get(param_name)
Expand All @@ -57,21 +97,23 @@
params['VERTA_CURL_OPTS'] = curl_opts
else:
params['VERTA_CURL_OPTS'] = ''
params['VERTA_CURL_OPTS'] += f' -H @curl_headers'

config = {
'source': {
'model_version_id': atoi(params['VERTA_SOURCE_MODEL_VERSION_ID']),
'model_version_id': atoi(params['VERTA_SOURCE_MODEL_VERSION_ID'][2:-2]),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickchappell-verta can you ask why this is here? I don't understand why they'd strip the first and last 2 characters from the ID

'host': params['VERTA_SOURCE_HOST'],
'email': params['VERTA_SOURCE_EMAIL'],
'devkey': params['VERTA_SOURCE_DEV_KEY'],
'workspace': params['VERTA_SOURCE_WORKSPACE']
'workspace': params['VERTA_SOURCE_WORKSPACE_0'],
'workspace_name': source_workspace
},
'dest': {
'host': params['VERTA_DEST_HOST'],
'email': params['VERTA_DEST_EMAIL'],
'devkey': params['VERTA_DEST_DEV_KEY'],
'workspace': params['VERTA_DEST_WORKSPACE'],
'registered_model_name': params['VERTA_DEST_REGISTERED_MODEL'] # Will be empty if no destination RM was provided
'registered_model_id': params['VERTA_DEST_REGISTERED_MODEL_ID']
}
}

Expand All @@ -84,7 +126,8 @@ def copy_fields(fields, src, dest):

def auth_context(host, email, devkey, workspace):
return {'headers': {'Grpc-metadata-scheme': 'https', 'Grpc-metadata-source': 'PythonClient',
'Grpc-metadata-email': email, 'Grpc-metadata-developer_key': devkey}, 'host': host,
'Grpc-metadata-email': email, 'Grpc-metadata-developer_key': devkey},
'host': host,
'workspace': workspace
}

Expand All @@ -93,7 +136,8 @@ def post(auth, path, body):
auth['headers']['Content-Type'] = 'application/json'
body['workspaceName'] = auth['workspace']
try:
res = requests.post("https://{}{}".format(auth["host"], path), headers=auth['headers'], json=body)
res = requests.post("https://{}{}".format(auth["host"], path), headers=auth['headers'],
json=body, proxies=proxies)
res.raise_for_status()
except requests.exceptions.RequestException as e:
if e.response.text:
Expand All @@ -104,7 +148,8 @@ def post(auth, path, body):

def get(auth, path):
try:
res = requests.get("https://{}{}".format(auth["host"], path), headers=auth['headers'])
res = requests.get("https://{}{}".format(auth["host"], path), headers=auth['headers'],
proxies=proxies)
res.raise_for_status()
except requests.exceptions.RequestException as e:
if e.response.text:
Expand All @@ -115,7 +160,8 @@ def get(auth, path):

def put(auth, path, body):
try:
res = requests.put("https://{}{}".format(auth["host"], path), headers=auth['headers'], json=body)
res = requests.put("https://{}{}".format(auth["host"], path), headers=auth['headers'],
json=body, proxies=proxies)
res.raise_for_status()
except requests.exceptions.RequestException as e:
if e.response.text:
Expand All @@ -128,7 +174,8 @@ def put(auth, path, body):

def patch(auth, path, body):
try:
res = requests.patch("https://{}{}".format(auth["host"], path), headers=auth['headers'], json=body)
res = requests.patch("https://{}{}".format(auth["host"], path), headers=auth['headers'],
json=body, proxies=proxies)
res.raise_for_status()
except requests.exceptions.RequestException as e:
if e.response.text:
Expand All @@ -143,29 +190,22 @@ def get_build(auth, build_id):


def get_builds(auth, source):
path = "/api/v1/deployment/builds/?workspaceName={}&model_version_id={}".format(source['workspace'], source['model_version_id'])
path = "/api/v1/deployment/builds/?workspaceName={}&model_version_id={}".format(
source['workspace_name'], source['model_version_id'])
print(f"\n\nPATH = {path}\n\n")
builds = get(auth, path)
print(f"\n\nBUILDS = {builds}\n\n")
return get(auth, path)


def get_model_version(auth, model_version_id):
return get(auth, '/api/v1/registry/model_versions/{}'.format(model_version_id))['model_version']
return get(auth, '/api/v1/registry/model_versions/{}'.format(model_version_id))[
'model_version']


def get_registered_model(auth, registered_model_id):
return get(auth, '/api/v1/registry/registered_models/{}'.format(registered_model_id))['registered_model']


def get_registered_models_by_name(auth, registered_model_name):
path = '/api/v1/registry/workspaces/{}/registered_models/find'.format(auth['workspace'])
predicates = {
'predicates': [{
"key": "name",
"operator": "EQ",
"value": registered_model_name,
"value_type": "STRING"
}]
}
return post(auth, path, predicates)['registered_models']
return get(auth, '/api/v1/registry/registered_models/{}'.format(registered_model_id))[
'registered_model']


def signed_artifact_url(auth, model_version_id, artifact):
Expand Down Expand Up @@ -193,7 +233,8 @@ def download_artifact(auth, model_version_id, artifact):
key = artifact['key']
url = signed_artifact_url(auth, model_version_id, artifact)
print("Downloading artifact '%s'" % key)
curl_cmd = "curl %s -o %s '%s'" % (params['VERTA_CURL_OPTS'], key, url)
curl_cmd = "curl --cacert %s -o %s %s '%s'" % (
os.environ['REQUESTS_CA_BUNDLE'], key, params['VERTA_CURL_OPTS'], url)
os.system(curl_cmd)


Expand All @@ -208,7 +249,8 @@ def download_artifacts(auth, model_version_id, artifacts, model_artifact):
}
copy_fields(['artifact_type', 'key'], artifact, artifact_request)
download_artifact(auth, model_version_id, artifact_request)
downloaded_artifacts.append({'key': artifact['key'], 'artifact_type': artifact['artifact_type']})
downloaded_artifacts.append(
{'key': artifact['key'], 'artifact_type': artifact['artifact_type']})

model_artifact_request = {
'method': 'GET',
Expand All @@ -223,23 +265,32 @@ def download_artifacts(auth, model_version_id, artifacts, model_artifact):
def upload_artifact(auth, model_version_id, artifact):
key = artifact['key']
print("Uploading artifact '%s'" % key)

print(artifact)

artifact_request = {
'method': 'PUT',
'model_version_id': model_version_id,
'key': key
}
put_url = signed_artifact_url(auth, model_version_id, artifact_request)
data = open(key, 'rb')
put_response = requests.put(put_url, data=data, headers={'Content-type': 'application/octet-stream'})
headers_dict = {
'Grpc-metadata-source': 'PythonClient',
'Content-type': 'application/octet-stream',
'Grpc-metadata-email': os.environ['VERTA_DEST_EMAIL'],
'Grpc-metadata-developer_key': os.environ['VERTA_DEST_DEV_KEY']
}
put_response = requests.put(put_url, data=data, headers=headers_dict)

if not put_response.ok:
raise Exception("Failed to put artifact (%d %s). Key: %s\tURL: %s\tText: %s" % (put_response.status_code,
put_response.reason, key, put_url, put_response.text))

check_url = signed_artifact_url(auth, model_version_id, {'method': 'GET', 'model_version_id': model_version_id,
'key': key})
check = requests.get(check_url)
raise Exception("Failed to put artifact (%d %s). Key: %s\tURL: %s\tText: %s" % (
put_response.status_code,
put_response.reason, key, put_url, put_response.text))

check_url = signed_artifact_url(auth, model_version_id,
{'method': 'GET', 'model_version_id': model_version_id,
'key': key})
check = requests.get(check_url, headers=headers_dict)
if not check.ok:
raise Exception("Failed to verify artifact '%s' upload at URL %s" % (key, check_url))

Expand All @@ -263,7 +314,8 @@ def get_promotion_data(_config):
model_version_id = source['model_version_id']

print("Fetching promotion data for model version %d" % source['model_version_id'])
source_auth = auth_context(source['host'], source['email'], source['devkey'], source['workspace'])
source_auth = auth_context(source['host'], source['email'], source['devkey'],
source['workspace'])
model_version = get_model_version(source_auth, model_version_id)

all_builds = get_builds(source_auth, source)
Expand All @@ -273,24 +325,28 @@ def get_promotion_data(_config):
build = None
latest_date = None
for b in all_builds['builds']:
if 'self_contained' in b['creator_request'] and b['creator_request']['self_contained']:
print(f"\n\nBUILDS = {b}\n\n")
if 'self_contained' in b['creator_request']:
build_date = datetime.datetime.strptime(b['date_created'], time_format)
if not latest_date or build_date > latest_date:
latest_date = build_date
build = b

if not build or not latest_date:
print("No self contained builds found for model version id %d, promotion stopped." % source['model_version_id'])
print(
"No self contained builds found for model version id %d, promotion stopped." % source[
'model_version_id'])
raise SystemExit(1)

model = get_registered_model(source_auth, model_version['registered_model_id'])
artifacts = download_artifacts(source_auth, model_version_id, model_version['artifacts'], model_version['model'])
model = get_registered_model(source_auth, model_version['registered_model_id'])
artifacts = download_artifacts(source_auth, model_version_id, model_version['artifacts'],
model_version['model'])

promotion = {
'build': build,
'model_version': model_version,
'model': model,
'artifacts': artifacts
'build': build,
'model_version': model_version,
'model': model,
'artifacts': artifacts
}
return promotion

Expand All @@ -301,7 +357,8 @@ def create_model(auth, source_model, source_artifacts):
model = {
'artifacts': source_artifacts
}
copy_fields(['labels', 'custom_permission', 'name', 'readme_text', 'resource_visibility', 'description'], source_model, model)
copy_fields(['labels', 'custom_permission', 'name', 'readme_text', 'resource_visibility',
'description'], source_model, model)
return post(auth, path, model)['registered_model']


Expand All @@ -312,15 +369,17 @@ def create_model_version(auth, source_model_version, promoted_model):
if 'labels' in source_model_version.keys():
model_version['labels'] = source_model_version['labels']

fields = ['artifacts', 'attributes', 'environment', 'version', 'readme_text', 'model', 'description', 'labels']
fields = ['artifacts', 'attributes', 'environment', 'version', 'readme_text', 'model',
'description', 'labels']
copy_fields(fields, source_model_version, model_version)
return post(auth, path, model_version)['model_version']


def patch_model(auth, registered_model_id, model_version_id, model):
print("Updating model artifact for model version '%s'" % model_version_id)

path = '/api/v1/registry/registered_models/{}/model_versions/{}'.format(registered_model_id, model_version_id)
path = '/api/v1/registry/registered_models/{}/model_versions/{}'.format(registered_model_id,
model_version_id)
update = {'model': model}
return patch(auth, path, update)

Expand Down Expand Up @@ -360,21 +419,15 @@ def upload_build(source_build):

def create_promotion(_config, promotion):
dest = _config['dest']

dest_auth = auth_context(dest['host'], dest['email'], dest['devkey'], dest['workspace'])

print("Starting promotion")
build_location = upload_build(promotion['build'])
if not dest['registered_model_name']:
if not dest['registered_model_id']:
model = create_model(dest_auth, promotion['model'], promotion['artifacts'])
else:
models = get_registered_models_by_name(dest_auth, dest['registered_model_name'])
if len(models) > 1:
print("WARNING: Multiple registered models with name '%s' found, using first one with id '%s'" % (dest['registered_model_name'], models[0]["id"]))
elif len(models) == 0:
print("ERROR: Registered model with name '%s' not found" % dest['registered_model_name'])
return
model = models[0]
model = get_registered_model(dest_auth, dest['registered_model_id'])
print("Using existing registered model '%s'" % model['name'])
model_version = create_model_version(dest_auth, promotion['model_version'], model)

Expand Down