Skip to content

Commit

Permalink
Merge branch 'master' into add-circular-migration-finder
Browse files Browse the repository at this point in the history
  • Loading branch information
kingbuzzman authored Mar 5, 2024
2 parents 9041631 + ca41ca3 commit cb4611f
Show file tree
Hide file tree
Showing 14 changed files with 355 additions and 166 deletions.
13 changes: 13 additions & 0 deletions django_squash/contrib/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Postgres specific code
"""

try:
from django.contrib.postgres.operations import CreateExtension as PGCreateExtension
except ImportError: # pragma: no cover

class PGCreateExtension:
pass


__all__ = ("PGCreateExtension",)
69 changes: 30 additions & 39 deletions django_squash/db/migrations/autodetector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from django.db import migrations as dj_migrations
from django.db.migrations.autodetector import MigrationAutodetector as MigrationAutodetectorBase

from . import operators, utils
from django_squash.contrib import postgres

from . import utils

RESERVED_MIGRATION_KEYWORDS = ("_deleted", "_dependencies_change", "_replaces_change", "_original_migration")

Expand Down Expand Up @@ -58,58 +60,47 @@ def from_migration(cls, migration):
return new


def all_custom_operations(operations, unique_names):
"""
Generator that loops over all the operations and traverses sub-operations such as those inside a -
SeparateDatabaseAndState class.
"""

for operation in operations:
if operation.elidable:
continue

if isinstance(operation, dj_migrations.RunSQL):
yield operators.RunSQL.from_operation(operation, unique_names)
elif isinstance(operation, dj_migrations.RunPython):
yield operators.RunPython.from_operation(operation, unique_names)
elif isinstance(operation, dj_migrations.SeparateDatabaseAndState):
# A valid use case for this should be given before any work is done.
pass


class SquashMigrationAutodetector(MigrationAutodetectorBase):

def add_non_elidables(self, original, loader, changes):
unique_names = utils.UniqueVariableName()
def add_non_elidables(self, loader, changes):
replacing_migrations_by_app = {
app: [
original.disk_migrations[r]
loader.disk_migrations[r]
for r in list(dict.fromkeys(itertools.chain.from_iterable([m.replaces for m in migrations])))
]
for app, migrations in changes.items()
}

for app in changes.keys():
operations = []
imports = []
new_operations = []
new_operations_bubble_top = []
new_imports = []

for migration in replacing_migrations_by_app[app]:
module = sys.modules[migration.__module__]
imports.extend(utils.get_imports(module))
for operation in all_custom_operations(migration.operations, unique_names):
if isinstance(operation, dj_migrations.RunPython):
operation.code = utils.copy_func(operation.code)
operation.code.__in_migration_file__ = module.__name__ == operation.code.__module__

if operation.reverse_code:
operation.reverse_code = utils.copy_func(operation.reverse_code)
in_migration_file = module.__name__ == operation.reverse_code.__module__
operation.reverse_code.__in_migration_file__ = in_migration_file
operations.append(operation)
new_imports.extend(utils.get_imports(module))
for operation in migration.operations:
if operation.elidable:
continue

if isinstance(operation, dj_migrations.RunSQL):
new_operations.append(operation)
elif isinstance(operation, dj_migrations.RunPython):
new_operations.append(operation)
elif isinstance(operation, postgres.PGCreateExtension):
new_operations_bubble_top.append(operation)
elif isinstance(operation, dj_migrations.SeparateDatabaseAndState):
# A valid use case for this should be given before any work is done.
pass

if new_operations_bubble_top:
migration = changes[app][0]
migration.operations = new_operations_bubble_top + migration.operations
migration.extra_imports = new_imports

migration = changes[app][-1]
migration.operations += operations
migration.extra_imports = imports
migration.operations += new_operations
migration.extra_imports = new_imports

def replace_current_migrations(self, original, graph, changes):
"""
Expand Down Expand Up @@ -222,7 +213,7 @@ def squash(self, real_loader, squash_loader, ignore_apps, migration_name=None):
self.convert_migration_references_to_objects(real_loader, changes, ignore_apps)
self.rename_migrations(real_loader, graph, changes, migration_name)
self.replace_current_migrations(real_loader, graph, changes)
self.add_non_elidables(real_loader, squash_loader, changes)
self.add_non_elidables(real_loader, changes)

for app, change in changes_.items():
changes[app].extend(change)
Expand Down
51 changes: 0 additions & 51 deletions django_squash/db/migrations/operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from django.db.migrations import RunPython as RunPythonBase, RunSQL as RunSQLBase


