diff --git a/madmigration/basemigration/base.py b/madmigration/basemigration/base.py index 47da293..35c4bb2 100644 --- a/madmigration/basemigration/base.py +++ b/madmigration/basemigration/base.py @@ -20,10 +20,12 @@ logger = logging.getLogger(__name__) + class BaseMigrate(): q = Queue() # Static queue for fk constraints data tables = set() - def __init__(self, config: Config,destination_db): + + def __init__(self, config: Config, destination_db): self.global_config = config self.migration_tables = config.migrationTables self.engine = destination_db.engine @@ -39,15 +41,14 @@ def __init__(self, config: Config,destination_db): self.db_operations = DbOperations(self.engine) signal.signal(signal.SIGINT, self.sig_handler) - - + def __enter__(self): return self def __exit__(self, type, value, traceback): self.engine.session.close() - - def sig_handler(self,sig_num,_sig_frame): + + def sig_handler(self, sig_num, _sig_frame): logger.warn("TERMINATE APP WITH SIGNAL -> %d" % sig_num) if self.tables: self.db_operations.db_drop_everything(self.tables) @@ -85,9 +86,8 @@ def collect_drop_fk(self): return False finally: conn.close() - - def parse_migration_tables(self,tabels_schema:MigrationTablesSchema): + def parse_migration_tables(self, tabels_schema: MigrationTablesSchema): """ This function parses migrationTables from yaml file """ @@ -98,7 +98,6 @@ def parse_migration_tables(self,tabels_schema:MigrationTablesSchema): except Exception as err: logger.error("parse_migration_tables [error] -> %s" % err) - def parse_migration_columns( self, tablename: str, migration_columns: ColumnParametersSchema ): @@ -107,7 +106,7 @@ def parse_migration_columns( """ try: update = self.check_table(tablename) - + for col in migration_columns: self.source_column = col.sourceColumn self.destination_column = col.destinationColumn @@ -119,22 +118,21 @@ def parse_migration_columns( col = Column(self.destination_column.name, column_type, **self.dest_options) if update: if not self.check_column(tablename, self.destination_column.name): - # self.add_alter_column(tablename, {"column_name": self.destination_column.name,"type":column_type,"options":{**self.dest_options}}) - # else: - self.add_updated_table(tablename,col) + # self.add_alter_column(tablename, {"column_name": self.destination_column.name,"type":column_type,"options":{**self.dest_options}}) + # else: + self.add_updated_table(tablename, col) else: - self.add_created_table(tablename,col) + self.add_created_table(tablename, col) except Exception as err: logger.error("parse_migration_columns [error] -> %s" % err) - - def add_updated_table(self,table_name: str, col: Column): + def add_updated_table(self, table_name: str, col: Column): self.table_update[table_name].append(col) - def add_created_table(self,table_name: str, col: Column): + def add_created_table(self, table_name: str, col: Column): self.table_create[table_name].append(col) - def add_alter_column(self,table_name: str, col: Column): + def add_alter_column(self, table_name: str, col: Column): self.alter_col[table_name].append(col) def prepare_tables(self): @@ -142,25 +140,25 @@ def prepare_tables(self): for migrate_table in self.migration_tables: if migrate_table.migrationTable.DestinationTable.create: self.parse_migration_tables(migrate_table) - self.parse_migration_columns(self.destination_table.get("name"),self.columns) + self.parse_migration_columns(self.destination_table.get("name"), self.columns) except Exception as err: logger.error("prepare_tables [error] -> %s" % err) def update_table(self): - - for tab,col in self.table_update.items(): - self.db_operations.add_column(tab,*col) + + for tab, col in self.table_update.items(): + self.db_operations.add_column(tab, *col) return True - + def alter_columns(self): for tab, val in self.alter_col.items(): for i in val: - self.db_operations.update_column(tab,i.pop("column_name"),i.pop("type"), **i.pop("options")) + self.db_operations.update_column(tab, i.pop("column_name"), i.pop("type"), **i.pop("options")) return True - + def create_tables(self): for tab, col in self.table_create.items(): - self.db_operations.create_table(tab,*col) + self.db_operations.create_table(tab, *col) return True def process(self): @@ -174,14 +172,14 @@ def process(self): self.collect_drop_fk() self.update_table() self.create_tables() - self.db_operations.create_fk_constraint(self.fk_constraints,self.contraints_columns) + self.db_operations.create_fk_constraint(self.fk_constraints, self.contraints_columns) return True except Exception as err: logger.error("create_tables [error] -> %s" % err) def _parse_column_type(self) -> object: """ Parse column type and options (length,type and etc.) """ - + try: column_type = self.get_column_type(self.dest_options.pop("type_cast")) type_length = self.dest_options.pop("length") @@ -190,7 +188,7 @@ def _parse_column_type(self) -> object: return column_type except Exception as err: logger.error("_parse_column_type [error] -> %s" % err) - + # logger.error(self.dest_options.get("length")) type_length = self.dest_options.pop("length") if type_length: @@ -207,7 +205,6 @@ def _parse_fk(self, tablename, fk_options): except Exception as err: logger.error("_parse_fk [error] -> %s" % err) - def check_table(self, table_name: str) -> bool: """ Check table exist or not, and wait user input """ try: @@ -218,11 +215,12 @@ def check_table(self, table_name: str) -> bool: logger.error("check_table [error] -> %s" % err) return False - def get_input(self,table_name): + def get_input(self, table_name): while True: answ = input( f"Table with name '{table_name}' already exist,\ -'{table_name}' table will be dropped and recreated,your table data will be lost,process?(y/n) ") +'{table_name}' table will be dropped and recreated,your table data will be lost in the process.\ + Do you want to continue?(y/n) ") if answ.lower() == "y": self.db_operations.drop_fk(self.dest_fk) self.db_operations.drop_table(table_name) @@ -231,7 +229,7 @@ def get_input(self,table_name): return True else: continue - + def get_table_attribute_from_base_class(self, source_table_name: str): """ This function gets table name attribute from sourceDB.base.classes. Example sourceDB.base.class.(table name) @@ -251,7 +249,6 @@ def get_data_from_source_table(self, source_table_name: str, source_columns: lis data[column] = getattr(row, column) yield data - def check_column(self, table_name: str, column_name: str) -> bool: """ Check column exist in destination table or not @@ -259,28 +256,25 @@ def check_column(self, table_name: str, column_name: str) -> bool: """ try: insp = reflection.Inspector.from_engine(self.engine) - has_column = False for col in insp.get_columns(table_name): - if column_name not in col["name"]: - continue - return True - return has_column + if column_name in col["name"]: + return True + return False except Exception as err: logger.error("check_column [error] -> %s" % err) return False - @staticmethod def insert_data(engine, table_name, data: dict): # stmt = None try: stmt = engine.base.metadata.tables[table_name].insert().values(**data) - + except Exception as err: logger.error("insert_data stmt [error] -> %s" % err) return # logger.error("STMT",stmt) - + try: engine.session.execute(stmt) except Exception as err: @@ -298,13 +292,13 @@ def insert_data(engine, table_name, data: dict): @staticmethod def insert_queue(engine): for stmt in BaseMigrate.q.queue: - + try: logger.info("Inserting from queue") engine.session.execute(stmt) except Exception as err: logger.error("insert_queue [error] -> %s" % err) - + try: engine.session.commit() except Exception as err: @@ -313,7 +307,6 @@ def insert_queue(engine): finally: engine.session.close() - @staticmethod def put_queue(data): BaseMigrate.q.put(data) @@ -327,7 +320,7 @@ def type_cast(data_from_source, mt, convert_info: dict): for columns in mt.migrationTable.MigrationColumns: source_column = columns.sourceColumn.name destination_column = columns.destinationColumn.name - + if columns.destinationColumn.options.type_cast: destination_type_cast = columns.destinationColumn.options.type_cast else: @@ -336,7 +329,7 @@ def type_cast(data_from_source, mt, convert_info: dict): if convert_info.get(destination_column): # ClassType is Class of data type (int, str, float, etc...) # Using this ClassType we are converting data into format specified in type_cast - datatype = get_type_object(destination_type_cast) + datatype = get_type_object(destination_type_cast) try: if datatype == type(data_from_source.get(source_column)): @@ -350,15 +343,14 @@ def type_cast(data_from_source, mt, convert_info: dict): except Exception as err: logger.error("type_cast [error] -> %s" % err) data_from_source[destination_column] = None - + except Exception as err: logger.error("type_cast [error] -> %s" % err) data_from_source[destination_column] = None else: data_from_source[destination_column] = data_from_source.pop(source_column) - - return data_from_source + return data_from_source @staticmethod def get_column_type(type_name: str) -> object: @@ -367,5 +359,3 @@ def get_column_type(type_name: str) -> object: :return: object class """ raise NotImplementedError - - diff --git a/madmigration/db_init/async_connection_engine.py b/madmigration/db_init/async_connection_engine.py new file mode 100644 index 0000000..3ec4463 --- /dev/null +++ b/madmigration/db_init/async_connection_engine.py @@ -0,0 +1,42 @@ +import asyncio + +import gino +from sqlalchemy import event, Table +from sqlalchemy.ext.automap import automap_base +from sqlalchemy import create_engine + +from madmigration.db_init.connection_engine import DestinationDB +from madmigration.utils.helpers import ( + database_not_exists, + goodby_message, + aio_database_exists, + run_await_funtion +) + + +@event.listens_for(Table, "after_parent_attach") +def before_parent_attach(target, parent): + if not target.primary_key and "id" in target.c: + print(target) + + +class AsyncSourceDB: + def __init__(self, source_uri): + if not aio_database_exists(source_uri): + goodby_message(database_not_exists(source_uri), 0) + + metadata = gino.Gino() + self.base = automap_base(metadata=metadata) + self.engine = create_engine(source_uri) + self.base.prepare() + + +class AsyncDestinationDB(DestinationDB): + def __init__(self, destination_uri): + self.check_for_or_create_database(destination_uri, check_for_database=aio_database_exists) + self.engine = create_engine(destination_uri, strategy='gino') + + +@run_await_funtion() +async def create_engine(*args, **kwargs): + return await gino.create_engine(*args, **kwargs) diff --git a/madmigration/db_init/connection_engine.py b/madmigration/db_init/connection_engine.py index 296defc..808f92b 100644 --- a/madmigration/db_init/connection_engine.py +++ b/madmigration/db_init/connection_engine.py @@ -3,7 +3,7 @@ from sqlalchemy.ext.automap import automap_base from sqlalchemy_utils.functions.database import database_exists, create_database import sys -from madmigration.utils.helpers import issue_url,app_name,parse_uri +from madmigration.utils.helpers import issue_url, app_name, parse_uri from madmigration.utils.helpers import database_not_exists, goodby_message import logging logger = logging.getLogger(__name__) @@ -15,7 +15,6 @@ def before_parent_attach(target, parent): print(target) - class SourceDB: def __init__(self, source_uri): if not database_exists(source_uri): @@ -28,7 +27,15 @@ def __init__(self, source_uri): class DestinationDB: def __init__(self, destination_uri): - if not database_exists(destination_uri): + self.check_for_or_create_database(destination_uri) + + self.base = automap_base() + self.engine = create_engine(destination_uri) + # self.base.prepare(self.engine, reflect=True) + self.session = Session(self.engine, autocommit=False, autoflush=False) + + def check_for_or_create_database(self, destination_uri, check_for_database: callable = database_exists): + if not check_for_database(destination_uri): while True: database = parse_uri(destination_uri) msg = input(f"The database {database} does not exists, would you like to create it in the destination?(y/n) ") @@ -41,13 +48,6 @@ def __init__(self, destination_uri): goodby_message(database_not_exists(destination_uri), 1) break elif msg.lower() == "n": - goodby_message("Destination database does not exit \nExiting ..", 0) + goodby_message("Destination database does not exist \nExiting ...", 0) break print("Please, select command") - - - self.base = automap_base() - self.engine = create_engine(destination_uri) - # self.base.prepare(self.engine, reflect=True) - self.session = Session(self.engine, autocommit=False, autoflush=False) - diff --git a/madmigration/db_operations/async_operations.py b/madmigration/db_operations/async_operations.py new file mode 100644 index 0000000..420bd60 --- /dev/null +++ b/madmigration/db_operations/async_operations.py @@ -0,0 +1,35 @@ +import asyncio +import logging + +import gino + +from gino.schema import DropConstraint, DropTable +from sqlalchemy import ForeignKeyConstraint, Table +from sqlalchemy.engine import reflection + +from alembic.migration import MigrationContext +from alembic.operations import Operations + +logger = logging.getLogger(__name__) + + +class AsyncDbOperations: + def __init__(self, engine): + self.engine = engine + self.metadata = gino.Gino() + self.loop = asyncio.get_event_loop() + + async def create_table(self, table_name: str, *columns) -> bool: + """ create a new table """ + try: + table = Table(table_name, self.metadata, *columns) + + # the gino object comes from "metadata" included in the Table class. metadata is + # a gino.Gino object which is why table.gino is possible + await table.gino.create(self.engine, checkfirst=True) + + logger.info("%s is created", table_name) + return True + except Exception as err: + logger.error("_create_table [error] -> %s", err) + return False diff --git a/madmigration/utils/helpers.py b/madmigration/utils/helpers.py index cd37b1a..586738a 100644 --- a/madmigration/utils/helpers.py +++ b/madmigration/utils/helpers.py @@ -1,7 +1,13 @@ +import asyncio +import functools from datetime import datetime +from copy import copy + +import gino from sqlalchemy.schema import DropTable from sqlalchemy.schema import ForeignKeyConstraint from sqlalchemy.ext.compiler import compiles +from sqlalchemy.engine.url import make_url from typing import Union import os import sys @@ -27,6 +33,8 @@ ########################### # Get class of cast # ########################### + + def get_cast_type(type_name: str) -> object: """ :param type_name: str @@ -52,18 +60,18 @@ def detect_driver(driver: str) -> Union[MysqlMigrate, PgMigrate, MongoDbMigrate] :return: object class """ return { - "mysqldb" : MysqlMigrate, + "mysqldb": MysqlMigrate, "mysql+mysqldb": MysqlMigrate, "pymysql": MysqlMigrate, - "mysql+pymysql" : MysqlMigrate, - "mariadb+pymsql" : MysqlMigrate, - "psycopg2": PgMigrate, + "mysql+pymysql": MysqlMigrate, + "mariadb+pymsql": MysqlMigrate, + "psycopg2": PgMigrate, "pg8000": PgMigrate, "pyodbc": MssqlMigrate, "mongodb": MongoDbMigrate - # "postgresql+asyncpg": postgres_migrate, - # "asyncpg": postgres_migrate + # "postgresql+asyncpg": postgres_migrate, + # "asyncpg": postgres_migrate }.get(driver) @@ -77,7 +85,6 @@ def _compile_drop_table(element, compiler, **kwargs): return compiler.visit_drop_table(element) + " CASCADE" - @compiles(ForeignKeyConstraint, "mysql", "mariadb") def process(element, compiler, **kw): element.deferrable = element.initially = None @@ -88,7 +95,6 @@ def check_file(file): return Path(file).is_file() and os.access(file, os.R_OK) - def file_not_found(file): logger.error(f"Given file does not exists file: {file}") sys.exit(1) @@ -98,13 +104,14 @@ def file_not_found(file): def issue_url(): return "https://github.com/MadeByMads/mad-migration/issues" + def app_name(): return "madmigrate" def parse_uri(uri): if "///" in uri: - database_name = uri.split("///")[-1] + database_name = uri.split("///")[-1] else: database_name = uri.split("/")[-1] @@ -112,7 +119,6 @@ def parse_uri(uri): def database_not_exists(database): - """This function will be executed if there is no database exists """ database = parse_uri(database) @@ -134,4 +140,80 @@ def database_not_exists(database): def goodby_message(message, exit_code=0): print(message, flush=True) logger.error(message) - sys.exit(int(exit_code)) \ No newline at end of file + sys.exit(int(exit_code)) + + +def run_await_funtion(loop=None): + """ + a decorator to help run async functions like they were sync + """ + if not loop: + loop = asyncio.get_event_loop() + + if not loop.is_running(): + loop = asyncio.get_event_loop() + + def wrapper_function(func): + @functools.wraps(func) + def wrapped_function(*args, **kwargs): + return loop.run_until_complete(func(*args, **kwargs)) + return wrapped_function + + return wrapper_function + + +@run_await_funtion() +async def aio_database_exists(url): + async def get_scalar_result(engine, sql): + conn = await engine.acquire() + result = await conn.scalar(sql) + await conn.release() + await engine.close() + return result + + def sqlite_file_exists(database): + if not os.path.isfile(database) or os.path.getsize(database) < 100: + return False + + with open(database, 'rb') as f: + header = f.read(100) + + return header[:16] == b'SQLite format 3\x00' + + url = copy(make_url(url)) + database, url.database = url.database, None + engine = await gino.create_engine(url) + + if engine.dialect.name == 'postgresql': + text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database + result = await get_scalar_result(engine, text) + return bool(result) + + elif engine.dialect.name == 'mysql': + text = ("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA " + "WHERE SCHEMA_NAME = '%s'" % database) + result = await get_scalar_result(engine, text) + return bool(result) + + elif engine.dialect.name == 'sqlite': + if database: + return database == ':memory:' or sqlite_file_exists(database) + else: + return True + + else: + await engine.close() + engine = None + text = 'SELECT 1' + try: + url.database = database + engine = await gino.create_engine(url) + result = engine.scalar(text) + await result.release() + return True + + except (ProgrammingError, OperationalError): + return False + finally: + if engine is not None: + await engine.close() diff --git a/requirements.txt b/requirements.txt index 6f33dcc..2b4b193 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,17 @@ alembic==1.4.2 appdirs==1.4.4 astroid==2.4.2 +asyncpg==0.21.0 attrs==19.3.0 black==19.10b0 click==7.1.2 +contextvars==2.4 +dataclasses==0.8 flake8==3.8.3 future==0.18.2 +gino==1.0.1 +greenlet==1.0a1 +immutables==0.14 importlib-metadata==1.7.0 isort==4.3.21 Jinja2==2.11.2 @@ -36,7 +42,7 @@ python-editor==1.0.4 PyYAML==5.3.1 regex==2020.7.14 six==1.15.0 -SQLAlchemy==1.3.18 +SQLAlchemy==1.3.22 SQLAlchemy-Utils==0.36.8 toml==0.10.1 tornado==6.0.4