Skip to content

Commit

Permalink
Merge pull request #97 from SFDO-Tooling/feature/nicer-plugin-exceptions
Browse files Browse the repository at this point in the history
Throw plugin exceptions while line# is available
  • Loading branch information
prescod authored Sep 21, 2020
2 parents 7db175d + 4cfb1dc commit b6535a2
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 26 deletions.
8 changes: 8 additions & 0 deletions snowfakery/data_gen_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ class DataGenValueError(DataGenError):
pass


class DataGenImportError(DataGenError):
pass


class DataGenTypeError(DataGenError):
pass


def fix_exception(message, parentobj, e):
"""Add filename and linenumber to an exception if needed"""
filename, line_num = parentobj.filename, parentobj.line_num
Expand Down
23 changes: 4 additions & 19 deletions snowfakery/data_generator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import warnings
from typing import IO, Tuple, Mapping, List, Dict, TextIO, Union
from importlib import import_module
from click.utils import LazyFile

from yaml import safe_dump, safe_load
from faker.providers import BaseProvider as FakerProvider

from .data_gen_exceptions import DataGenNameError
from .output_streams import DebugOutputStream, OutputStream
from .parse_recipe_yaml import parse_recipe
from .data_generator_runtime import output_batches, StoppingCriteria, Globals
from .data_gen_exceptions import DataGenError
from . import SnowfakeryPlugin
from faker.providers import BaseProvider as FakerProvider


# This tool is essentially a three stage interpreter.
Expand Down Expand Up @@ -83,32 +82,18 @@ def save_continuation_yaml(continuation_data: Globals, continuation_file: TextIO
safe_dump(continuation_data, continuation_file)


def resolve_plugin(plugin: str) -> object:
"Resolve a plugin to a class"
module_name, class_name = plugin.rsplit(".", 1)
module = import_module(module_name)
cls = getattr(module, class_name)
if issubclass(cls, FakerProvider):
return (FakerProvider, cls)
elif issubclass(cls, SnowfakeryPlugin):
return (SnowfakeryPlugin, cls)
else:
raise TypeError(f"{cls} is not a Faker Provider nor Snowfakery Plugin")


def process_plugins(plugins: List[str]) -> Tuple[List[object], Mapping[str, object]]:
def process_plugins(plugins: List) -> Tuple[List[object], Mapping[str, object]]:
"""Resolve a list of names for SnowfakeryPlugins and Faker Providers to objects
The Providers are returned as a list of objects.
The Plugins are a mapping of ClassName:object so they can be namespaced.
"""
plugin_classes = [resolve_plugin(plugin) for plugin in plugins]
faker_providers = [
provider for baseclass, provider in plugin_classes if baseclass == FakerProvider
provider for baseclass, provider in plugins if baseclass == FakerProvider
]
snowfakery_plugins = {
plugin.__name__: plugin
for baseclass, plugin in plugin_classes
for baseclass, plugin in plugins
if baseclass == SnowfakeryPlugin
}
return (faker_providers, snowfakery_plugins)
Expand Down
12 changes: 8 additions & 4 deletions snowfakery/parse_recipe_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ReferenceValue,
)

from snowfakery.plugins import resolve_plugin
import snowfakery.data_gen_exceptions as exc

SHARED_OBJECT = "#SHARED_OBJECT"
Expand All @@ -32,9 +33,7 @@


class ParseResult:
def __init__(
self, options, tables: Mapping, templates, plugins: Sequence[str] = ()
):
def __init__(self, options, tables: Mapping, templates, plugins: Sequence = ()):
self.options = options
self.tables = tables
self.templates = templates
Expand Down Expand Up @@ -485,7 +484,12 @@ def parse_top_level_elements(path: Path, data: List, context: ParseContext):
templates.extend(parse_included_files(path, data, context))
context.options.extend(top_level_objects["option"])
context.macros.update({obj["macro"]: obj for obj in top_level_objects["macro"]})
context.plugins.extend([obj["plugin"] for obj in top_level_objects["plugin"]])
context.plugins.extend(
[
resolve_plugin(obj["plugin"], obj["__line__"])
for obj in top_level_objects["plugin"]
]
)
templates.extend(top_level_objects["object"])
return templates

Expand Down
27 changes: 27 additions & 0 deletions snowfakery/plugins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any, Callable, Mapping
from importlib import import_module

import yaml
from yaml.representer import Representer
from faker.providers import BaseProvider as FakerProvider

import snowfakery.data_gen_exceptions as exc


class SnowfakeryPlugin:
Expand Down Expand Up @@ -80,6 +84,29 @@ def lazy(func: Any) -> Callable:
return func


def resolve_plugin(plugin: str, lineinfo) -> object:
"Resolve a plugin to a class"
module_name, class_name = plugin.rsplit(".", 1)

try:
module = import_module(module_name)
except ModuleNotFoundError as e:
raise exc.DataGenImportError(
f"Cannot find plugin: {e}", lineinfo.filename, lineinfo.line_num
)
cls = getattr(module, class_name)
if issubclass(cls, FakerProvider):
return (FakerProvider, cls)
elif issubclass(cls, SnowfakeryPlugin):
return (SnowfakeryPlugin, cls)
else:
raise exc.DataGenTypeError(
f"{cls} is not a Faker Provider nor Snowfakery Plugin",
lineinfo.filename,
lineinfo.line_num,
)


class PluginResult:
"""`PluginResult` objects expose a namespace that other code can access through dot-notation.
Expand Down
11 changes: 8 additions & 3 deletions tests/test_custom_plugins_and_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

from snowfakery.data_generator import generate
from snowfakery import SnowfakeryPlugin, lazy
from snowfakery.data_gen_exceptions import DataGenError
from snowfakery.data_gen_exceptions import (
DataGenError,
DataGenTypeError,
DataGenImportError,
)

from unittest import mock
import pytest
Expand Down Expand Up @@ -61,9 +65,10 @@ def test_bogus_plugin(self):
fields:
service_name: saascrmlightning
"""
with pytest.raises(TypeError) as e:
with pytest.raises(DataGenTypeError) as e:
generate(StringIO(yaml), {})
assert "TestCustomPlugin" in str(e.value)
assert ":2" in str(e.value)

def test_missing_plugin(self):
yaml = """
Expand All @@ -72,7 +77,7 @@ def test_missing_plugin(self):
fields:
service_name: saascrmlightning
"""
with pytest.raises(ImportError) as e:
with pytest.raises(DataGenImportError) as e:
generate(StringIO(yaml), {})
assert "xyzzy" in str(e.value)

Expand Down

0 comments on commit b6535a2

Please sign in to comment.