class Variable:
"""
Wrapper type to be able to format the variable name correctly inside a migration
Expand All @@ -12,51 +9,3 @@ def __init__(self, name, value):

def __bool__(self):
return bool(self.value)


class RunPython(RunPythonBase):
# Fake the class so the OperationWriter thinks its the internal class and not a custom one
__class__ = RunPythonBase

def deconstruct(self):
name, args, kwargs = super().deconstruct()
kwargs["elidable"] = self.elidable
return name, args, kwargs

@classmethod
def from_operation(cls, operation, unique_names):
operation.code.__original_qualname__ = operation.code.__qualname__
operation.code.__qualname__ = unique_names.function(operation.code)
if operation.reverse_code:
operation.reverse_code.__original_qualname__ = operation.reverse_code.__qualname__
operation.reverse_code.__qualname__ = unique_names.function(operation.reverse_code)
return cls(
code=operation.code,
reverse_code=operation.reverse_code,
atomic=operation.atomic,
hints=operation.hints,
elidable=operation.elidable,
)


class RunSQL(RunSQLBase):
# Fake the class so the OperationWriter thinks its the internal class and not a custom one
__class__ = RunSQLBase

def deconstruct(self):
name, args, kwargs = super().deconstruct()
kwargs["elidable"] = self.elidable
return name, args, kwargs

@classmethod
def from_operation(cls, operation, unique_names):
name = unique_names("SQL", force_number=True)
reverse_sql = Variable("%s_ROLLBACK" % name, operation.reverse_sql) if operation.reverse_sql else None

return cls(
sql=Variable(name, operation.sql),
reverse_sql=reverse_sql,
state_operations=operation.state_operations,
hints=operation.hints,
elidable=operation.elidable,
)
29 changes: 9 additions & 20 deletions django_squash/db/migrations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,30 +115,19 @@ def normalize_function_name(name):
return function_name


def extract_function_source(f):
function_source = inspect.getsource(f)
if normalize_function_name(f.__original_qualname__) == normalize_function_name(f.__qualname__):
return function_source

function_source = re.sub(
rf"(def\s+){normalize_function_name(f.__original_qualname__)}",
rf"\1{normalize_function_name(f.__qualname__)}",
function_source,
1,
)
return function_source


def copy_func(f, name=None):
def copy_func(f, name):
"""
Return a function with same code, globals, defaults, closure, and name (or provide a new name)
"""
name = name or f.__qualname__
func = types.FunctionType(f.__code__, f.__globals__, name, f.__defaults__, f.__closure__)
func.__qualname__ = f.__qualname__
func.__original_qualname__ = f.__original_qualname__
func.__original_module__ = f.__module__
func.__original_function__ = f
func.__qualname__ = name
func.__original__ = f
func.__source__ = re.sub(
rf"(def\s+){normalize_function_name(f.__qualname__)}",
rf"\1{normalize_function_name(name)}",
inspect.getsource(f),
1,
)
return func


Expand Down
93 changes: 73 additions & 20 deletions django_squash/db/migrations/writer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import os
import re
import textwrap
Expand All @@ -8,7 +9,8 @@
from django.db.migrations import writer as dj_writer
from django.utils.timezone import now

from django_squash.db.migrations import utils
from django_squash.contrib import postgres
from django_squash.db.migrations import operators, utils

SUPPORTED_DJANGO_WRITER = (
"39645482d4eb04b9dd21478dc4bdfeea02393913dd2161bf272f4896e8b3b343", # 5.0
Expand Down Expand Up @@ -38,6 +40,16 @@ def check_django_migration_hash():
check_django_migration_hash()


class OperationWriter(dj_writer.OperationWriter):
def serialize(self):
if isinstance(self.operation, postgres.PGCreateExtension):
if not utils.is_code_in_site_packages(self.operation.__class__.__module__):
self.feed("%s()," % (self.operation.__class__.__name__))
return self.render(), set()

return super().serialize()


class ReplacementMigrationWriter(dj_writer.MigrationWriter):
"""
Take a Migration instance and is able to produce the contents
Expand Down Expand Up @@ -70,7 +82,7 @@ def get_kwargs(self): # pragma: no cover
# Deconstruct operations
operations = []
for operation in self.migration.operations:
operation_string, operation_imports = dj_writer.OperationWriter(operation).serialize()
operation_string, operation_imports = OperationWriter(operation).serialize()
imports.update(operation_imports)
operations.append(operation_string)
items["operations"] = "\n".join(operations) + "\n" if operations else ""
Expand Down Expand Up @@ -152,8 +164,42 @@ class Migration(migrations.Migration):
def as_string(self):
if hasattr(self.migration, "is_migration_level") and self.migration.is_migration_level:
return self.replace_in_migration()
else:
return super().as_string()

variables = []
unique_names = utils.UniqueVariableName()
for operation in self.migration.operations:
operation._deconstruct = operation.__class__.deconstruct

