Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

async support for mad-migration #72

Merged
merged 4 commits into from
Jan 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 41 additions & 51 deletions madmigration/basemigration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

logger = logging.getLogger(__name__)


class BaseMigrate():
q = Queue() # Static queue for fk constraints data
tables = set()
def __init__(self, config: Config,destination_db):

def __init__(self, config: Config, destination_db):
self.global_config = config
self.migration_tables = config.migrationTables
self.engine = destination_db.engine
Expand All @@ -39,15 +41,14 @@ def __init__(self, config: Config,destination_db):
self.db_operations = DbOperations(self.engine)

signal.signal(signal.SIGINT, self.sig_handler)



def __enter__(self):
return self

def __exit__(self, type, value, traceback):
self.engine.session.close()
def sig_handler(self,sig_num,_sig_frame):

def sig_handler(self, sig_num, _sig_frame):
logger.warn("TERMINATE APP WITH SIGNAL -> %d" % sig_num)
if self.tables:
self.db_operations.db_drop_everything(self.tables)
Expand Down Expand Up @@ -85,9 +86,8 @@ def collect_drop_fk(self):
return False
finally:
conn.close()


def parse_migration_tables(self,tabels_schema:MigrationTablesSchema):
def parse_migration_tables(self, tabels_schema: MigrationTablesSchema):
"""
This function parses migrationTables from yaml file
"""
Expand All @@ -98,7 +98,6 @@ def parse_migration_tables(self,tabels_schema:MigrationTablesSchema):
except Exception as err:
logger.error("parse_migration_tables [error] -> %s" % err)


