diff --git a/.flake8 b/.flake8 index ae5a762d..b4001848 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] ignore = B008,B023,B306,E203,E402,E731,D100,D101,D102,D103,D104,D105,W503,W504,E252,F999,F541 -exclude = .git,__pycache__,build,dist,.eggs +exclude = .git,__pycache__,build,dist,.eggs,generated diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8f393b81..1f7f9201 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -78,9 +78,15 @@ jobs: - name: Install Python Deps if: steps.release.outputs.version == 0 + env: + PYTHON_VERSION: ${{ matrix.python-version }} run: | python -m pip install --upgrade setuptools pip wheel - python -m pip install -e .[test] + if [ "${PYTHON_VERSION}" = "3.10" -o "${PYTHON_VERSION}" = "3.11" -o "${PYTHON_VERSION}" = "3.12" ]; then + python -m pip install -e .[test,sqltest] + else + python -m pip install -e .[test] + fi - name: Test if: steps.release.outputs.version == 0 diff --git a/gel/_testbase.py b/gel/_testbase.py index f379bcea..775f1a0c 100644 --- a/gel/_testbase.py +++ b/gel/_testbase.py @@ -21,6 +21,7 @@ import atexit import contextlib import functools +import importlib.util import inspect import json import logging @@ -35,6 +36,8 @@ import gel from gel import asyncio_client from gel import blocking_client +from gel.orm.introspection import get_schema_json +from gel.orm.sqla import ModelGenerator log = logging.getLogger(__name__) @@ -444,6 +447,7 @@ class DatabaseTestCase(ClusterTestCase, ConnectedTestCaseMixin): SETUP = None TEARDOWN = None SCHEMA = None + DEFAULT_MODULE = 'test' SETUP_METHOD = None TEARDOWN_METHOD = None @@ -521,15 +525,18 @@ def get_database_name(cls): @classmethod def get_setup_script(cls): script = '' + schema = [] # Look at all SCHEMA entries and potentially create multiple - # modules, but always create the 'test' module. - schema = ['\nmodule test {}'] + # modules, but always create the test module, if not `default`. + if cls.DEFAULT_MODULE != 'default': + schema.append(f'\nmodule {cls.DEFAULT_MODULE} {{}}') for name, val in cls.__dict__.items(): m = re.match(r'^SCHEMA(?:_(\w+))?', name) if m: - module_name = (m.group(1) or 'test').lower().replace( - '__', '.') + module_name = ( + m.group(1) or cls.DEFAULT_MODULE + ).lower().replace('_', '::') with open(val, 'r') as sf: module = sf.read() @@ -623,6 +630,54 @@ def adapt_call(cls, result): return result +class SQLATestCase(SyncQueryTestCase): + SQLAPACKAGE = None + DEFAULT_MODULE = 'default' + + @classmethod + def setUpClass(cls): + # SQLAlchemy relies on psycopg2 to connect to Postgres and thus we + # need it to run tests. Unfortunately not all test environemnts might + # have psycopg2 installed, as long as we run this in the test + # environments that have this, it is fine since we're not expecting + # different functionality based on flavours of psycopg2. + if importlib.util.find_spec("psycopg2") is None: + raise unittest.SkipTest("need psycopg2 for ORM tests") + + super().setUpClass() + + class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP') + if not class_set_up: + # Now that the DB is setup, generate the SQLAlchemy models from it + spec = get_schema_json(cls.client) + # We'll need a temp directory to setup the generated Python + # package + cls.tmpsqladir = tempfile.TemporaryDirectory() + gen = ModelGenerator( + outdir=os.path.join(cls.tmpsqladir.name, cls.SQLAPACKAGE), + basemodule=cls.SQLAPACKAGE, + ) + gen.render_models(spec) + sys.path.append(cls.tmpsqladir.name) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + # cleanup the temp modules + sys.path.remove(cls.tmpsqladir.name) + cls.tmpsqladir.cleanup() + + @classmethod + def get_dsn_for_sqla(cls): + cargs = cls.get_connect_args(database=cls.get_database_name()) + dsn = ( + f'postgresql://{cargs["user"]}:{cargs["password"]}' + f'@{cargs["host"]}:{cargs["port"]}/{cargs["database"]}' + ) + + return dsn + + _lock_cnt = 0 diff --git a/gel/orm/__init__.py b/gel/orm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/gel/orm/cli.py b/gel/orm/cli.py new file mode 100644 index 00000000..5f387879 --- /dev/null +++ b/gel/orm/cli.py @@ -0,0 +1,86 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import argparse + +import gel + +from gel.codegen.generator import _get_conn_args +from .introspection import get_schema_json +from .sqla import ModelGenerator + + +class ArgumentParser(argparse.ArgumentParser): + def error(self, message): + self.exit( + 2, + f"error: {message:s}\n", + ) + + +parser = ArgumentParser( + description="Generate Python ORM code for accessing a Gel database." +) +parser.add_argument( + "orm", + choices=['sqlalchemy', 'django'], + help="Pick which ORM to generate models for.", +) +parser.add_argument("--dsn") +parser.add_argument("--credentials-file", metavar="PATH") +parser.add_argument("-I", "--instance", metavar="NAME") +parser.add_argument("-H", "--host") +parser.add_argument("-P", "--port") +parser.add_argument("-d", "--database", metavar="NAME") +parser.add_argument("-u", "--user") +parser.add_argument("--password") +parser.add_argument("--password-from-stdin", action="store_true") +parser.add_argument("--tls-ca-file", metavar="PATH") +parser.add_argument( + "--tls-security", + choices=["default", "strict", "no_host_verification", "insecure"], +) +parser.add_argument( + "--out", + help="The output directory for the generated files.", + required=True, +) +parser.add_argument( + "--mod", + help="The fullname of the Python module corresponding to the output " + "directory.", + required=True, +) + + +def main(): + args = parser.parse_args() + # setup client + client = gel.create_client(**_get_conn_args(args)) + spec = get_schema_json(client) + + match args.orm: + case 'sqlalchemy': + gen = ModelGenerator( + outdir=args.out, + basemodule=args.mod, + ) + gen.render_models(spec) + case 'django': + print('Not available yet. Coming soon!') diff --git a/gel/orm/introspection.py b/gel/orm/introspection.py new file mode 100644 index 00000000..f19bffd7 --- /dev/null +++ b/gel/orm/introspection.py @@ -0,0 +1,234 @@ +import json +import re +import collections + + +INTRO_QUERY = ''' +with module schema +select ObjectType { + name, + links: { + name, + readonly, + required, + cardinality, + exclusive := exists ( + select .constraints + filter .name = 'std::exclusive' + ), + target: {name}, + + properties: { + name, + readonly, + required, + cardinality, + exclusive := exists ( + select .constraints + filter .name = 'std::exclusive' + ), + target: {name}, + }, + } filter .name != '__type__', + properties: { + name, + readonly, + required, + cardinality, + exclusive := exists ( + select .constraints + filter .name = 'std::exclusive' + ), + target: {name}, + }, + backlinks := >[], +} +filter + not .builtin + and + not .internal + and + not re_test('^(std|cfg|sys|schema)::', .name); +''' + +MODULE_QUERY = ''' +with + module schema, + m := (select `Module` filter not .builtin) +select m.name; +''' + +CLEAN_NAME = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$') + + +def get_sql_name(name): + # Just remove the module name + name = name.rsplit('::', 1)[-1] + + return name + + +def check_name(name): + # Just remove module separators and check the rest + name = name.replace('::', '') + if not CLEAN_NAME.fullmatch(name): + raise RuntimeError( + f'Non-alphanumeric names are not supported: {name}') + + +def get_schema_json(client): + types = json.loads(client.query_json(INTRO_QUERY)) + modules = json.loads(client.query_json(MODULE_QUERY)) + + return _process_links(types, modules) + + +async def async_get_schema_json(client): + types = json.loads(await client.query_json(INTRO_QUERY)) + modules = json.loads(client.query_json(MODULE_QUERY)) + + return _process_links(types, modules) + + +def _process_links(types, modules): + # Figure out all the backlinks, link tables, and links with link + # properties that require their own intermediate objects. + type_map = {} + link_tables = [] + link_objects = [] + prop_objects = [] + + for spec in types: + check_name(spec['name']) + type_map[spec['name']] = spec + spec['backlink_renames'] = {} + + for prop in spec['properties']: + check_name(prop['name']) + + for spec in types: + mod = spec["name"].rsplit('::', 1)[0] + sql_source = get_sql_name(spec["name"]) + + for prop in spec['properties']: + exclusive = prop['exclusive'] + cardinality = prop['cardinality'] + name = prop['name'] + check_name(name) + sql_name = get_sql_name(name) + + if cardinality == 'Many': + # Multi property will make its own "link table". But since it + # doesn't link to any other object the link table itself must + # be reflected as an object. + pobj = { + 'module': mod, + 'name': f'{sql_source}_{sql_name}_prop', + 'table': f'{sql_source}.{sql_name}', + 'links': [{ + 'name': 'source', + 'required': True, + 'cardinality': 'One' if exclusive else 'Many', + 'exclusive': cardinality == 'One', + 'target': {'name': spec['name']}, + 'has_link_object': False, + }], + 'properties': [{ + 'name': 'target', + 'required': True, + 'cardinality': 'One', + 'exclusive': False, + 'target': prop['target'], + 'has_link_object': False, + }], + } + prop_objects.append(pobj) + + for link in spec['links']: + if link['name'] != '__type__': + target = link['target']['name'] + cardinality = link['cardinality'] + exclusive = link['exclusive'] + name = link['name'] + check_name(name) + sql_name = get_sql_name(name) + + objtype = type_map[target] + objtype['backlinks'].append({ + 'name': f'backlink_via_{sql_name}', + # flip cardinality and exclusivity + 'cardinality': 'One' if exclusive else 'Many', + 'exclusive': cardinality == 'One', + 'target': {'name': spec['name']}, + 'has_link_object': False, + }) + + for prop in link['properties']: + check_name(prop['name']) + + link['has_link_object'] = False + # Any link with properties should become its own intermediate + # object, since ORMs generally don't have a special convenient + # way of exposing this as just a link table. + if len(link['properties']) > 2: + # more than just 'source' and 'target' properties + lobj = { + 'module': mod, + 'name': f'{sql_source}_{sql_name}_link', + 'table': f'{sql_source}.{sql_name}', + 'links': [], + 'properties': [], + } + for prop in link['properties']: + if prop['name'] in {'source', 'target'}: + lobj['links'].append(prop) + else: + lobj['properties'].append(prop) + + link_objects.append(lobj) + link['has_link_object'] = True + objtype['backlinks'][-1]['has_link_object'] = True + + elif cardinality == 'Many': + # Add a link table for One-to-Many and Many-to-Many + link_tables.append({ + 'module': mod, + 'name': f'{sql_source}_{sql_name}_table', + 'table': f'{sql_source}.{sql_name}', + 'source': spec["name"], + 'target': target, + }) + + # Go over backlinks and resolve any name collisions using the type map. + for spec in types: + mod = spec["name"].rsplit('::', 1)[0] + sql_source = get_sql_name(spec["name"]) + + # Find collisions in backlink names + bk = collections.defaultdict(list) + for link in spec['backlinks']: + if link['name'].startswith('backlink_via_'): + bk[link['name']].append(link) + + for bklinks in bk.values(): + if len(bklinks) > 1: + # We have a collision, so each backlink in it must now be + # disambiguated. + for link in bklinks: + origsrc = get_sql_name(link['target']['name']) + lname = link['name'] + link['name'] = f'{lname}_from_{origsrc}' + # Also update the original source of the link with the + # special backlink name. + source = type_map[link['target']['name']] + fwname = lname.replace('backlink_via_', '', 1) + link['fwname'] = fwname + source['backlink_renames'][fwname] = link['name'] + + return { + 'modules': modules, + 'object_types': types, + 'link_tables': link_tables, + 'link_objects': link_objects, + 'prop_objects': prop_objects, + } diff --git a/gel/orm/sqla.py b/gel/orm/sqla.py new file mode 100644 index 00000000..bf18f264 --- /dev/null +++ b/gel/orm/sqla.py @@ -0,0 +1,573 @@ +import pathlib +import re +import textwrap + +from contextlib import contextmanager + +from .introspection import get_sql_name + + +INDENT = ' ' * 4 + +GEL_SCALAR_MAP = { + 'std::bool': ('bool', 'Boolean'), + 'std::str': ('str', 'String'), + 'std::int16': ('int', 'Integer'), + 'std::int32': ('int', 'Integer'), + 'std::int64': ('int', 'Integer'), + 'std::float32': ('float', 'Float'), + 'std::float64': ('float', 'Float'), + 'std::uuid': ('uuid.UUID', 'Uuid'), +} + +CLEAN_RE = re.compile(r'[^A-Za-z0-9]+') + +COMMENT = '''\ +# +# Automatically generated from Gel schema. +#\ +''' + +BASE_STUB = f'''\ +{COMMENT} + +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): + pass\ +''' + +MODELS_STUB = f'''\ +{COMMENT} + +import uuid + +from typing import List +from typing import Optional + +from sqlalchemy import MetaData, Table, Column, ForeignKey +from sqlalchemy import String, Uuid, Integer, Float, Boolean +from sqlalchemy.orm import Mapped, mapped_column, relationship +''' + + +def get_mod_and_name(name): + # Assume the names are already validated to be properly formed + # alphanumeric identifiers that may be prefixed by a module. If the module + # is present assume it is safe to drop it (currently only defualt module + # is allowed). + + # Split on module separator. Potentially if we ever handle more unusual + # names, there may be more processing done. + return name.rsplit('::', 1) + + +class ModelGenerator(object): + INDENT = ' ' * 4 + + def __init__(self, *, outdir=None, basemodule=None): + # set the output to be stdout by default, but this is generally + # expected to be overridden by appropriate files in the `outdir` + if outdir is not None: + self.outdir = pathlib.Path(outdir) + else: + self.outdir = None + + self.basemodule = basemodule + self.out = None + self._indent_level = 0 + + def indent(self): + self._indent_level += 1 + + def dedent(self): + if self._indent_level > 0: + self._indent_level -= 1 + + def reset_indent(self): + self._indent_level -= 0 + + def write(self, text=''): + print( + textwrap.indent(text, prefix=self.INDENT * self._indent_level), + file=self.out, + ) + + def init_dir(self, dirpath): + if not dirpath: + # nothing to initialize + return + + path = pathlib.Path(dirpath).resolve() + + # ensure `path` directory exists + if not path.exists(): + path.mkdir() + elif not path.is_dir(): + raise NotADirectoryError( + f'{path!r} exists, but it is not a directory') + + # ensure `path` directory contains `__init__.py` + (path / '__init__.py').touch() + + def init_sqlabase(self): + with open(self.outdir / '_sqlabase.py', 'wt') as f: + self.out = f + self.write(BASE_STUB) + + @contextmanager + def init_module(self, mod, modules): + if any(m.startswith(f'{mod}::') for m in modules): + # This is a prefix in another module, thus it is part of a nested + # module structure. + dirpath = mod.split('::') + filename = '__init__.py' + else: + # This is a leaf module, so we just need to create a corresponding + # .py file. + *dirpath, filename = mod.split('::') + filename = f'{filename}.py' + + # Along the dirpath we need to ensure that all packages are created + path = self.outdir + for el in dirpath: + path = path / el + self.init_dir(path) + + with open(path / filename, 'wt') as f: + try: + self.out = f + self.write(MODELS_STUB) + relimport = '.' * len(dirpath) + self.write(f'from {relimport}._sqlabase import Base') + self.write(f'from {relimport}._tables import *') + yield f + finally: + self.out = None + + def get_fk(self, mod, table, curmod): + if mod == curmod: + # No need for anything fancy within the same schema + return f'ForeignKey("{table}.id")' + else: + return f'ForeignKey("{mod}.{table}.id")' + + def get_py_name(self, mod, name, curmod): + if False and mod == curmod: + # No need for anything fancy within the same module + return f"'{name}'" + else: + mod = mod.replace('::', '.') + return f"'{self.basemodule}.{mod}.{name}'" + + def render_models(self, spec): + # The modules dict will be populated with the respective types, link + # tables, etc., since they will need to be put in their own files. We + # sort the modules so that nested modules are initialized from root to + # leaf. + modules = { + mod: {} for mod in sorted(spec['modules']) + } + + for rec in spec['link_tables']: + mod = rec['module'] + if 'link_tables' not in modules[mod]: + modules[mod]['link_tables'] = [] + modules[mod]['link_tables'].append(rec) + + for lobj in spec['link_objects']: + mod = lobj['module'] + if 'link_objects' not in modules[mod]: + modules[mod]['link_objects'] = {} + modules[mod]['link_objects'][lobj['name']] = lobj + + for pobj in spec['prop_objects']: + mod = pobj['module'] + if 'prop_objects' not in modules[mod]: + modules[mod]['prop_objects'] = {} + modules[mod]['prop_objects'][pobj['name']] = pobj + + for rec in spec['object_types']: + mod, name = get_mod_and_name(rec['name']) + if 'object_types' not in modules[mod]: + modules[mod]['object_types'] = {} + modules[mod]['object_types'][name] = rec + + # Initialize the base directory + self.init_dir(self.outdir) + self.init_sqlabase() + + with open(self.outdir / '_tables.py', 'wt') as f: + self.out = f + self.write(MODELS_STUB) + self.write(f'from ._sqlabase import Base') + + for rec in spec['link_tables']: + self.write() + self.render_link_table(rec) + + for mod, maps in modules.items(): + with self.init_module(mod, modules): + if not maps: + # skip apparently empty modules + continue + + for lobj in maps.get('link_objects', {}).values(): + self.write() + self.render_link_object(lobj, modules) + + for pobj in maps.get('prop_objects', {}).values(): + self.write() + self.render_prop_object(pobj) + + for rec in maps.get('object_types', {}).values(): + self.write() + self.render_type(rec, modules) + + def render_link_table(self, spec): + mod, source = get_mod_and_name(spec["source"]) + tmod, target = get_mod_and_name(spec["target"]) + s_fk = self.get_fk(mod, source, 'default') + t_fk = self.get_fk(tmod, target, 'default') + + self.write() + self.write(f'{spec["name"]} = Table(') + self.indent() + self.write(f'{spec["table"]!r},') + self.write(f'Base.metadata,') + # source is in the same module as this table + self.write(f'Column("source", {s_fk}),') + self.write(f'Column("target", {t_fk}),') + self.write(f'schema={mod!r},') + self.dedent() + self.write(f')') + + def render_link_object(self, spec, modules): + mod = spec['module'] + name = spec['name'] + sql_name = spec['table'] + source_name, source_link = sql_name.split('.') + + self.write() + self.write(f'class {name}(Base):') + self.indent() + self.write(f'__tablename__ = {sql_name!r}') + if mod != 'default': + self.write(f'__table_args__ = {{"schema": {mod!r}}}') + # We rely on Gel for maintaining integrity and various on delete + # triggers, so the rows may be deleted in a different way from what + # SQLAlchemy expects. + self.write('__mapper_args__ = {"confirm_deleted_rows": False}') + self.write() + + # No ids for these intermediate objects + if spec['links']: + self.write() + self.write('# Links:') + + for link in spec['links']: + lname = link['name'] + tmod, target = get_mod_and_name(link['target']['name']) + fk = self.get_fk(tmod, target, mod) + pyname = self.get_py_name(tmod, target, mod) + self.write(f'{lname}_id: Mapped[uuid.UUID] = mapped_column(') + self.indent() + self.write(f'{lname!r}, Uuid(), {fk},') + self.write(f'primary_key=True, nullable=False,') + self.dedent() + self.write(')') + + if lname == 'source': + bklink = source_link + else: + src = modules[mod]['object_types'][source_name] + bklink = src['backlink_renames'].get( + source_link, + f'backlink_via_{source_link}', + ) + + self.write( + f'{lname}: Mapped[{pyname}] = ' + f'relationship(back_populates={bklink!r})' + ) + + if spec['properties']: + self.write() + self.write('# Properties:') + + for prop in spec['properties']: + if prop['name'] != 'id': + self.render_prop(prop, mod, name, {}) + + self.dedent() + + def render_prop_object(self, spec): + mod = spec['module'] + name = spec['name'] + sql_name = spec['table'] + bklink = sql_name.split('.')[-1] + + self.write() + self.write(f'class {name}(Base):') + self.indent() + self.write(f'__tablename__ = {sql_name!r}') + if mod != 'default': + self.write(f'__table_args__ = {{"schema": {mod!r}}}') + # We rely on Gel for maintaining integrity and various on delete + # triggers, so the rows may be deleted in a different way from what + # SQLAlchemy expects. + self.write('__mapper_args__ = {"confirm_deleted_rows": False}') + self.write() + # No ids for these intermediate objects + + # Link to the source type (with "backlink" being the original + # property) + link = spec['links'][0] + self.write() + self.write('# Links:') + tmod, target = get_mod_and_name(link['target']['name']) + self.write(f'source_id: Mapped[uuid.UUID] = mapped_column(') + self.indent() + self.write(f'"source", Uuid(), ForeignKey("{target}.id"),') + self.write(f'primary_key=True, nullable=False,') + self.dedent() + self.write(')') + self.write( + f'source: Mapped[{target!r}] = ' + f'relationship(back_populates={bklink!r})' + ) + + # The target is the actual multi prop + prop = spec['properties'][0] + self.write() + self.write('# Properties:') + self.render_prop(prop, mod, name, {}, is_pk=True) + + self.dedent() + + def render_type(self, spec, modules): + # assume nice names for now + mod, name = get_mod_and_name(spec['name']) + sql_name = get_sql_name(spec['name']) + + self.write() + self.write(f'class {name}(Base):') + self.indent() + self.write(f'__tablename__ = {sql_name!r}') + if mod != 'default': + self.write(f'__table_args__ = {{"schema": {mod!r}}}') + # We rely on Gel for maintaining integrity and various on delete + # triggers, so the rows may be deleted in a different way from what + # SQLAlchemy expects. + self.write('__mapper_args__ = {"confirm_deleted_rows": False}') + self.write() + + # Add two fields that all objects have + self.write(f'id: Mapped[uuid.UUID] = mapped_column(') + self.indent() + self.write( + f"Uuid(), primary_key=True, server_default='uuid_generate_v4()')") + self.dedent() + + # This is maintained entirely by Gel, the server_default simply + # indicates to SQLAlchemy that this value may be omitted. + self.write(f'gel_type_id: Mapped[uuid.UUID] = mapped_column(') + self.indent() + self.write( + f"'__type__', Uuid(), unique=True, server_default='PLACEHOLDER')") + self.dedent() + + if spec['properties']: + self.write() + self.write('# Properties:') + + for prop in spec['properties']: + if prop['name'] != 'id': + self.render_prop(prop, mod, name, modules) + + if spec['links']: + self.write() + self.write('# Links:') + + for link in spec['links']: + self.render_link(link, mod, name, modules) + + if spec['backlinks']: + self.write() + self.write('# Back-links:') + + for link in spec['backlinks']: + self.render_backlink(link, mod, modules) + + self.dedent() + + def render_prop(self, spec, mod, parent, modules, *, is_pk=False): + name = spec['name'] + nullable = not spec['required'] + cardinality = spec['cardinality'] + + target = spec['target']['name'] + try: + pytype, sqlatype = GEL_SCALAR_MAP[target] + except KeyError: + raise RuntimeError( + f'Scalar type {target} is not supported') + + if is_pk: + # special case of a primary key property (should only happen to + # 'target' in multi property table) + self.write( + f'{name}: Mapped[{pytype}] = mapped_column(' + f'{sqlatype}(), primary_key=True, nullable=False)' + ) + elif cardinality == 'Many': + # multi property (treated as a link) + propobj = modules[mod]['prop_objects'][f'{parent}_{name}_prop'] + target = propobj['name'] + + if cardinality == 'One': + tmap = f'Mapped[{target!r}]' + elif cardinality == 'Many': + tmap = f'Mapped[List[{target!r}]]' + # We want the cascade to delete orphans here as the objects + # represent property leaves + self.write(f'{name}: {tmap} = relationship(') + self.indent() + self.write(f"back_populates='source',") + self.write(f"cascade='all, delete-orphan',") + self.dedent() + self.write(')') + + else: + # plain property + self.write( + f'{name}: Mapped[{pytype}] = ' + f'mapped_column({sqlatype}(), nullable={nullable})' + ) + + def render_link(self, spec, mod, parent, modules): + name = spec['name'] + nullable = not spec['required'] + tmod, target = get_mod_and_name(spec['target']['name']) + source = modules[mod]['object_types'][parent] + cardinality = spec['cardinality'] + bklink = source['backlink_renames'].get(name, f'backlink_via_{name}') + + if spec.get('has_link_object'): + # intermediate object will have the actual source and target + # links, so the link here needs to be treated similar to a + # back-link. + linkobj = modules[mod]['link_objects'][f'{parent}_{name}_link'] + target = linkobj['name'] + tmod = linkobj['module'] + pyname = self.get_py_name(tmod, target, mod) + + if cardinality == 'One': + self.write( + f'{name}: Mapped[{pyname}] = ' + f"relationship(back_populates='source')" + ) + elif cardinality == 'Many': + self.write( + f'{name}: Mapped[List[{pyname}]] = ' + f"relationship(back_populates='source')" + ) + + if cardinality == 'One': + tmap = f'Mapped[{pyname}]' + elif cardinality == 'Many': + tmap = f'Mapped[List[{pyname}]]' + # We want the cascade to delete orphans here as the intermediate + # objects represent links and must not exist without source. + self.write(f'{name}: {tmap} = relationship(') + self.indent() + self.write(f"back_populates='source',") + self.write(f"cascade='all, delete-orphan',") + self.dedent() + self.write(')') + + else: + fk = self.get_fk(tmod, target, mod) + pyname = self.get_py_name(tmod, target, mod) + + if cardinality == 'One': + self.write( + f'{name}_id: Mapped[uuid.UUID] = ' + f'mapped_column(Uuid(), ' + f'{fk}, nullable={nullable})' + ) + self.write( + f'{name}: Mapped[{pyname}] = ' + f'relationship(back_populates={bklink!r})' + ) + + elif cardinality == 'Many': + secondary = f'{parent}_{name}_table' + self.write( + f'{name}: Mapped[List[{pyname}]] = relationship(') + self.indent() + self.write( + f'{pyname}, secondary={secondary}, ' + f'back_populates={bklink!r},' + ) + self.dedent() + self.write(')') + + def render_backlink(self, spec, mod, modules): + name = spec['name'] + tmod, target = get_mod_and_name(spec['target']['name']) + cardinality = spec['cardinality'] + exclusive = spec['exclusive'] + bklink = spec.get('fwname', name.replace('backlink_via_', '', 1)) + + if spec.get('has_link_object'): + # intermediate object will have the actual source and target + # links, so the link here needs to refer to the intermediate + # object and 'target' as back-link. + linkobj = modules[tmod]['link_objects'][f'{target}_{bklink}_link'] + target = linkobj['name'] + tmod = linkobj['module'] + pyname = self.get_py_name(tmod, target, mod) + + if cardinality == 'One': + tmap = f'Mapped[{pyname}]' + elif cardinality == 'Many': + tmap = f'Mapped[List[{pyname}]]' + # We want the cascade to delete orphans here as the intermediate + # objects represent links and must not exist without target. + self.write(f'{name}: {tmap} = relationship(') + self.indent() + self.write(f"back_populates='target',") + self.write(f"cascade='all, delete-orphan',") + self.dedent() + self.write(')') + + else: + pyname = self.get_py_name(tmod, target, mod) + if exclusive: + # This is a backlink from a single link. There is no link table + # involved. + if cardinality == 'One': + self.write( + f'{name}: Mapped[{pyname}] = ' + f'relationship(back_populates={bklink!r})' + ) + elif cardinality == 'Many': + self.write( + f'{name}: Mapped[List[{pyname}]] = ' + f'relationship(back_populates={bklink!r})' + ) + + else: + # This backlink involves a link table, so we still treat it as + # a Many-to-Many. + secondary = f'{target}_{bklink}_table' + self.write(f'{name}: Mapped[List[{pyname}]] = relationship(') + self.indent() + self.write( + f'{pyname}, secondary={secondary}, ' + f'back_populates={bklink!r},' + ) + self.dedent() + self.write(')') diff --git a/setup.py b/setup.py index f4f445a8..f4ea7b0c 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,15 @@ 'flake8-bugbear~=24.4.26', 'flake8~=7.0.0', 'uvloop>=0.15.1; platform_system != "Windows"', + 'SQLAlchemy>=2.0.0', +] + +# This is needed specifically to test ORM reflection because the ORMs tend to +# use this library to access Postgres. It's not always avaialable as a +# pre-built package and we don't necessarily want to try and build it from +# source. +SQLTEST_DEPENDENCIES = [ + 'psycopg2-binary>=2.9.10', ] # Dependencies required to build documentation. @@ -68,11 +77,12 @@ 'ai': AI_DEPENDENCIES, 'docs': DOC_DEPENDENCIES, 'test': TEST_DEPENDENCIES, + 'sqltest': SQLTEST_DEPENDENCIES, # Dependencies required to develop edgedb. 'dev': [ CYTHON_DEPENDENCY, 'pytest>=3.6.0', - ] + DOC_DEPENDENCIES + TEST_DEPENDENCIES + ] + DOC_DEPENDENCIES + TEST_DEPENDENCIES + SQLTEST_DEPENDENCIES } @@ -354,6 +364,7 @@ def finalize_options(self): "console_scripts": [ "edgedb-py=gel.codegen.cli:main", "gel-py=gel.codegen.cli:main", + "gel-orm=gel.orm.cli:main", ] } ) diff --git a/tests/dbsetup/base.edgeql b/tests/dbsetup/base.edgeql new file mode 100644 index 00000000..74959e71 --- /dev/null +++ b/tests/dbsetup/base.edgeql @@ -0,0 +1,44 @@ +insert User {name := 'Alice'}; +insert User {name := 'Billie'}; +insert User {name := 'Cameron'}; +insert User {name := 'Dana'}; +insert User {name := 'Elsa'}; +insert User {name := 'Zoe'}; + +insert UserGroup { + name := 'red', + users := (select User filter .name != 'Zoe'), +}; +insert UserGroup { + name := 'green', + users := (select User filter .name in {'Alice', 'Billie'}), +}; +insert UserGroup { + name := 'blue', +}; + +insert GameSession { + num := 123, + players := (select User filter .name in {'Alice', 'Billie'}), +}; +insert GameSession { + num := 456, + players := (select User filter .name in {'Dana'}), +}; + +insert Post { + author := assert_single((select User filter .name = 'Alice')), + body := 'Hello', +}; +insert Post { + author := assert_single((select User filter .name = 'Alice')), + body := "I'm Alice", +}; +insert Post { + author := assert_single((select User filter .name = 'Cameron')), + body := "I'm Cameron", +}; +insert Post { + author := assert_single((select User filter .name = 'Elsa')), + body := '*magic stuff*', +}; diff --git a/tests/dbsetup/base.esdl b/tests/dbsetup/base.esdl new file mode 100644 index 00000000..4a5c02c0 --- /dev/null +++ b/tests/dbsetup/base.esdl @@ -0,0 +1,23 @@ +abstract type Named { + required name: str; +} + +type UserGroup extending Named { + # many-to-many + multi link users: User; +} + +type GameSession { + required num: int64; + # one-to-many + multi link players: User { + constraint exclusive; + }; +} + +type User extending Named; + +type Post { + required body: str; + required link author: User; +} diff --git a/tests/dbsetup/features.edgeql b/tests/dbsetup/features.edgeql new file mode 100644 index 00000000..7e89aa6e --- /dev/null +++ b/tests/dbsetup/features.edgeql @@ -0,0 +1,82 @@ +insert Child {num := 0}; +insert Child {num := 1}; + +insert HasLinkPropsA { + child := (select Child{@a := 'single'} filter .num = 0) +}; + +insert HasLinkPropsB; +update HasLinkPropsB +set { + children += (select Child{@b := 'hello'} filter .num = 0) +}; +update HasLinkPropsB +set { + children += (select Child{@b := 'world'} filter .num = 1) +}; + +insert MultiProp { + name := 'got one', + tags := {'solo tag'}, +}; + +insert MultiProp { + name := 'got many', + tags := {'one', 'two', 'three'}, +}; + +insert other::nested::Leaf { + num := 10 +}; + +insert other::nested::Leaf { + num := 20 +}; + +insert other::nested::Leaf { + num := 30 +}; + +insert other::Branch { + val := 'big', + leaves := (select other::nested::Leaf filter .num != 10), +}; + +insert other::Branch { + val := 'small', + leaves := (select other::nested::Leaf filter .num = 10), +}; + +insert Theme { + color := 'green', + branch := ( + select other::Branch{@note := 'fresh'} filter .val = 'big' + ) +}; + +insert Theme { + color := 'orange', + branch := ( + select other::Branch{@note := 'fall'} filter .val = 'big' + ) +}; + +insert Foo { + name := 'foo' +}; + +insert Foo { + name := 'oof' +}; + +insert Bar { + n := 123, + foo := assert_single((select Foo filter .name = 'foo')), + many_foo := Foo, +}; + +insert Who { + x := 456, + foo := assert_single((select Foo filter .name = 'oof')), + many_foo := (select Foo{@note := 'just one'} filter .name = 'foo'), +}; \ No newline at end of file diff --git a/tests/dbsetup/features_default.esdl b/tests/dbsetup/features_default.esdl new file mode 100644 index 00000000..59b4bd6d --- /dev/null +++ b/tests/dbsetup/features_default.esdl @@ -0,0 +1,47 @@ +type Child { + required property num: int64 { + constraint exclusive; + } +}; + +type HasLinkPropsA { + link child: Child { + property a: str; + } +}; + +type HasLinkPropsB { + multi link children: Child { + property b: str; + } +}; + +type MultiProp { + required property name: str; + multi property tags: str; +}; + +type Theme { + required property color: str; + link branch: other::Branch { + property note: str; + } +}; + +type Foo { + required property name: str; +}; + +type Bar { + link foo: Foo; + multi link many_foo: Foo; + required property n: int64; +}; + +type Who { + link foo: Foo; + multi link many_foo: Foo { + property note: str; + }; + required property x: int64; +}; diff --git a/tests/dbsetup/features_other.esdl b/tests/dbsetup/features_other.esdl new file mode 100644 index 00000000..5f87435f --- /dev/null +++ b/tests/dbsetup/features_other.esdl @@ -0,0 +1,7 @@ +type Branch { + required property val: str { + constraint exclusive; + } + + multi link leaves: other::nested::Leaf; +}; diff --git a/tests/dbsetup/features_other_nested.esdl b/tests/dbsetup/features_other_nested.esdl new file mode 100644 index 00000000..4f02cd4b --- /dev/null +++ b/tests/dbsetup/features_other_nested.esdl @@ -0,0 +1,5 @@ +type Leaf { + required property num: int64 { + constraint exclusive; + } +}; diff --git a/tests/test_sqla_basic.py b/tests/test_sqla_basic.py new file mode 100644 index 00000000..a9b274a5 --- /dev/null +++ b/tests/test_sqla_basic.py @@ -0,0 +1,571 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import uuid + +from sqlalchemy import create_engine, select +from sqlalchemy.orm import Session + +from gel import _testbase as tb + + +class TestSQLABasic(tb.SQLATestCase): + SCHEMA = os.path.join(os.path.dirname(__file__), 'dbsetup', + 'base.esdl') + + SETUP = os.path.join(os.path.dirname(__file__), 'dbsetup', + 'base.edgeql') + + SQLAPACKAGE = 'basemodels' + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.engine = create_engine(cls.get_dsn_for_sqla()) + cls.sess = Session(cls.engine, autobegin=False) + + from basemodels import default + cls.sm = default + + def setUp(self): + super().setUp() + + if self.client.query_required_single(''' + select sys::get_version().major < 6 + '''): + self.skipTest("Test needs SQL DML queries") + + self.sess.begin() + + def tearDown(self): + super().tearDown() + self.sess.rollback() + + def test_sqla_read_models_01(self): + vals = {r.name for r in self.sess.query(self.sm.User).all()} + self.assertEqual( + vals, {'Alice', 'Billie', 'Cameron', 'Dana', 'Elsa', 'Zoe'}) + + vals = {r.name for r in self.sess.query(self.sm.UserGroup).all()} + self.assertEqual( + vals, {'red', 'green', 'blue'}) + + vals = {r.num for r in self.sess.query(self.sm.GameSession).all()} + self.assertEqual(vals, {123, 456}) + + vals = {r.body for r in self.sess.query(self.sm.Post).all()} + self.assertEqual( + vals, {'Hello', "I'm Alice", "I'm Cameron", '*magic stuff*'}) + + # Read from the abstract type + vals = {r.name for r in self.sess.query(self.sm.Named).all()} + self.assertEqual( + vals, + { + 'Alice', 'Billie', 'Cameron', 'Dana', 'Elsa', 'Zoe', + 'red', 'green', 'blue', + } + ) + + def test_sqla_read_models_02(self): + # test single link and the one-to-many backlink + # using load-on-demand + + res = self.sess.query(self.sm.Post).all() + vals = {(p.author.name, p.body) for p in res} + self.assertEqual( + vals, + { + ('Alice', 'Hello'), + ('Alice', "I'm Alice"), + ('Cameron', "I'm Cameron"), + ('Elsa', '*magic stuff*'), + } + ) + + # use backlink + res = self.sess.query(self.sm.User).order_by(self.sm.User.name).all() + vals = [ + (u.name, {p.body for p in u.backlink_via_author}) + for u in res + ] + self.assertEqual( + vals, + [ + ('Alice', {'Hello', "I'm Alice"}), + ('Billie', set()), + ('Cameron', {"I'm Cameron"}), + ('Dana', set()), + ('Elsa', {'*magic stuff*'}), + ('Zoe', set()), + ] + ) + + def test_sqla_read_models_03(self): + # test single link and the one-to-many backlink + + res = self.sess.execute( + select(self.sm.Post, self.sm.User) + .join(self.sm.Post.author) + ) + vals = {(p.author.name, p.body) for (p, _) in res} + self.assertEqual( + vals, + { + ('Alice', 'Hello'), + ('Alice', "I'm Alice"), + ('Cameron', "I'm Cameron"), + ('Elsa', '*magic stuff*'), + } + ) + + # join via backlink + res = self.sess.execute( + select(self.sm.Post, self.sm.User) + .join(self.sm.User.backlink_via_author) + .order_by(self.sm.Post.body) + ) + # We'll get a cross-product, so we need to jump through some hoops to + # remove duplicates + vals = { + (u.name, tuple(p.body for p in u.backlink_via_author)) + for (_, u) in res + } + self.assertEqual( + vals, + { + ('Alice', ('Hello', "I'm Alice")), + ('Cameron', ("I'm Cameron",)), + ('Elsa', ('*magic stuff*',)), + } + ) + + # LEFT OUTER join via backlink + res = self.sess.execute( + select(self.sm.Post, self.sm.User) + .join(self.sm.User.backlink_via_author, isouter=True) + .order_by(self.sm.Post.body) + ) + vals = { + (u.name, tuple(p.body for p in u.backlink_via_author)) + for (p, u) in res + } + self.assertEqual( + vals, + { + ('Alice', ('Hello', "I'm Alice")), + ('Billie', ()), + ('Cameron', ("I'm Cameron",)), + ('Dana', ()), + ('Elsa', ('*magic stuff*',)), + ('Zoe', ()), + } + ) + + def test_sqla_read_models_04(self): + # test exclusive multi link and its backlink + # using load-on-demand + + res = self.sess.query( + self.sm.GameSession + ).order_by(self.sm.GameSession.num).all() + + vals = [(g.num, {u.name for u in g.players}) for g in res] + self.assertEqual( + vals, + [ + (123, {'Alice', 'Billie'}), + (456, {'Dana'}), + ] + ) + + # use backlink + res = self.sess.query(self.sm.User).all() + vals = { + (u.name, tuple(g.num for g in u.backlink_via_players)) + for u in res + } + self.assertEqual( + vals, + { + ('Alice', (123,)), + ('Billie', (123,)), + ('Cameron', ()), + ('Dana', (456,)), + ('Elsa', ()), + ('Zoe', ()), + } + ) + + def test_sqla_read_models_05(self): + # test exclusive multi link and its backlink + + res = self.sess.execute( + select(self.sm.GameSession, self.sm.User) + .join(self.sm.GameSession.players) + ) + # We'll get a cross-product, so we need to jump through some hoops to + # remove duplicates + vals = { + (g.num, tuple(sorted(u.name for u in g.players))) + for (g, _) in res + } + self.assertEqual( + vals, + { + (123, ('Alice', 'Billie')), + (456, ('Dana',)), + } + ) + + # LEFT OUTER join via backlink + res = self.sess.execute( + select(self.sm.GameSession, self.sm.User) + .join(self.sm.User.backlink_via_players, isouter=True) + ) + vals = { + (u.name, tuple(g.num for g in u.backlink_via_players)) + for (_, u) in res + } + self.assertEqual( + vals, + { + ('Alice', (123,)), + ('Billie', (123,)), + ('Cameron', ()), + ('Dana', (456,)), + ('Elsa', ()), + ('Zoe', ()), + } + ) + + def test_sqla_read_models_06(self): + # test multi link and its backlink + # using load-on-demand + + res = self.sess.query( + self.sm.UserGroup + ).order_by(self.sm.UserGroup.name).all() + + vals = [(g.name, {u.name for u in g.users}) for g in res] + self.assertEqual( + vals, + [ + ('blue', set()), + ('green', {'Alice', 'Billie'}), + ('red', {'Alice', 'Billie', 'Cameron', 'Dana', 'Elsa'}), + ] + ) + + # use backlink + res = self.sess.query(self.sm.User).order_by(self.sm.User.name).all() + vals = [ + (u.name, {g.name for g in u.backlink_via_users}) + for u in res + ] + self.assertEqual( + vals, + [ + ('Alice', {'red', 'green'}), + ('Billie', {'red', 'green'}), + ('Cameron', {'red'}), + ('Dana', {'red'}), + ('Elsa', {'red'}), + ('Zoe', set()), + ] + ) + + def test_sqla_read_models_07(self): + # test exclusive multi link and its backlink + + res = self.sess.execute( + select(self.sm.UserGroup, self.sm.User) + .join(self.sm.UserGroup.users, isouter=True) + .order_by(self.sm.UserGroup.name) + ) + # We'll get a cross-product, so we need to jump through some hoops to + # remove duplicates + vals = { + (g.name, tuple(sorted(u.name for u in g.users))) + for (g, _) in res + } + self.assertEqual( + vals, + { + ('blue', ()), + ('green', ('Alice', 'Billie')), + ('red', ('Alice', 'Billie', 'Cameron', 'Dana', 'Elsa')), + } + ) + + # LEFT OUTER join via backlink + res = self.sess.execute( + select(self.sm.UserGroup, self.sm.User) + .join(self.sm.User.backlink_via_users, isouter=True) + ) + vals = { + (u.name, tuple(sorted(g.name for g in u.backlink_via_users))) + for (_, u) in res + } + self.assertEqual( + vals, + { + ('Alice', ('green', 'red')), + ('Billie', ('green', 'red')), + ('Cameron', ('red',)), + ('Dana', ('red',)), + ('Elsa', ('red',)), + ('Zoe', ()), + } + ) + + def test_sqla_create_models_01(self): + vals = self.sess.query(self.sm.User).filter_by(name='Yvonne').all() + self.assertEqual(list(vals), []) + + self.sess.add(self.sm.User(name='Yvonne')) + self.sess.flush() + + user = self.sess.query(self.sm.User).filter_by(name='Yvonne').one() + self.assertEqual(user.name, 'Yvonne') + self.assertIsInstance(user.id, uuid.UUID) + + def test_sqla_create_models_02(self): + cyan = self.sm.UserGroup( + name='cyan', + users=[ + self.sm.User(name='Yvonne'), + self.sm.User(name='Xander'), + ], + ) + + self.sess.add(cyan) + self.sess.flush() + + for name in ['Yvonne', 'Xander']: + user = self.sess.query(self.sm.User).filter_by(name=name).one() + + self.assertEqual(user.name, name) + self.assertEqual(user.backlink_via_users[0].name, 'cyan') + self.assertIsInstance(user.id, uuid.UUID) + + def test_sqla_create_models_03(self): + user0 = self.sm.User(name='Yvonne') + user1 = self.sm.User(name='Xander') + cyan = self.sm.UserGroup(name='cyan') + + user0.backlink_via_users.append(cyan) + user1.backlink_via_users.append(cyan) + + self.sess.add(cyan) + self.sess.flush() + + for name in ['Yvonne', 'Xander']: + user = self.sess.query(self.sm.User).filter_by(name=name).one() + + self.assertEqual(user.name, name) + self.assertEqual(user.backlink_via_users[0].name, 'cyan') + self.assertIsInstance(user.id, uuid.UUID) + + def test_sqla_create_models_04(self): + user = self.sm.User(name='Yvonne') + self.sm.Post(body='this is a test', author=user) + self.sm.Post(body='also a test', author=user) + + self.sess.add(user) + self.sess.flush() + + res = self.sess.execute( + select(self.sm.Post) + .join(self.sm.Post.author) + .where(self.sm.User.name == 'Yvonne') + ) + self.assertEqual( + {p.body for (p,) in res}, + {'this is a test', 'also a test'}, + ) + + def test_sqla_delete_models_01(self): + user = self.sess.query(self.sm.User).filter_by(name='Zoe').one() + self.assertEqual(user.name, 'Zoe') + self.assertIsInstance(user.id, uuid.UUID) + + self.sess.delete(user) + self.sess.flush() + + vals = self.sess.query(self.sm.User).filter_by(name='Zoe').all() + self.assertEqual(list(vals), []) + + def test_sqla_delete_models_02(self): + post = ( + self.sess.query(self.sm.Post) + .join(self.sm.Post.author) + .filter(self.sm.User.name == 'Elsa') + .one() + ) + user_id = post.author.id + + self.sess.delete(post) + self.sess.flush() + + vals = ( + self.sess.query(self.sm.Post) + .join(self.sm.Post.author) + .filter(self.sm.User.name == 'Elsa') + .all() + ) + self.assertEqual(list(vals), []) + + user = self.sess.get(self.sm.User, user_id) + self.assertEqual(user.name, 'Elsa') + + def test_sqla_delete_models_03(self): + post = ( + self.sess.query(self.sm.Post) + .join(self.sm.Post.author) + .filter(self.sm.User.name == 'Elsa') + .one() + ) + user = post.author + + self.sess.delete(post) + self.sess.delete(user) + self.sess.flush() + + vals = ( + self.sess.query(self.sm.Post) + .join(self.sm.Post.author) + .filter(self.sm.User.name == 'Elsa') + .all() + ) + self.assertEqual(list(vals), []) + + vals = self.sess.query(self.sm.User).filter_by(name='Elsa').all() + self.assertEqual(list(vals), []) + + def test_sqla_delete_models_04(self): + group = self.sess.query( + self.sm.UserGroup).filter_by(name='green').one() + names = {u.name for u in group.users} + + self.sess.delete(group) + self.sess.flush() + + vals = self.sess.query( + self.sm.UserGroup).filter_by(name='green').all() + self.assertEqual(list(vals), []) + + users = self.sess.query(self.sm.User).all() + for name in names: + self.assertIn(name, {u.name for u in users}) + + def test_sqla_delete_models_05(self): + group = self.sess.query( + self.sm.UserGroup).filter_by(name='green').one() + for u in group.users: + if u.name == 'Billie': + user = u + break + + self.sess.delete(group) + self.sess.delete(user) + self.sess.flush() + + vals = self.sess.query( + self.sm.UserGroup).filter_by(name='green').all() + self.assertEqual(list(vals), []) + + users = self.sess.query(self.sm.User).all() + self.assertNotIn('Billie', {u.name for u in users}) + + def test_sqla_update_models_01(self): + user = self.sess.query(self.sm.User).filter_by(name='Alice').one() + self.assertEqual(user.name, 'Alice') + self.assertIsInstance(user.id, uuid.UUID) + + user.name = 'Xander' + self.sess.add(user) + self.sess.flush() + + vals = self.sess.query(self.sm.User).filter_by(name='Alice').all() + self.assertEqual(list(vals), []) + other = self.sess.query(self.sm.User).filter_by(name='Xander').one() + self.assertEqual(user, other) + + def test_sqla_update_models_02(self): + red = self.sess.query(self.sm.UserGroup).filter_by(name='red').one() + blue = self.sess.query(self.sm.UserGroup).filter_by(name='blue').one() + user = self.sm.User(name='Yvonne') + + self.sess.add(user) + red.users.append(user) + blue.users.append(user) + self.sess.flush() + + self.assertEqual( + {g.name for g in user.backlink_via_users}, + {'red', 'blue'}, + ) + self.assertEqual(user.name, 'Yvonne') + self.assertIsInstance(user.id, uuid.UUID) + + group = [g for g in user.backlink_via_users if g.name == 'red'][0] + self.assertEqual( + {u.name for u in group.users}, + {'Alice', 'Billie', 'Cameron', 'Dana', 'Elsa', 'Yvonne'}, + ) + + def test_sqla_update_models_03(self): + user0 = self.sess.query(self.sm.User).filter_by(name='Elsa').one() + user1 = self.sess.query(self.sm.User).filter_by(name='Zoe').one() + # Replace the author or a post + post = user0.backlink_via_author[0] + body = post.body + post.author = user1 + + self.sess.add(post) + self.sess.flush() + + res = self.sess.execute( + select(self.sm.Post) + .join(self.sm.Post.author) + .where(self.sm.User.name == 'Zoe') + ) + self.assertEqual( + {p.body for (p,) in res}, + {body}, + ) + + def test_sqla_update_models_04(self): + user = self.sess.query(self.sm.User).filter_by(name='Zoe').one() + post = ( + self.sess.query(self.sm.Post) + .join(self.sm.Post.author) + .filter(self.sm.User.name == 'Elsa') + .one() + ) + # Replace the author or a post + post_id = post.id + post.author = user + + self.sess.add(post) + self.sess.flush() + + post = self.sess.get(self.sm.Post, post_id) + self.assertEqual(post.author.name, 'Zoe') diff --git a/tests/test_sqla_features.py b/tests/test_sqla_features.py new file mode 100644 index 00000000..4c1e42e6 --- /dev/null +++ b/tests/test_sqla_features.py @@ -0,0 +1,318 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2024-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from gel import _testbase as tb + + +class TestSQLAFeatures(tb.SQLATestCase): + SCHEMA = os.path.join( + os.path.dirname(__file__), 'dbsetup', 'features_default.esdl') + + SCHEMA_OTHER = os.path.join( + os.path.dirname(__file__), 'dbsetup', 'features_other.esdl') + + SCHEMA_OTHER_NESTED = os.path.join( + os.path.dirname(__file__), 'dbsetup', 'features_other_nested.esdl') + + SETUP = os.path.join(os.path.dirname(__file__), 'dbsetup', + 'features.edgeql') + + SQLAPACKAGE = 'fmodels' + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.engine = create_engine(cls.get_dsn_for_sqla()) + cls.sess = Session(cls.engine, autobegin=False) + + from fmodels import default, other + from fmodels.other import nested + cls.sm = default + cls.sm_o = other + cls.sm_on = nested + + def setUp(self): + super().setUp() + + if self.client.query_required_single(''' + select sys::get_version().major < 6 + '''): + self.skipTest("Test needs SQL DML queries") + + self.sess.begin() + + def tearDown(self): + super().tearDown() + self.sess.rollback() + + def test_sqla_multiprops_01(self): + val = self.sess.query(self.sm.MultiProp)\ + .filter_by(name='got one').one() + self.assertEqual( + {t.target for t in val.tags}, + {'solo tag'}, + ) + + val = self.sess.query(self.sm.MultiProp)\ + .filter_by(name='got many').one() + self.assertEqual( + {t.target for t in val.tags}, + {'one', 'two', 'three'}, + ) + + def test_sqla_multiprops_02(self): + val = self.sess.query(self.sm.MultiProp)\ + .filter_by(name='got one').one() + self.assertEqual( + {t.target for t in val.tags}, + {'solo tag'}, + ) + + # create and add a few more tags + self.sess.add( + self.sm.MultiProp_tags_prop(source=val, target='hello')) + self.sess.add( + self.sm.MultiProp_tags_prop(source=val, target='world')) + self.sess.flush() + + val = self.sess.query(self.sm.MultiProp)\ + .filter_by(name='got one').one() + self.assertEqual( + {t.target for t in val.tags}, + {'solo tag', 'hello', 'world'}, + ) + + def test_sqla_multiprops_03(self): + val = self.sess.query(self.sm.MultiProp)\ + .filter_by(name='got one').one() + self.assertEqual( + {t.target for t in val.tags}, + {'solo tag'}, + ) + + # create and add a few more tags (don't specify source explicitly) + val.tags.append( + self.sm.MultiProp_tags_prop(target='hello')) + val.tags.append( + self.sm.MultiProp_tags_prop(target='world')) + self.sess.flush() + + val = self.sess.query(self.sm.MultiProp)\ + .filter_by(name='got one').one() + self.assertEqual( + {t.target for t in val.tags}, + {'solo tag', 'hello', 'world'}, + ) + + def test_sqla_multiprops_04(self): + val = self.sess.query(self.sm.MultiProp)\ + .filter_by(name='got many').one() + self.assertEqual( + {t.target for t in val.tags}, + {'one', 'two', 'three'}, + ) + + # remove several tags + for t in list(val.tags): + if t.target != 'one': + val.tags.remove(t) + self.sess.flush() + + val = self.sess.query(self.sm.MultiProp)\ + .filter_by(name='got many').one() + self.assertEqual( + {t.target for t in val.tags}, + {'one'}, + ) + + def test_sqla_linkprops_01(self): + val = self.sess.query(self.sm.HasLinkPropsA).one() + self.assertEqual(val.child.target.num, 0) + self.assertEqual(val.child.a, 'single') + + def test_sqla_linkprops_02(self): + val = self.sess.query(self.sm.HasLinkPropsA).one() + self.assertEqual(val.child.target.num, 0) + self.assertEqual(val.child.a, 'single') + + # replace the single child with a different one + ch = self.sess.query(self.sm.Child).filter_by(num=1).one() + val.child = self.sm.HasLinkPropsA_child_link(a='replaced', target=ch) + self.sess.flush() + + val = self.sess.query(self.sm.HasLinkPropsA).one() + self.assertEqual(val.child.target.num, 1) + self.assertEqual(val.child.a, 'replaced') + + # make sure there's only one link object still + vals = self.sess.query(self.sm.HasLinkPropsA_child_link).all() + self.assertEqual( + [(val.a, val.target.num) for val in vals], + [('replaced', 1)] + ) + + def test_sqla_linkprops_03(self): + val = self.sess.query(self.sm.HasLinkPropsA).one() + self.assertEqual(val.child.target.num, 0) + self.assertEqual(val.child.a, 'single') + + # delete the child object + val = self.sess.query(self.sm.Child).filter_by(num=0).one() + self.sess.delete(val) + self.sess.flush() + + val = self.sess.query(self.sm.HasLinkPropsA).one() + self.assertEqual(val.child, None) + + # make sure there's only one link object still + vals = self.sess.query(self.sm.HasLinkPropsA_child_link).all() + self.assertEqual(vals, []) + + def test_sqla_linkprops_04(self): + val = self.sess.query(self.sm.HasLinkPropsB).one() + self.assertEqual( + {(c.b, c.target.num) for c in val.children}, + {('hello', 0), ('world', 1)}, + ) + + def test_sqla_linkprops_05(self): + val = self.sess.query(self.sm.HasLinkPropsB).one() + self.assertEqual( + {(c.b, c.target.num) for c in val.children}, + {('hello', 0), ('world', 1)}, + ) + + # Remove one of the children + for t in list(val.children): + if t.b != 'hello': + val.children.remove(t) + self.sess.flush() + + val = self.sess.query(self.sm.HasLinkPropsB).one() + self.assertEqual( + {(c.b, c.target.num) for c in val.children}, + {('hello', 0)}, + ) + + def test_sqla_linkprops_06(self): + val = self.sess.query(self.sm.HasLinkPropsB).one() + self.assertEqual( + {(c.b, c.target.num) for c in val.children}, + {('hello', 0), ('world', 1)}, + ) + + # Remove one of the children + val = self.sess.query(self.sm.Child).filter_by(num=0).one() + self.sess.delete(val) + self.sess.flush() + + val = self.sess.query(self.sm.HasLinkPropsB).one() + self.assertEqual( + {(c.b, c.target.num) for c in val.children}, + {('world', 1)}, + ) + + def test_sqla_module_01(self): + vals = self.sess.query(self.sm_o.Branch).all() + self.assertEqual( + {(r.val, tuple(sorted(lf.num for lf in r.leaves))) for r in vals}, + { + ('big', (20, 30)), + ('small', (10,)), + }, + ) + + vals = self.sess.query(self.sm_on.Leaf).all() + self.assertEqual( + {r.num for r in vals}, + {10, 20, 30}, + ) + + vals = self.sess.query(self.sm.Theme).all() + self.assertEqual( + { + (r.color, r.branch.note, r.branch.target.val) + for r in vals + }, + { + ('green', 'fresh', 'big'), + ('orange', 'fall', 'big'), + }, + ) + + def test_sqla_module_02(self): + val = self.sess.query(self.sm.Theme).filter_by(color='orange').one() + self.assertEqual( + (val.color, val.branch.note, val.branch.target.val), + ('orange', 'fall', 'big'), + ) + + # swap the branch for 'small' + br = self.sess.query(self.sm_o.Branch).filter_by(val='small').one() + # can't update link tables (Gel limitation), so we will create a new + # one + val.branch = self.sm.Theme_branch_link( + note='swapped', target=br) + self.sess.add(val) + self.sess.flush() + + vals = self.sess.query(self.sm.Theme).all() + self.assertEqual( + { + (r.color, r.branch.note, r.branch.target.val) + for r in vals + }, + { + ('green', 'fresh', 'big'), + ('orange', 'swapped', 'small'), + }, + ) + + def test_sqla_bklink_01(self): + # test backlink name collisions + foo = self.sess.query(self.sm.Foo).filter_by(name='foo').one() + oof = self.sess.query(self.sm.Foo).filter_by(name='oof').one() + + # only one link from Bar 123 to foo + self.assertEqual( + [obj.n for obj in foo.backlink_via_foo_from_Bar], + [123], + ) + # only one link from Who 456 to oof + self.assertEqual( + [obj.x for obj in oof.backlink_via_foo_from_Who], + [456], + ) + + # foo is linked via `many_foo` from both Bar and Who + self.assertEqual( + [obj.n for obj in foo.backlink_via_many_foo_from_Bar], + [123], + ) + self.assertEqual( + [ + (obj.note, obj.source.x) + for obj in foo.backlink_via_many_foo_from_Who + ], + [('just one', 456)], + )