def deconstruct(self):
name, args, kwargs = self._deconstruct(self)
kwargs["elidable"] = self.elidable
return name, args, kwargs

if isinstance(operation, dj_migrations.RunPython):
# Bind the deconstruct() to the instance to get the elidable
operation.deconstruct = deconstruct.__get__(operation, operation.__class__)
if not utils.is_code_in_site_packages(operation.code.__module__):
code_name = unique_names.function(operation.code)
operation.code = utils.copy_func(operation.code, code_name)
operation.code.__in_migration_file__ = True
if operation.reverse_code:
if not utils.is_code_in_site_packages(operation.reverse_code.__module__):
reversed_code_name = unique_names.function(operation.reverse_code)
operation.reverse_code = utils.copy_func(operation.reverse_code, reversed_code_name)
operation.reverse_code.__in_migration_file__ = True
elif isinstance(operation, dj_migrations.RunSQL):
# Bind the deconstruct() to the instance to get the elidable
operation.deconstruct = deconstruct.__get__(operation, operation.__class__)

variable_name = unique_names("SQL", force_number=True)
variables.append(self.template_variable % (variable_name, repr(operation.sql)))
operation.sql = operators.Variable(variable_name, operation.sql)
if operation.reverse_sql:
reverse_variable_name = "%s_ROLLBACK" % variable_name
variables.append(self.template_variable % (reverse_variable_name, repr(operation.reverse_sql)))
operation.reverse_sql = operators.Variable(reverse_variable_name, operation.reverse_sql)

return super().as_string()

def replace_in_migration(self):
if self.migration._deleted:
Expand All @@ -171,41 +217,48 @@ def replace_in_migration(self):
source = utils.replace_migration_attribute(source, "replaces", self.migration.replaces)
changed = True
if not changed:
raise NotImplementedError()
raise NotImplementedError() # pragma: no cover

return source

def get_kwargs(self):
kwargs = super().get_kwargs()

functions_references = []
functions = []
variables = []
for operation in self.migration.operations:
if isinstance(operation, dj_migrations.RunPython):
code_reference = operation.code
if hasattr(operation.code, "__original_function__"):
code_reference = operation.code.__original_function__
if code_reference in functions_references:
continue
functions_references.append(code_reference)
if hasattr(operation.code, "__original__"):
if operation.code.__original__ in functions_references:
continue
functions_references.append(operation.code.__original__)
else:
if operation.code in functions_references:
continue
functions_references.append(operation.code)

if not utils.is_code_in_site_packages(operation.code.__module__):
functions.append(textwrap.dedent(utils.extract_function_source(operation.code)))
functions.append(textwrap.dedent(operation.code.__source__))
if operation.reverse_code:
reverse_code_reference = operation.reverse_code
if hasattr(operation.reverse_code, "__original_function__"):
reverse_code_reference = operation.reverse_code.__original_function__
if reverse_code_reference in functions_references:
continue
functions_references.append(reverse_code_reference)
if hasattr(operation.reverse_code, "__original__"):
if operation.reverse_code.__original__ in functions_references:
continue
functions_references.append(operation.reverse_code.__original__)
else:
if operation.reverse_code in functions_references:
continue
functions_references.append(operation.reverse_code)
if not utils.is_code_in_site_packages(operation.reverse_code.__module__):
functions.append(textwrap.dedent(utils.extract_function_source(operation.reverse_code)))
functions.append(textwrap.dedent(operation.reverse_code.__source__))
elif isinstance(operation, dj_migrations.RunSQL):
variables.append(self.template_variable % (operation.sql.name, repr(operation.sql.value)))
if operation.reverse_sql:
variables.append(
self.template_variable % (operation.reverse_sql.name, repr(operation.reverse_sql.value))
)
elif isinstance(operation, postgres.PGCreateExtension):
if not utils.is_code_in_site_packages(operation.__class__.__module__):
functions.append(textwrap.dedent(inspect.getsource(operation.__class__)))

kwargs["functions"] = ("\n\n" if functions else "") + "\n\n".join(functions)
kwargs["variables"] = ("\n\n" if variables else "") + "\n\n".join(variables)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"ipdb",
"isort",
"libcst",
"psycopg2",
"pytest-cov",
"pytest-django",
"restructuredtext-lint",
Expand Down
10 changes: 10 additions & 0 deletions tests/app/tests/migrations/pg_indexes/0001_initial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from django.contrib.postgres.operations import BtreeGinExtension
from django.db import migrations


class Migration(migrations.Migration):
dependencies = []

operations = [
BtreeGinExtension(),
]
Loading

0 comments on commit cb4611f

Please sign in to comment.