def parse_migration_columns(
self, tablename: str, migration_columns: ColumnParametersSchema
):
Expand All @@ -107,7 +106,7 @@ def parse_migration_columns(
"""
try:
update = self.check_table(tablename)

for col in migration_columns:
self.source_column = col.sourceColumn
self.destination_column = col.destinationColumn
Expand All @@ -119,48 +118,47 @@ def parse_migration_columns(
col = Column(self.destination_column.name, column_type, **self.dest_options)
if update:
if not self.check_column(tablename, self.destination_column.name):
# self.add_alter_column(tablename, {"column_name": self.destination_column.name,"type":column_type,"options":{**self.dest_options}})
# else:
self.add_updated_table(tablename,col)
# self.add_alter_column(tablename, {"column_name": self.destination_column.name,"type":column_type,"options":{**self.dest_options}})
# else:
self.add_updated_table(tablename, col)
else:
self.add_created_table(tablename,col)
self.add_created_table(tablename, col)
except Exception as err:
logger.error("parse_migration_columns [error] -> %s" % err)


def add_updated_table(self,table_name: str, col: Column):
def add_updated_table(self, table_name: str, col: Column):
self.table_update[table_name].append(col)

def add_created_table(self,table_name: str, col: Column):
def add_created_table(self, table_name: str, col: Column):
self.table_create[table_name].append(col)

def add_alter_column(self,table_name: str, col: Column):
def add_alter_column(self, table_name: str, col: Column):
self.alter_col[table_name].append(col)

def prepare_tables(self):
try:
for migrate_table in self.migration_tables:
if migrate_table.migrationTable.DestinationTable.create:
self.parse_migration_tables(migrate_table)
self.parse_migration_columns(self.destination_table.get("name"),self.columns)
self.parse_migration_columns(self.destination_table.get("name"), self.columns)
except Exception as err:
logger.error("prepare_tables [error] -> %s" % err)

def update_table(self):
for tab,col in self.table_update.items():
self.db_operations.add_column(tab,*col)

for tab, col in self.table_update.items():
self.db_operations.add_column(tab, *col)
return True

def alter_columns(self):
for tab, val in self.alter_col.items():
for i in val:
self.db_operations.update_column(tab,i.pop("column_name"),i.pop("type"), **i.pop("options"))
self.db_operations.update_column(tab, i.pop("column_name"), i.pop("type"), **i.pop("options"))
return True

def create_tables(self):
for tab, col in self.table_create.items():
self.db_operations.create_table(tab,*col)
self.db_operations.create_table(tab, *col)
return True

def process(self):
Expand All @@ -174,14 +172,14 @@ def process(self):
self.collect_drop_fk()
self.update_table()
self.create_tables()
self.db_operations.create_fk_constraint(self.fk_constraints,self.contraints_columns)
self.db_operations.create_fk_constraint(self.fk_constraints, self.contraints_columns)
return True
except Exception as err:
logger.error("create_tables [error] -> %s" % err)

def _parse_column_type(self) -> object:
""" Parse column type and options (length,type and etc.) """

try:
column_type = self.get_column_type(self.dest_options.pop("type_cast"))
type_length = self.dest_options.pop("length")
Expand All @@ -190,7 +188,7 @@ def _parse_column_type(self) -> object:
return column_type
except Exception as err:
logger.error("_parse_column_type [error] -> %s" % err)

# logger.error(self.dest_options.get("length"))
type_length = self.dest_options.pop("length")
if type_length:
Expand All @@ -207,7 +205,6 @@ def _parse_fk(self, tablename, fk_options):
except Exception as err:
logger.error("_parse_fk [error] -> %s" % err)


def check_table(self, table_name: str) -> bool:
""" Check table exist or not, and wait user input """
try:
Expand All @@ -218,11 +215,12 @@ def check_table(self, table_name: str) -> bool:
logger.error("check_table [error] -> %s" % err)
return False

def get_input(self,table_name):
def get_input(self, table_name):
while True:
answ = input(
f"Table with name '{table_name}' already exist,\
'{table_name}' table will be dropped and recreated,your table data will be lost,process?(y/n) ")
'{table_name}' table will be dropped and recreated,your table data will be lost in the process.\
Do you want to continue?(y/n) ")
if answ.lower() == "y":
self.db_operations.drop_fk(self.dest_fk)
self.db_operations.drop_table(table_name)
Expand All @@ -231,7 +229,7 @@ def get_input(self,table_name):
return True
else:
continue

def get_table_attribute_from_base_class(self, source_table_name: str):
"""
This function gets table name attribute from sourceDB.base.classes. Example sourceDB.base.class.(table name)
Expand All @@ -251,36 +249,32 @@ def get_data_from_source_table(self, source_table_name: str, source_columns: lis
data[column] = getattr(row, column)
yield data


def check_column(self, table_name: str, column_name: str) -> bool:
"""
Check column exist in destination table or not
param:: column_name -> is destination column name
"""
try:
insp = reflection.Inspector.from_engine(self.engine)
has_column = False
for col in insp.get_columns(table_name):
if column_name not in col["name"]:
continue
return True
return has_column
if column_name in col["name"]:
return True
return False
except Exception as err:
logger.error("check_column [error] -> %s" % err)
return False


@staticmethod
def insert_data(engine, table_name, data: dict):
# stmt = None
try:
stmt = engine.base.metadata.tables[table_name].insert().values(**data)

except Exception as err:
logger.error("insert_data stmt [error] -> %s" % err)
return
# logger.error("STMT",stmt)

try:
engine.session.execute(stmt)
except Exception as err:
Expand All @@ -298,13 +292,13 @@ def insert_data(engine, table_name, data: dict):
@staticmethod
def insert_queue(engine):
for stmt in BaseMigrate.q.queue:

try:
logger.info("Inserting from queue")
engine.session.execute(stmt)
except Exception as err:
logger.error("insert_queue [error] -> %s" % err)

try:
engine.session.commit()
except Exception as err:
Expand All @@ -313,7 +307,6 @@ def insert_queue(engine):
finally:
engine.session.close()


@staticmethod
def put_queue(data):
BaseMigrate.q.put(data)
Expand All @@ -327,7 +320,7 @@ def type_cast(data_from_source, mt, convert_info: dict):
for columns in mt.migrationTable.MigrationColumns:
source_column = columns.sourceColumn.name
destination_column = columns.destinationColumn.name

if columns.destinationColumn.options.type_cast:
destination_type_cast = columns.destinationColumn.options.type_cast
else:
Expand All @@ -336,7 +329,7 @@ def type_cast(data_from_source, mt, convert_info: dict):
if convert_info.get(destination_column):
# ClassType is Class of data type (int, str, float, etc...)
# Using this ClassType we are converting data into format specified in type_cast
datatype = get_type_object(destination_type_cast)
datatype = get_type_object(destination_type_cast)

try:
if datatype == type(data_from_source.get(source_column)):
Expand All @@ -350,15 +343,14 @@ def type_cast(data_from_source, mt, convert_info: dict):
except Exception as err:
logger.error("type_cast [error] -> %s" % err)
data_from_source[destination_column] = None

except Exception as err:
logger.error("type_cast [error] -> %s" % err)
data_from_source[destination_column] = None
else:
data_from_source[destination_column] = data_from_source.pop(source_column)

return data_from_source

return data_from_source

@staticmethod
def get_column_type(type_name: str) -> object:
Expand All @@ -367,5 +359,3 @@ def get_column_type(type_name: str) -> object:
:return: object class
"""
raise NotImplementedError


42 changes: 42 additions & 0 deletions madmigration/db_init/async_connection_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio

import gino
from sqlalchemy import event, Table
from sqlalchemy.ext.automap import automap_base
from sqlalchemy import create_engine

from madmigration.db_init.connection_engine import DestinationDB
from madmigration.utils.helpers import (
database_not_exists,
goodby_message,
aio_database_exists,
run_await_funtion
)


@event.listens_for(Table, "after_parent_attach")
def before_parent_attach(target, parent):
if not target.primary_key and "id" in target.c:
print(target)


class AsyncSourceDB:
def __init__(self, source_uri):
if not aio_database_exists(source_uri):
goodby_message(database_not_exists(source_uri), 0)

metadata = gino.Gino()
self.base = automap_base(metadata=metadata)
self.engine = create_engine(source_uri)
self.base.prepare()


class AsyncDestinationDB(DestinationDB):
def __init__(self, destination_uri):
self.check_for_or_create_database(destination_uri, check_for_database=aio_database_exists)
self.engine = create_engine(destination_uri, strategy='gino')


@run_await_funtion()
async def create_engine(*args, **kwargs):
return await gino.create_engine(*args, **kwargs)
22 changes: 11 additions & 11 deletions madmigration/db_init/connection_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy.ext.automap import automap_base
from sqlalchemy_utils.functions.database import database_exists, create_database
import sys
from madmigration.utils.helpers import issue_url,app_name,parse_uri
from madmigration.utils.helpers import issue_url, app_name, parse_uri
from madmigration.utils.helpers import database_not_exists, goodby_message
import logging
logger = logging.getLogger(__name__)
Expand All @@ -15,7 +15,6 @@ def before_parent_attach(target, parent):
print(target)



class SourceDB:
def __init__(self, source_uri):
if not database_exists(source_uri):
Expand All @@ -28,7 +27,15 @@ def __init__(self, source_uri):

class DestinationDB:
def __init__(self, destination_uri):
if not database_exists(destination_uri):
self.check_for_or_create_database(destination_uri)

self.base = automap_base()
self.engine = create_engine(destination_uri)
# self.base.prepare(self.engine, reflect=True)
self.session = Session(self.engine, autocommit=False, autoflush=False)

def check_for_or_create_database(self, destination_uri, check_for_database: callable = database_exists):
if not check_for_database(destination_uri):
while True:
database = parse_uri(destination_uri)
msg = input(f"The database {database} does not exists, would you like to create it in the destination?(y/n) ")
Expand All @@ -41,13 +48,6 @@ def __init__(self, destination_uri):
goodby_message(database_not_exists(destination_uri), 1)
break
elif msg.lower() == "n":
goodby_message("Destination database does not exit \nExiting ..", 0)
goodby_message("Destination database does not exist \nExiting ...", 0)
break
print("Please, select command")


self.base = automap_base()
self.engine = create_engine(destination_uri)
# self.base.prepare(self.engine, reflect=True)
self.session = Session(self.engine, autocommit=False, autoflush=False)

Loading