Skip to content

Commit

Permalink
Refactor ModelBuilderBase.calculate_cache_key(). securityContext for …
Browse files Browse the repository at this point in the history
…Workflow (#1245)

* securityContext for Workflow

* Black reformating

* Fix RuntimeError in pydantic

* Fix test_workflow_generator.py tests

* test_security_context

* Black reformating

* _parse_version() -> parse_version()

* ModelBuilder.gordo_version

* gordo/serializer/utils.py

* gordo/serializer/utils.py

* Black reformating

* test_validate_locate

* black reformating

* Additional test in test_import_locate

* --model-builder-class

* Fix tests/gordo/test_version.py

* Black reformating

* Fix tests

* Refactor MlFlowReporter

* Black reformating

* add gordo/builder/utils.py

* --model-builder-class gordo workflow

* Fix test_build_exit_code()

* Fix test_build_exit_code() second attempt

* Container securityContext

* Fixing test_security_context

* Black reformating

* --pod-security-context

* Black reformating
  • Loading branch information
koropets authored May 19, 2022
1 parent 3ef0800 commit 6346184
Show file tree
Hide file tree
Showing 14 changed files with 389 additions and 73 deletions.
4 changes: 2 additions & 2 deletions gordo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__version__ = DEFAULT_VERSION


def _parse_version(input_version: str) -> Tuple[int, int, bool]:
def parse_version(input_version: str) -> Tuple[int, int, bool]:
"""
Takes a string which starts with standard major.minor.patch.
and returns the split of major and minor version as integers
Expand Down Expand Up @@ -44,7 +44,7 @@ def _parse_version(input_version: str) -> Tuple[int, int, bool]:
return result


MAJOR_VERSION, MINOR_VERSION, IS_UNSTABLE_VERSION = _parse_version(__version__)
MAJOR_VERSION, MINOR_VERSION, IS_UNSTABLE_VERSION = parse_version(__version__)

try:
# FIXME(https://github.com/abseil/abseil-py/issues/99)
Expand Down
24 changes: 14 additions & 10 deletions gordo/builder/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
from gordo import (
serializer,
__version__,
MAJOR_VERSION,
MINOR_VERSION,
IS_UNSTABLE_VERSION,
parse_version,
)
from gordo_dataset.dataset import get_dataset
from gordo.machine.model.base import GordoBase
Expand All @@ -49,7 +47,7 @@
class ModelBuilder:
def __init__(self, machine: Machine):
"""
Build a model for a given :class:`gordo.workflow.config_elements.machine.Machine`
Build a model for a given :class:`gordo.machine.Machine`
Parameters
----------
Expand Down Expand Up @@ -89,6 +87,10 @@ def cached_model_path(self) -> Union[os.PathLike, str, None]:
def cached_model_path(self, value):
self._cached_model_path = value

@property
def gordo_version(self):
return __version__

def build(
self,
output_dir: Optional[Union[os.PathLike, str]] = None,
Expand Down Expand Up @@ -303,7 +305,7 @@ def _build(self) -> Tuple[sklearn.base.BaseEstimator, Machine]:
model_creation_date=str(
datetime.datetime.now(datetime.timezone.utc).astimezone()
),
model_builder_version=__version__,
model_builder_version=self.gordo_version,
model_training_duration_sec=time_elapsed_model,
cross_validation=CrossValidationMetaData(
cv_duration_sec=cv_duration_sec,
Expand Down Expand Up @@ -553,8 +555,7 @@ def _extract_metadata_from_model(
def cache_key(self) -> str:
return self.calculate_cache_key(self.machine)

@staticmethod
def calculate_cache_key(machine: Machine) -> str:
def calculate_cache_key(self, machine: Machine) -> str:
"""
Calculates a hash-key from the model and data-config.
Expand Down Expand Up @@ -588,15 +589,18 @@ def calculate_cache_key(machine: Machine) -> str:
# Sets a lot of the parameters to json.dumps explicitly to ensure that we get
# consistent hash-values even if json.dumps changes their default values
# (and as such might generate different json which again gives different hash)
gordo_version = __version__ if IS_UNSTABLE_VERSION else ""
major_version, minor_version, is_unstable_version = parse_version(
self.gordo_version
)
gordo_version = self.gordo_version if is_unstable_version else ""
json_rep = json.dumps(
{
"name": machine.name,
"model_config": machine.model,
"data_config": machine.dataset.to_dict(),
"evaluation_config": machine.evaluation,
"gordo-major-version": MAJOR_VERSION,
"gordo-minor-version": MINOR_VERSION,
"gordo-major-version": major_version,
"gordo-minor-version": minor_version,
"gordo_version": gordo_version,
},
sort_keys=True,
Expand Down
15 changes: 15 additions & 0 deletions gordo/builder/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Optional, Type

from gordo.serializer.utils import validate_locate, import_locate

from .build_model import ModelBuilder


def create_model_builder(model_builder_class: Optional[str]) -> Type[ModelBuilder]:
if model_builder_class is None:
return ModelBuilder
validate_locate(model_builder_class)
cls = import_locate(model_builder_class)
if issubclass(cls, ModelBuilder):
raise ValueError('"%s" class should be subclass of "%s"')
return cls
22 changes: 16 additions & 6 deletions gordo/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,19 @@
import logging
import sys
import traceback
import jinja2
import yaml
import click

from gordo_dataset.data_provider.providers import NoSuitableDataProviderError
from gordo_dataset.sensor_tag import SensorTagNormalizationError
from gordo_dataset.base import ConfigurationError
from gordo_dataset.exceptions import ConfigException, InsufficientDataError
from gunicorn.glogging import Logger
from azure.datalake.store.exceptions import DatalakeIncompleteTransferException

import jinja2
import yaml
import click
from typing import Tuple, List, Any, cast
from gordo.builder.utils import create_model_builder

from gordo.builder.build_model import ModelBuilder
from gordo import serializer
from gordo.server import server
from gordo import __version__
Expand All @@ -35,13 +34,15 @@
_exceptions_reporter = ExceptionsReporter(
(
(Exception, 1),
(ValueError, 2),
(PermissionError, 20),
(FileNotFoundError, 30),
(DatalakeIncompleteTransferException, 40),
(SensorTagNormalizationError, 60),
(NoSuitableDataProviderError, 70),
(InsufficientDataError, 80),
(ConfigurationError, 81),
(ImportError, 85),
(ReporterException, 90),
(ConfigException, 100),
)
Expand Down Expand Up @@ -89,6 +90,12 @@ def gordo(gordo_ctx: click.Context, **ctx):
exists=False, file_okay=False, dir_okay=True, writable=True, readable=True
),
)
@click.option(
"--model-builder-class",
help="ModelBuilder class import path. "
"This should be a subclass of gordo.builder.build_model.ModelBuilder",
envvar="MODEL_BUILDER_CLASS",
)
@click.option(
"--print-cv-scores", help="Prints CV scores to stdout", is_flag=True, default=False
)
Expand Down Expand Up @@ -117,6 +124,7 @@ def build(
machine_config: dict,
output_dir: str,
model_register_dir: click.Path,
model_builder_class: str,
print_cv_scores: bool,
model_parameter: List[Tuple[str, Any]],
exceptions_reporter_file: str,
Expand All @@ -137,6 +145,7 @@ def build(
Path to a directory which will index existing models and their locations, used
for re-using old models instead of rebuilding them. If omitted then always
rebuild
model_builder_class: str
print_cv_scores: bool
Print cross validation scores to stdout
model_parameter: List[Tuple[str, Any]
Expand Down Expand Up @@ -168,7 +177,8 @@ def build(
)
logger.info(f"Fully expanded model config: {machine.model}")

builder = ModelBuilder(machine=machine)
cls = create_model_builder(model_builder_class)
builder = cls(machine=machine)

_, machine_out = builder.build(output_dir, model_register_dir) # type: ignore

Expand Down
86 changes: 74 additions & 12 deletions gordo/cli/workflow_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pkg_resources
import os

from typing import Dict, Any, TypeVar, Type, List, Tuple, cast
from typing import Dict, Any, TypeVar, Type, List, Tuple, Optional, Generic, cast

import click
import json
Expand All @@ -13,9 +13,15 @@
from gordo import __version__
from gordo.workflow.config_elements.normalized_config import NormalizedConfig
from gordo.workflow.workflow_generator import workflow_generator as wg
from gordo.workflow.config_elements.schemas import (
SecurityContext,
PodSecurityContext,
EnvVar,
)
from gordo.cli.exceptions_reporter import ReportLevel
from gordo.util.version import parse_version
from gordo.dependencies import configure_once
from gordo.serializer.utils import validate_locate


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,16 +61,26 @@ def get_builder_exceptions_report_level(config: NormalizedConfig) -> ReportLevel
T = TypeVar("T")


def parse_json(value: str, schema: Type[T]):
try:
data = json.loads(value)
except json.JSONDecodeError as e:
raise click.ClickException('Malformed JSON string: "%s"' % str(e))
try:
obj = parse_obj_as(schema, data)
except ValidationError as e:
raise click.ClickException('Schema validation error: "%s"' % str(e))
return obj
class JSONParam(click.ParamType, Generic[T]):
name = "JSON"

def __init__(self, schema: Type[T]):
self.schema = schema

def convert(
self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context]
) -> Optional[T]:
if value is None:
return None
try:
data = json.loads(value)
except json.JSONDecodeError as e:
raise self.fail("Malformed JSON string - %s" % str(e))
try:
obj = parse_obj_as(self.schema, data)
except ValidationError as e:
raise self.fail("Schema validation error - %s" % str(e))
return obj


DEFAULT_CUSTOM_MODEL_BUILDER_ENVS = """
Expand Down Expand Up @@ -303,6 +319,7 @@ def workflow_cli(gordo_ctx):
help="List of custom environment variables in ",
envvar=f"{PREFIX}_CUSTOM_MODEL_BUILDER_ENVS",
default=DEFAULT_CUSTOM_MODEL_BUILDER_ENVS,
type=JSONParam(List[EnvVar]),
)
@click.option(
"--prometheus-server-address",
Expand Down Expand Up @@ -361,6 +378,23 @@ def workflow_cli(gordo_ctx):
type=int,
default=600,
)
@click.option(
"--security-context",
help="Containers securityContext in JSON format",
envvar=f"{PREFIX}_SECURITY_CONTEXT",
type=JSONParam(SecurityContext),
)
@click.option(
"--pod-security-context",
help="Global Workflow securityContext in JSON format",
envvar=f"{PREFIX}_POD_SECURITY_CONTEXT",
type=JSONParam(PodSecurityContext),
)
@click.option(
"--model-builder-class",
help="ModelBuilder class",
envvar="MODEL_BUILDER_CLASS",
)
@click.pass_context
def workflow_generator_cli(gordo_ctx, **ctx):
"""
Expand All @@ -382,11 +416,27 @@ def workflow_generator_cli(gordo_ctx, **ctx):

validate_generate_context(context)

if context["model_builder_class"]:
validate_locate(context["model_builder_class"])

context["resources_labels"] = prepare_resources_labels(context["resources_labels"])

if context["pod_security_context"]:
pod_security_context = cast(PodSecurityContext, context["pod_security_context"])
context["pod_security_context"] = pod_security_context.dict(exclude_none=True)

if context["security_context"]:
security_context = cast(SecurityContext, context["security_context"])
context["security_context"] = security_context.dict(exclude_none=True)

model_builder_env = None
if context["custom_model_builder_envs"]:
model_builder_env = json.loads(context["custom_model_builder_envs"])
custom_model_builder_envs = cast(
List[EnvVar], context["custom_model_builder_envs"]
)
model_builder_env = [
env_var.dict(exclude_none=True) for env_var in custom_model_builder_envs
]
# Create normalized config
config = NormalizedConfig(
yaml_content,
Expand Down Expand Up @@ -428,6 +478,18 @@ def workflow_generator_cli(gordo_ctx, **ctx):

context["builder_runtime"] = builder_runtime

builder_runtime_env = []
if "env" in builder_runtime:
builder_runtime_env = builder_runtime["env"]

if builder_runtime_env:
if context["model_builder_class"]:
builder_runtime_env.append(
{"name": "MODEL_BUILDER_CLASS", "value": context["model_builder_class"]}
)

context["builder_runtime_env"] = builder_runtime_env

context["server_resources"] = config.globals["runtime"]["server"]["resources"]
context["server_image"] = config.globals["runtime"]["server"]["image"]

Expand Down
22 changes: 17 additions & 5 deletions gordo/reporters/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
import tempfile
from typing import Dict, List, Union, Tuple
from typing import Dict, List, Union, Tuple, Optional, Type, cast
from uuid import uuid4

from azureml.core import Workspace
Expand All @@ -22,6 +22,8 @@
from gordo.machine.machine import MachineEncoder
from gordo.util.utils import capture_args
from gordo_dataset.sensor_tag import extract_tag_name
from gordo.builder.utils import create_model_builder

from .base import BaseReporter
from .exceptions import ReporterException

Expand Down Expand Up @@ -480,14 +482,24 @@ def log_machine(mlflow_client: MlflowClient, run_id: str, machine: Machine):

class MlFlowReporter(BaseReporter):
@capture_args
def __init__(self, *args, **kwargs):
pass
def __init__(
self,
*args,
model_builder_class: Optional[Union[str, Type[ModelBuilder]]] = None,
**kwargs,
):
if type(model_builder_class) is str:
model_builder_class = create_model_builder(model_builder_class)
if model_builder_class is None:
model_builder_class = ModelBuilder
self.model_builder_class = cast(Type[ModelBuilder], model_builder_class)

def report(self, machine: Machine):

workspace_kwargs = get_workspace_kwargs()
service_principal_kwargs = get_spauth_kwargs()
cache_key = ModelBuilder.calculate_cache_key(machine)
# TODO something better here
model_builder = self.model_builder_class(machine)
cache_key = model_builder.calculate_cache_key(machine)

with mlflow_context(
machine.name, cache_key, workspace_kwargs, service_principal_kwargs
Expand Down
Loading

0 comments on commit 6346184

Please sign in to comment.