diff --git a/gel/_testbase.py b/gel/_testbase.py index d7f2b77b..2052fe7f 100644 --- a/gel/_testbase.py +++ b/gel/_testbase.py @@ -33,11 +33,12 @@ import tempfile import time import unittest +import warnings import gel from gel import asyncio_client from gel import blocking_client -from gel.orm.introspection import get_schema_json +from gel.orm.introspection import get_schema_json, GelORMWarning from gel.orm.sqla import ModelGenerator as SQLAModGen from gel.orm.django.generator import ModelGenerator as DjangoModGen @@ -646,17 +647,20 @@ def setUpClass(cls): if importlib.util.find_spec("psycopg2") is None: raise unittest.SkipTest("need psycopg2 for ORM tests") - super().setUpClass() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", GelORMWarning) - class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP') - if not class_set_up: - # We'll need a temp directory to setup the generated Python - # package - cls.tmpormdir = tempfile.TemporaryDirectory() - sys.path.append(cls.tmpormdir.name) - # Now that the DB is setup, generate the ORM models from it - cls.spec = get_schema_json(cls.client) - cls.setupORM() + super().setUpClass() + + class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP') + if not class_set_up: + # We'll need a temp directory to setup the generated Python + # package + cls.tmpormdir = tempfile.TemporaryDirectory() + sys.path.append(cls.tmpormdir.name) + # Now that the DB is setup, generate the ORM models from it + cls.spec = get_schema_json(cls.client) + cls.setupORM() @classmethod def setupORM(cls): diff --git a/gel/orm/cli.py b/gel/orm/cli.py index 1ae29e55..fa50a3b5 100644 --- a/gel/orm/cli.py +++ b/gel/orm/cli.py @@ -18,11 +18,12 @@ import argparse +import warnings import gel from gel.codegen.generator import _get_conn_args -from .introspection import get_schema_json +from .introspection import get_schema_json, GelORMWarning from .sqla import ModelGenerator as SQLAModGen from .django.generator import ModelGenerator as DjangoModGen @@ -73,8 +74,15 @@ def main(): args = parser.parse_args() # setup client client = gel.create_client(**_get_conn_args(args)) - spec = get_schema_json(client) - generate_models(args, spec) + + with warnings.catch_warnings(record=True) as wlist: + warnings.simplefilter("always", GelORMWarning) + + spec = get_schema_json(client) + generate_models(args, spec) + + for w in wlist: + print(w.message) def generate_models(args, spec): diff --git a/gel/orm/django/generator.py b/gel/orm/django/generator.py index 93c45b0d..fbcdf901 100644 --- a/gel/orm/django/generator.py +++ b/gel/orm/django/generator.py @@ -1,7 +1,8 @@ import pathlib import re +import warnings -from ..introspection import get_mod_and_name, FilePrinter +from ..introspection import get_mod_and_name, GelORMWarning, FilePrinter GEL_SCALAR_MAP = { @@ -24,7 +25,7 @@ 'cal::local_date': 'DateField', 'cal::local_datetime': 'DateTimeField', 'cal::local_time': 'TimeField', - # all kinds of duration is not supported due to this error: + # all kinds of durations are not supported due to this error: # iso_8601 intervalstyle currently not supported } @@ -53,7 +54,6 @@ class GelPGMeta: 'This is a model reflected from Gel using Postgres protocol.' ''' -FK_RE = re.compile(r'''models\.ForeignKey\((.+?),''') CLOSEPAR_RE = re.compile(r'\)(?=\s+#|$)') @@ -83,19 +83,16 @@ def __init__(self, *, out): def spec_to_modules_dict(self, spec): modules = { - mod: {} for mod in sorted(spec['modules']) + mod: {'link_tables': {}, 'object_types': {}} + for mod in sorted(spec['modules']) } for rec in spec['link_tables']: mod = rec['module'] - if 'link_tables' not in modules[mod]: - modules[mod]['link_tables'] = {} modules[mod]['link_tables'][rec['table']] = rec for rec in spec['object_types']: mod, name = get_mod_and_name(rec['name']) - if 'object_types' not in modules[mod]: - modules[mod]['object_types'] = {} modules[mod]['object_types'][name] = rec return modules['default'] @@ -128,10 +125,12 @@ def build_models(self, maps): # process properties as fields for prop in rec['properties']: pname = prop['name'] - if pname == 'id': + if pname == 'id' or prop['cardinality'] == 'Many': continue - mod.props[pname] = self.render_prop(prop) + code = self.render_prop(prop) + if code: + mod.props[pname] = code # process single links as fields for link in rec['links']: @@ -142,7 +141,9 @@ def build_models(self, maps): lname = link['name'] bklink = mod.get_backlink_name(lname) - mod.links[lname] = self.render_link(link, bklink) + code = self.render_link(link, bklink) + if code: + mod.links[lname] = code modmap[mod.name] = mod @@ -153,7 +154,16 @@ def build_models(self, maps): mod.meta['unique_together'] = "(('source', 'target'),)" # Only have source and target - _, target = get_mod_and_name(rec['target']) + mtgt, target = get_mod_and_name(rec['target']) + if mtgt != 'default': + # skip this whole link table + warnings.warn( + f'Skipping link {fwname!r}: link target ' + f'{rec["target"]!r} is not supported', + GelORMWarning, + ) + continue + mod.links['source'] = ( f"LTForeignKey({source!r}, models.DO_NOTHING, " f"db_column='source', primary_key=True)" @@ -190,8 +200,11 @@ def render_prop(self, prop): try: ftype = GEL_SCALAR_MAP[target] except KeyError: - raise RuntimeError( - f'Scalar type {target} is not supported') + warnings.warn( + f'Scalar type {target} is not supported', + GelORMWarning, + ) + return '' return f'models.{ftype}({req})' @@ -201,7 +214,15 @@ def render_link(self, link, bklink=None): else: req = ', blank=True, null=True' - _, target = get_mod_and_name(link['target']['name']) + mod, target = get_mod_and_name(link['target']['name']) + + if mod != 'default': + warnings.warn( + f'Skipping link {link["name"]!r}: link target ' + f'{link["target"]["name"]!r} is not supported', + GelORMWarning, + ) + return '' if bklink: bklink = f', related_name={bklink!r}' @@ -215,23 +236,28 @@ def render_models(self, spec): # Check that there is only "default" module mods = spec['modules'] if mods[0] != 'default' or len(mods) > 1: - raise RuntimeError( - f"Django reflection doesn't support multiple modules or " - f"non-default modules." + skipped = ', '.join([repr(m) for m in mods if m != 'default']) + warnings.warn( + f"Skipping modules {skipped}: Django reflection doesn't " + f"support multiple modules or non-default modules.", + GelORMWarning, ) # Check that we don't have multiprops or link properties as they # produce models without `id` field and Django doesn't like that. It # causes Django to mistakenly use `source` as `id` and also attempt to # UPDATE `target` on link tables. if len(spec['prop_objects']) > 0: - raise RuntimeError( - f"Django reflection doesn't support multi properties as they " - f"produce models without `id` field." + warnings.warn( + f"Skipping multi properties: Django reflection doesn't " + f"support multi properties as they produce models without " + f"`id` field.", + GelORMWarning, ) if len(spec['link_objects']) > 0: - raise RuntimeError( - f"Django reflection doesn't support link properties as they " - f"produce models without `id` field." + warnings.warn( + f"Skipping link properties: Django reflection doesn't support " + f"link properties as they produce models without `id` field.", + GelORMWarning, ) maps = self.spec_to_modules_dict(spec) diff --git a/gel/orm/introspection.py b/gel/orm/introspection.py index 03c50586..97bacb3d 100644 --- a/gel/orm/introspection.py +++ b/gel/orm/introspection.py @@ -2,6 +2,7 @@ import re import collections import textwrap +import warnings INTRO_QUERY = ''' @@ -62,6 +63,10 @@ CLEAN_NAME = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$') +class GelORMWarning(Warning): + pass + + def get_sql_name(name): # Just remove the module name name = name.rsplit('::', 1)[-1] @@ -80,12 +85,16 @@ def get_mod_and_name(name): return name.rsplit('::', 1) -def check_name(name): +def valid_name(name): # Just remove module separators and check the rest name = name.replace('::', '') if not CLEAN_NAME.fullmatch(name): - raise RuntimeError( - f'Non-alphanumeric names are not supported: {name}') + warnings.warn( + f'Skipping {name!r}: non-alphanumeric names are not supported', + GelORMWarning, + ) + return False + return True def get_schema_json(client): @@ -102,6 +111,23 @@ async def async_get_schema_json(client): return _process_links(types, modules) +def _skip_invalid_names(spec_list, recurse_into=None): + valid = [] + for spec in spec_list: + # skip invalid names + if valid_name(spec['name']): + if recurse_into is not None: + for fname in recurse_into: + if fname not in spec: + continue + spec[fname] = _skip_invalid_names( + spec[fname], recurse_into) + + valid.append(spec) + + return valid + + def _process_links(types, modules): # Figure out all the backlinks, link tables, and links with link # properties that require their own intermediate objects. @@ -110,23 +136,20 @@ def _process_links(types, modules): link_objects = [] prop_objects = [] + # All the names of types, props and links are valid beyond this point. + types = _skip_invalid_names(types, ['properties', 'links']) for spec in types: - check_name(spec['name']) type_map[spec['name']] = spec spec['backlink_renames'] = {} - for prop in spec['properties']: - check_name(prop['name']) - for spec in types: mod = spec["name"].rsplit('::', 1)[0] sql_source = get_sql_name(spec["name"]) for prop in spec['properties']: + name = prop['name'] exclusive = prop['exclusive'] cardinality = prop['cardinality'] - name = prop['name'] - check_name(name) sql_name = get_sql_name(name) if cardinality == 'Many': @@ -158,11 +181,10 @@ def _process_links(types, modules): for link in spec['links']: if link['name'] != '__type__': + name = link['name'] target = link['target']['name'] cardinality = link['cardinality'] exclusive = link['exclusive'] - name = link['name'] - check_name(name) sql_name = get_sql_name(name) objtype = type_map[target] @@ -175,8 +197,6 @@ def _process_links(types, modules): 'has_link_object': False, }) - for prop in link['properties']: - check_name(prop['name']) link['has_link_object'] = False # Any link with properties should become its own intermediate diff --git a/gel/orm/sqla.py b/gel/orm/sqla.py index 093096b1..1fdfc4d3 100644 --- a/gel/orm/sqla.py +++ b/gel/orm/sqla.py @@ -1,10 +1,11 @@ import pathlib import re +import warnings from contextlib import contextmanager from .introspection import get_sql_name, get_mod_and_name -from .introspection import FilePrinter +from .introspection import GelORMWarning, FilePrinter GEL_SCALAR_MAP = { @@ -134,6 +135,14 @@ def spec_to_modules_dict(self, spec): mod: {} for mod in sorted(spec['modules']) } + if len(spec['prop_objects']) > 0: + warnings.warn( + f"Skipping multi properties: SQLAlchemy reflection doesn't " + f"support multi properties as they produce models without a " + f"clear identity.", + GelORMWarning, + ) + for rec in spec['link_tables']: mod = rec['module'] if 'link_tables' not in modules[mod]: @@ -146,12 +155,6 @@ def spec_to_modules_dict(self, spec): modules[mod]['link_objects'] = {} modules[mod]['link_objects'][lobj['name']] = lobj - for pobj in spec['prop_objects']: - mod = pobj['module'] - if 'prop_objects' not in modules[mod]: - modules[mod]['prop_objects'] = {} - modules[mod]['prop_objects'][pobj['name']] = pobj - for rec in spec['object_types']: mod, name = get_mod_and_name(rec['name']) if 'object_types' not in modules[mod]: @@ -190,10 +193,6 @@ def render_models(self, spec): self.write() self.render_link_object(lobj, modules) - for pobj in maps.get('prop_objects', {}).values(): - self.write() - self.render_prop_object(pobj) - for rec in maps.get('object_types', {}).values(): self.write() self.render_type(rec, modules) @@ -275,50 +274,6 @@ def render_link_object(self, spec, modules): self.dedent() - def render_prop_object(self, spec): - mod = spec['module'] - name = spec['name'] - sql_name = spec['table'] - bklink = sql_name.split('.')[-1] - - self.write() - self.write(f'class {name}(Base):') - self.indent() - self.write(f'__tablename__ = {sql_name!r}') - if mod != 'default': - self.write(f'__table_args__ = {{"schema": {mod!r}}}') - # We rely on Gel for maintaining integrity and various on delete - # triggers, so the rows may be deleted in a different way from what - # SQLAlchemy expects. - self.write('__mapper_args__ = {"confirm_deleted_rows": False}') - self.write() - # No ids for these intermediate objects - - # Link to the source type (with "backlink" being the original - # property) - link = spec['links'][0] - self.write() - self.write('# Links:') - tmod, target = get_mod_and_name(link['target']['name']) - self.write(f'source_id: Mapped[uuid.UUID] = mapped_column(') - self.indent() - self.write(f'"source", Uuid(), ForeignKey("{target}.id"),') - self.write(f'primary_key=True, nullable=False,') - self.dedent() - self.write(')') - self.write( - f'source: Mapped[{target!r}] = ' - f'relationship(back_populates={bklink!r})' - ) - - # The target is the actual multi prop - prop = spec['properties'][0] - self.write() - self.write('# Properties:') - self.render_prop(prop, mod, name, {}, is_pk=True) - - self.dedent() - def render_type(self, spec, modules): # assume nice names for now mod, name = get_mod_and_name(spec['name']) @@ -384,8 +339,12 @@ def render_prop(self, spec, mod, parent, modules, *, is_pk=False): try: pytype, sqlatype = GEL_SCALAR_MAP[target] except KeyError: - raise RuntimeError( - f'Scalar type {target} is not supported') + warnings.warn( + f'Scalar type {target} is not supported', + GelORMWarning, + ) + # Skip rendering this one + return if is_pk: # special case of a primary key property (should only happen to @@ -395,22 +354,8 @@ def render_prop(self, spec, mod, parent, modules, *, is_pk=False): f'{sqlatype}(), primary_key=True, nullable=False)' ) elif cardinality == 'Many': - # multi property (treated as a link) - propobj = modules[mod]['prop_objects'][f'{parent}_{name}_prop'] - target = propobj['name'] - - if cardinality == 'One': - tmap = f'Mapped[{target!r}]' - elif cardinality == 'Many': - tmap = f'Mapped[List[{target!r}]]' - # We want the cascade to delete orphans here as the objects - # represent property leaves - self.write(f'{name}: {tmap} = relationship(') - self.indent() - self.write(f"back_populates='source',") - self.write(f"cascade='all, delete-orphan',") - self.dedent() - self.write(')') + # skip it + return else: # plain property diff --git a/tests/test_sqla_features.py b/tests/test_sqla_features.py index 7b79b818..52359703 100644 --- a/tests/test_sqla_features.py +++ b/tests/test_sqla_features.py @@ -74,86 +74,6 @@ def tearDown(self): super().tearDown() self.sess.rollback() - def test_sqla_multiprops_01(self): - val = self.sess.query(self.sm.MultiProp)\ - .filter_by(name='got one').one() - self.assertEqual( - {t.target for t in val.tags}, - {'solo tag'}, - ) - - val = self.sess.query(self.sm.MultiProp)\ - .filter_by(name='got many').one() - self.assertEqual( - {t.target for t in val.tags}, - {'one', 'two', 'three'}, - ) - - def test_sqla_multiprops_02(self): - val = self.sess.query(self.sm.MultiProp)\ - .filter_by(name='got one').one() - self.assertEqual( - {t.target for t in val.tags}, - {'solo tag'}, - ) - - # create and add a few more tags - self.sess.add( - self.sm.MultiProp_tags_prop(source=val, target='hello')) - self.sess.add( - self.sm.MultiProp_tags_prop(source=val, target='world')) - self.sess.flush() - - val = self.sess.query(self.sm.MultiProp)\ - .filter_by(name='got one').one() - self.assertEqual( - {t.target for t in val.tags}, - {'solo tag', 'hello', 'world'}, - ) - - def test_sqla_multiprops_03(self): - val = self.sess.query(self.sm.MultiProp)\ - .filter_by(name='got one').one() - self.assertEqual( - {t.target for t in val.tags}, - {'solo tag'}, - ) - - # create and add a few more tags (don't specify source explicitly) - val.tags.append( - self.sm.MultiProp_tags_prop(target='hello')) - val.tags.append( - self.sm.MultiProp_tags_prop(target='world')) - self.sess.flush() - - val = self.sess.query(self.sm.MultiProp)\ - .filter_by(name='got one').one() - self.assertEqual( - {t.target for t in val.tags}, - {'solo tag', 'hello', 'world'}, - ) - - def test_sqla_multiprops_04(self): - val = self.sess.query(self.sm.MultiProp)\ - .filter_by(name='got many').one() - self.assertEqual( - {t.target for t in val.tags}, - {'one', 'two', 'three'}, - ) - - # remove several tags - for t in list(val.tags): - if t.target != 'one': - val.tags.remove(t) - self.sess.flush() - - val = self.sess.query(self.sm.MultiProp)\ - .filter_by(name='got many').one() - self.assertEqual( - {t.target for t in val.tags}, - {'one'}, - ) - def test_sqla_linkprops_01(self): val = self.sess.query(self.sm.HasLinkPropsA).one() self.assertEqual(val.child.target.num, 0)