Skip to content

Commit

Permalink
Handles Postgres indexes (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
kingbuzzman authored Mar 5, 2024
1 parent e6ff6b7 commit ca41ca3
Show file tree
Hide file tree
Showing 11 changed files with 273 additions and 9 deletions.
13 changes: 13 additions & 0 deletions django_squash/contrib/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Postgres specific code
"""

try:
from django.contrib.postgres.operations import CreateExtension as PGCreateExtension
except ImportError: # pragma: no cover

class PGCreateExtension:
pass


__all__ = ("PGCreateExtension",)
18 changes: 14 additions & 4 deletions django_squash/db/migrations/autodetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from django.db import migrations as dj_migrations
from django.db.migrations.autodetector import MigrationAutodetector as MigrationAutodetectorBase

from django_squash.contrib import postgres

from . import utils

RESERVED_MIGRATION_KEYWORDS = ("_deleted", "_dependencies_change", "_replaces_change", "_original_migration")
Expand Down Expand Up @@ -60,17 +62,18 @@ def from_migration(cls, migration):

class SquashMigrationAutodetector(MigrationAutodetectorBase):

def add_non_elidables(self, original, loader, changes):
def add_non_elidables(self, loader, changes):
replacing_migrations_by_app = {
app: [
original.disk_migrations[r]
loader.disk_migrations[r]
for r in list(dict.fromkeys(itertools.chain.from_iterable([m.replaces for m in migrations])))
]
for app, migrations in changes.items()
}

for app in changes.keys():
new_operations = []
new_operations_bubble_top = []
new_imports = []

for migration in replacing_migrations_by_app[app]:
Expand All @@ -84,10 +87,17 @@ def add_non_elidables(self, original, loader, changes):
new_operations.append(operation)
elif isinstance(operation, dj_migrations.RunPython):
new_operations.append(operation)
elif isinstance(operation, postgres.PGCreateExtension):
new_operations_bubble_top.append(operation)
elif isinstance(operation, dj_migrations.SeparateDatabaseAndState):
# A valid use case for this should be given before any work is done.
pass

if new_operations_bubble_top:
migration = changes[app][0]
migration.operations = new_operations_bubble_top + migration.operations
migration.extra_imports = new_imports

migration = changes[app][-1]
migration.operations += new_operations
migration.extra_imports = new_imports
Expand Down Expand Up @@ -198,14 +208,14 @@ def squash(self, real_loader, squash_loader, ignore_apps, migration_name=None):
self.convert_migration_references_to_objects(real_loader, changes, ignore_apps)
self.rename_migrations(real_loader, graph, changes, migration_name)
self.replace_current_migrations(real_loader, graph, changes)
self.add_non_elidables(real_loader, squash_loader, changes)
self.add_non_elidables(real_loader, changes)

for app, change in changes_.items():
changes[app].extend(change)

return changes

def delete_old_squashed(self, loader, ignore_apps=None):
def delete_old_squashed(self, loader, ignore_apps):
changes = defaultdict(set)
project_path = os.path.abspath(os.curdir)
project_apps = [
Expand Down
19 changes: 17 additions & 2 deletions django_squash/db/migrations/writer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import os
import re
import textwrap
Expand All @@ -8,6 +9,7 @@
from django.db.migrations import writer as dj_writer
from django.utils.timezone import now

from django_squash.contrib import postgres
from django_squash.db.migrations import operators, utils

SUPPORTED_DJANGO_WRITER = (
Expand Down Expand Up @@ -38,6 +40,16 @@ def check_django_migration_hash():
check_django_migration_hash()


class OperationWriter(dj_writer.OperationWriter):
def serialize(self):
if isinstance(self.operation, postgres.PGCreateExtension):
if not utils.is_code_in_site_packages(self.operation.__class__.__module__):
self.feed("%s()," % (self.operation.__class__.__name__))
return self.render(), set()

return super().serialize()


class ReplacementMigrationWriter(dj_writer.MigrationWriter):
"""
Take a Migration instance and is able to produce the contents
Expand Down Expand Up @@ -70,7 +82,7 @@ def get_kwargs(self): # pragma: no cover
# Deconstruct operations
operations = []
for operation in self.migration.operations:
operation_string, operation_imports = dj_writer.OperationWriter(operation).serialize()
operation_string, operation_imports = OperationWriter(operation).serialize()
imports.update(operation_imports)
operations.append(operation_string)
items["operations"] = "\n".join(operations) + "\n" if operations else ""
Expand Down Expand Up @@ -205,7 +217,7 @@ def replace_in_migration(self):
source = utils.replace_migration_attribute(source, "replaces", self.migration.replaces)
changed = True
if not changed:
raise NotImplementedError()
raise NotImplementedError() # pragma: no cover

return source

Expand Down Expand Up @@ -244,6 +256,9 @@ def get_kwargs(self):
variables.append(
self.template_variable % (operation.reverse_sql.name, repr(operation.reverse_sql.value))
)
elif isinstance(operation, postgres.PGCreateExtension):
if not utils.is_code_in_site_packages(operation.__class__.__module__):
functions.append(textwrap.dedent(inspect.getsource(operation.__class__)))

kwargs["functions"] = ("\n\n" if functions else "") + "\n\n".join(functions)
kwargs["variables"] = ("\n\n" if variables else "") + "\n\n".join(variables)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"ipdb",
"isort",
"libcst",
"psycopg2",
"pytest-cov",
"pytest-django",
"restructuredtext-lint",
Expand Down
10 changes: 10 additions & 0 deletions tests/app/tests/migrations/pg_indexes/0001_initial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from django.contrib.postgres.operations import BtreeGinExtension
from django.db import migrations


class Migration(migrations.Migration):
dependencies = []

operations = [
BtreeGinExtension(),
]
41 changes: 41 additions & 0 deletions tests/app/tests/migrations/pg_indexes/0002_use_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Generated by Django 3.2.22 on 2023-10-13 10:38

import django.contrib.postgres.indexes
from django.db import migrations, models


class Migration(migrations.Migration):
initial = True

dependencies = [
("app", "0001_initial"),
]

operations = [
migrations.CreateModel(
name="Message",
fields=[
(
"id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("score", models.IntegerField(default=0)),
("unicode_name", models.CharField(db_index=True, max_length=255)),
],
),
migrations.AddIndex(
model_name="message",
index=models.Index(fields=["-score"], name="message_e_score_385f90_idx"),
),
migrations.AddIndex(
model_name="message",
index=django.contrib.postgres.indexes.GinIndex(
fields=["unicode_name"], name="message_e_unicode_6789fc_gin"
),
),
]
Empty file.
19 changes: 19 additions & 0 deletions tests/app/tests/migrations/pg_indexes_custom/0001_initial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from django.contrib.postgres.operations import BtreeGinExtension
from django.db import migrations


class IgnoreRollbackBtreeGinExtension(BtreeGinExtension):
"""
Custom extension that doesn't rollback no matter what
"""

def database_backwards(self, *args, **kwargs):
pass


class Migration(migrations.Migration):
dependencies = []

operations = [
IgnoreRollbackBtreeGinExtension(),
]
41 changes: 41 additions & 0 deletions tests/app/tests/migrations/pg_indexes_custom/0002_use_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Generated by Django 3.2.22 on 2023-10-13 10:38

import django.contrib.postgres.indexes
from django.db import migrations, models


class Migration(migrations.Migration):
initial = True

dependencies = [
("app", "0001_initial"),
]

operations = [
migrations.CreateModel(
name="Message",
fields=[
(
"id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("score", models.IntegerField(default=0)),
("unicode_name", models.CharField(db_index=True, max_length=255)),
],
),
migrations.AddIndex(
model_name="message",
index=models.Index(fields=["-score"], name="message_e_score_385f90_idx"),
),
migrations.AddIndex(
model_name="message",
index=django.contrib.postgres.indexes.GinIndex(
fields=["unicode_name"], name="message_e_unicode_6789fc_gin"
),
),
]
Empty file.
Loading

0 comments on commit ca41ca3

Please sign in to comment.