From ca41ca383571d650e8c01929fde6be3ff4f9ba73 Mon Sep 17 00:00:00 2001 From: Javier Buzzi Date: Tue, 5 Mar 2024 12:00:40 +0100 Subject: [PATCH] Handles Postgres indexes (#62) --- django_squash/contrib/postgres.py | 13 ++ django_squash/db/migrations/autodetector.py | 18 ++- django_squash/db/migrations/writer.py | 19 ++- setup.py | 1 + .../migrations/pg_indexes/0001_initial.py | 10 ++ .../migrations/pg_indexes/0002_use_index.py | 41 ++++++ .../tests/migrations/pg_indexes/__init__.py | 0 .../pg_indexes_custom/0001_initial.py | 19 +++ .../pg_indexes_custom/0002_use_index.py | 41 ++++++ .../migrations/pg_indexes_custom/__init__.py | 0 tests/test_migrations.py | 120 +++++++++++++++++- 11 files changed, 273 insertions(+), 9 deletions(-) create mode 100644 django_squash/contrib/postgres.py create mode 100644 tests/app/tests/migrations/pg_indexes/0001_initial.py create mode 100644 tests/app/tests/migrations/pg_indexes/0002_use_index.py create mode 100644 tests/app/tests/migrations/pg_indexes/__init__.py create mode 100644 tests/app/tests/migrations/pg_indexes_custom/0001_initial.py create mode 100644 tests/app/tests/migrations/pg_indexes_custom/0002_use_index.py create mode 100644 tests/app/tests/migrations/pg_indexes_custom/__init__.py diff --git a/django_squash/contrib/postgres.py b/django_squash/contrib/postgres.py new file mode 100644 index 0000000..054391f --- /dev/null +++ b/django_squash/contrib/postgres.py @@ -0,0 +1,13 @@ +""" +Postgres specific code +""" + +try: + from django.contrib.postgres.operations import CreateExtension as PGCreateExtension +except ImportError: # pragma: no cover + + class PGCreateExtension: + pass + + +__all__ = ("PGCreateExtension",) diff --git a/django_squash/db/migrations/autodetector.py b/django_squash/db/migrations/autodetector.py index 3c0ad15..09fd0c6 100644 --- a/django_squash/db/migrations/autodetector.py +++ b/django_squash/db/migrations/autodetector.py @@ -9,6 +9,8 @@ from django.db import migrations as dj_migrations from django.db.migrations.autodetector import MigrationAutodetector as MigrationAutodetectorBase +from django_squash.contrib import postgres + from . import utils RESERVED_MIGRATION_KEYWORDS = ("_deleted", "_dependencies_change", "_replaces_change", "_original_migration") @@ -60,10 +62,10 @@ def from_migration(cls, migration): class SquashMigrationAutodetector(MigrationAutodetectorBase): - def add_non_elidables(self, original, loader, changes): + def add_non_elidables(self, loader, changes): replacing_migrations_by_app = { app: [ - original.disk_migrations[r] + loader.disk_migrations[r] for r in list(dict.fromkeys(itertools.chain.from_iterable([m.replaces for m in migrations]))) ] for app, migrations in changes.items() @@ -71,6 +73,7 @@ def add_non_elidables(self, original, loader, changes): for app in changes.keys(): new_operations = [] + new_operations_bubble_top = [] new_imports = [] for migration in replacing_migrations_by_app[app]: @@ -84,10 +87,17 @@ def add_non_elidables(self, original, loader, changes): new_operations.append(operation) elif isinstance(operation, dj_migrations.RunPython): new_operations.append(operation) + elif isinstance(operation, postgres.PGCreateExtension): + new_operations_bubble_top.append(operation) elif isinstance(operation, dj_migrations.SeparateDatabaseAndState): # A valid use case for this should be given before any work is done. pass + if new_operations_bubble_top: + migration = changes[app][0] + migration.operations = new_operations_bubble_top + migration.operations + migration.extra_imports = new_imports + migration = changes[app][-1] migration.operations += new_operations migration.extra_imports = new_imports @@ -198,14 +208,14 @@ def squash(self, real_loader, squash_loader, ignore_apps, migration_name=None): self.convert_migration_references_to_objects(real_loader, changes, ignore_apps) self.rename_migrations(real_loader, graph, changes, migration_name) self.replace_current_migrations(real_loader, graph, changes) - self.add_non_elidables(real_loader, squash_loader, changes) + self.add_non_elidables(real_loader, changes) for app, change in changes_.items(): changes[app].extend(change) return changes - def delete_old_squashed(self, loader, ignore_apps=None): + def delete_old_squashed(self, loader, ignore_apps): changes = defaultdict(set) project_path = os.path.abspath(os.curdir) project_apps = [ diff --git a/django_squash/db/migrations/writer.py b/django_squash/db/migrations/writer.py index 0a32bac..e64a284 100644 --- a/django_squash/db/migrations/writer.py +++ b/django_squash/db/migrations/writer.py @@ -1,3 +1,4 @@ +import inspect import os import re import textwrap @@ -8,6 +9,7 @@ from django.db.migrations import writer as dj_writer from django.utils.timezone import now +from django_squash.contrib import postgres from django_squash.db.migrations import operators, utils SUPPORTED_DJANGO_WRITER = ( @@ -38,6 +40,16 @@ def check_django_migration_hash(): check_django_migration_hash() +class OperationWriter(dj_writer.OperationWriter): + def serialize(self): + if isinstance(self.operation, postgres.PGCreateExtension): + if not utils.is_code_in_site_packages(self.operation.__class__.__module__): + self.feed("%s()," % (self.operation.__class__.__name__)) + return self.render(), set() + + return super().serialize() + + class ReplacementMigrationWriter(dj_writer.MigrationWriter): """ Take a Migration instance and is able to produce the contents @@ -70,7 +82,7 @@ def get_kwargs(self): # pragma: no cover # Deconstruct operations operations = [] for operation in self.migration.operations: - operation_string, operation_imports = dj_writer.OperationWriter(operation).serialize() + operation_string, operation_imports = OperationWriter(operation).serialize() imports.update(operation_imports) operations.append(operation_string) items["operations"] = "\n".join(operations) + "\n" if operations else "" @@ -205,7 +217,7 @@ def replace_in_migration(self): source = utils.replace_migration_attribute(source, "replaces", self.migration.replaces) changed = True if not changed: - raise NotImplementedError() + raise NotImplementedError() # pragma: no cover return source @@ -244,6 +256,9 @@ def get_kwargs(self): variables.append( self.template_variable % (operation.reverse_sql.name, repr(operation.reverse_sql.value)) ) + elif isinstance(operation, postgres.PGCreateExtension): + if not utils.is_code_in_site_packages(operation.__class__.__module__): + functions.append(textwrap.dedent(inspect.getsource(operation.__class__))) kwargs["functions"] = ("\n\n" if functions else "") + "\n\n".join(functions) kwargs["variables"] = ("\n\n" if variables else "") + "\n\n".join(variables) diff --git a/setup.py b/setup.py index 7a4bd51..3e488d2 100755 --- a/setup.py +++ b/setup.py @@ -71,6 +71,7 @@ "ipdb", "isort", "libcst", + "psycopg2", "pytest-cov", "pytest-django", "restructuredtext-lint", diff --git a/tests/app/tests/migrations/pg_indexes/0001_initial.py b/tests/app/tests/migrations/pg_indexes/0001_initial.py new file mode 100644 index 0000000..8365a27 --- /dev/null +++ b/tests/app/tests/migrations/pg_indexes/0001_initial.py @@ -0,0 +1,10 @@ +from django.contrib.postgres.operations import BtreeGinExtension +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [] + + operations = [ + BtreeGinExtension(), + ] diff --git a/tests/app/tests/migrations/pg_indexes/0002_use_index.py b/tests/app/tests/migrations/pg_indexes/0002_use_index.py new file mode 100644 index 0000000..90af9a4 --- /dev/null +++ b/tests/app/tests/migrations/pg_indexes/0002_use_index.py @@ -0,0 +1,41 @@ +# Generated by Django 3.2.22 on 2023-10-13 10:38 + +import django.contrib.postgres.indexes +from django.db import migrations, models + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + ("app", "0001_initial"), + ] + + operations = [ + migrations.CreateModel( + name="Message", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("score", models.IntegerField(default=0)), + ("unicode_name", models.CharField(db_index=True, max_length=255)), + ], + ), + migrations.AddIndex( + model_name="message", + index=models.Index(fields=["-score"], name="message_e_score_385f90_idx"), + ), + migrations.AddIndex( + model_name="message", + index=django.contrib.postgres.indexes.GinIndex( + fields=["unicode_name"], name="message_e_unicode_6789fc_gin" + ), + ), + ] diff --git a/tests/app/tests/migrations/pg_indexes/__init__.py b/tests/app/tests/migrations/pg_indexes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/app/tests/migrations/pg_indexes_custom/0001_initial.py b/tests/app/tests/migrations/pg_indexes_custom/0001_initial.py new file mode 100644 index 0000000..bb66edd --- /dev/null +++ b/tests/app/tests/migrations/pg_indexes_custom/0001_initial.py @@ -0,0 +1,19 @@ +from django.contrib.postgres.operations import BtreeGinExtension +from django.db import migrations + + +class IgnoreRollbackBtreeGinExtension(BtreeGinExtension): + """ + Custom extension that doesn't rollback no matter what + """ + + def database_backwards(self, *args, **kwargs): + pass + + +class Migration(migrations.Migration): + dependencies = [] + + operations = [ + IgnoreRollbackBtreeGinExtension(), + ] diff --git a/tests/app/tests/migrations/pg_indexes_custom/0002_use_index.py b/tests/app/tests/migrations/pg_indexes_custom/0002_use_index.py new file mode 100644 index 0000000..90af9a4 --- /dev/null +++ b/tests/app/tests/migrations/pg_indexes_custom/0002_use_index.py @@ -0,0 +1,41 @@ +# Generated by Django 3.2.22 on 2023-10-13 10:38 + +import django.contrib.postgres.indexes +from django.db import migrations, models + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + ("app", "0001_initial"), + ] + + operations = [ + migrations.CreateModel( + name="Message", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("score", models.IntegerField(default=0)), + ("unicode_name", models.CharField(db_index=True, max_length=255)), + ], + ), + migrations.AddIndex( + model_name="message", + index=models.Index(fields=["-score"], name="message_e_score_385f90_idx"), + ), + migrations.AddIndex( + model_name="message", + index=django.contrib.postgres.indexes.GinIndex( + fields=["unicode_name"], name="message_e_unicode_6789fc_gin" + ), + ), + ] diff --git a/tests/app/tests/migrations/pg_indexes_custom/__init__.py b/tests/app/tests/migrations/pg_indexes_custom/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_migrations.py b/tests/test_migrations.py index a0a7f47..1f79872 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -6,6 +6,7 @@ import black import libcst import pytest +from django.contrib.postgres.indexes import GinIndex from django.core.management import CommandError from django.db import models @@ -17,7 +18,9 @@ def load_migration_module(path): spec.loader.exec_module(module) except Exception as e: with open(path) as f: - raise type(e)(f"{e}.\nError loading module file containing:\n\n{f.read()}") from e + lines = f.readlines() + formatted_lines = "".join(f"{i}: {line}" for i, line in enumerate(lines, start=1)) + raise type(e)(f"{e}.\nError loading module file containing:\n\n{formatted_lines}") from e return module @@ -501,13 +504,12 @@ class Meta: call_squash_migrations() files_in_app = migration_app_dir.migration_files() - expected_files = [ + assert files_in_app == [ "0001_initial.py", "0002_add_dob.py", "0003_squashed.py", "__init__.py", ] - assert files_in_app == expected_files app_squash = load_migration_module(migration_app_dir / "0003_squashed.py") expected = textwrap.dedent( @@ -540,3 +542,115 @@ class Migration(migrations.Migration): """ # noqa ) assert pretty_extract_piece(app_squash, "") == expected + + +@pytest.mark.temporary_migration_module(module="app.tests.migrations.pg_indexes", app_label="app") +def test_squashing_migration_pg_indexes(migration_app_dir, call_squash_migrations): + + class Message(models.Model): + score = models.IntegerField(default=0) + unicode_name = models.CharField(max_length=255, db_index=True) + + class Meta: + indexes = [models.Index(fields=["-score"]), GinIndex(fields=["unicode_name"])] + app_label = "app" + + call_squash_migrations() + assert migration_app_dir.migration_files() == [ + "0001_initial.py", + "0002_use_index.py", + "0003_squashed.py", + "__init__.py", + ] + app_squash = load_migration_module(migration_app_dir / "0003_squashed.py") + expected = textwrap.dedent( + """\ + import django.contrib.postgres.indexes + import django.contrib.postgres.operations + from django.contrib.postgres.operations import BtreeGinExtension + from django.db import migrations + from django.db import migrations, models + + + class Migration(migrations.Migration): + + replaces = [("app", "0001_initial"), ("app", "0002_use_index")] + + initial = True + + dependencies = [] + + operations = [ + django.contrib.postgres.operations.BtreeGinExtension(), + migrations.CreateModel( + name="Message", + fields=[ + ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("score", models.IntegerField(default=0)), + ("unicode_name", models.CharField(db_index=True, max_length=255)), + ], + """ # noqa + ) + # NOTE: different django versions handle index differently, since the Index part is actually not + # being tested, it doesn't matter that is not checked + assert pretty_extract_piece(app_squash, "").startswith(expected) + + +@pytest.mark.temporary_migration_module(module="app.tests.migrations.pg_indexes_custom", app_label="app") +def test_squashing_migration_pg_indexes_custom(migration_app_dir, call_squash_migrations): + + class Message(models.Model): + score = models.IntegerField(default=0) + unicode_name = models.CharField(max_length=255, db_index=True) + + class Meta: + indexes = [models.Index(fields=["-score"]), GinIndex(fields=["unicode_name"])] + app_label = "app" + + call_squash_migrations() + assert migration_app_dir.migration_files() == [ + "0001_initial.py", + "0002_use_index.py", + "0003_squashed.py", + "__init__.py", + ] + app_squash = load_migration_module(migration_app_dir / "0003_squashed.py") + expected = textwrap.dedent( + """\ + import django.contrib.postgres.indexes + from django.contrib.postgres.operations import BtreeGinExtension + from django.db import migrations + from django.db import migrations, models + + + class IgnoreRollbackBtreeGinExtension(BtreeGinExtension): + \"\"\" + Custom extension that doesn't rollback no matter what + \"\"\" + + def database_backwards(self, *args, **kwargs): + pass + + + class Migration(migrations.Migration): + + replaces = [("app", "0001_initial"), ("app", "0002_use_index")] + + initial = True + + dependencies = [] + + operations = [ + IgnoreRollbackBtreeGinExtension(), + migrations.CreateModel( + name="Message", + fields=[ + ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("score", models.IntegerField(default=0)), + ("unicode_name", models.CharField(db_index=True, max_length=255)), + ], + """ # noqa + ) + # NOTE: different django versions handle index differently, since the Index part is actually not + # being tested, it doesn't matter that is not checked + assert pretty_extract_piece(app_squash, "").startswith(expected)