diff --git a/.gitignore b/.gitignore index 364f8a8..c74175e 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,6 @@ dist/ *.egg-info/ .coverage + +# Ignore files +experiment/ diff --git a/common/utils.py b/common/utils.py index ff1a4fa..ff90e62 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,4 +1,5 @@ -from typing import Union +from typing import Union, List, Dict, Any +import click class DeserializerUtils: @@ -9,3 +10,54 @@ def convert_bytes_to_int(in_bytes: bytes) -> int: @staticmethod def convert_bytes_to_utf8(in_bytes: Union[bytes, bytearray]) -> str: return in_bytes.decode('utf-8') + + +def validate_db_configs(db_configs: List[Dict[str, Any]]) -> None: + """Validate database configurations for MultiDBEventProducer.""" + required_fields = { + 'pg_host': str, + 'pg_port': int, + 'pg_database': str, + 'pg_user': str, + 'pg_password': str, + 'pg_tables': str, + 'pg_replication_slot': str + } + + if not isinstance(db_configs, list): + raise click.BadParameter("db_configs must be a list of database configurations") + + if not db_configs: + raise click.BadParameter("db_configs cannot be empty") + + for idx, config in enumerate(db_configs): + if not isinstance(config, dict): + raise click.BadParameter(f"Configuration at index {idx} must be a dictionary") + + # Check required fields + for field, field_type in required_fields.items(): + if field not in config: + raise click.BadParameter(f"Missing required field '{field}' in configuration at index {idx}") + + if not isinstance(config[field], field_type): + raise click.BadParameter( + f"Field '{field}' in configuration at index {idx} " + f"must be of type {field_type.__name__}" + ) + + # Validate pg_port range + if not (1024 <= config['pg_port'] <= 65535): + raise click.BadParameter( + f"Invalid port number {config['pg_port']} in configuration at index {idx}. " + "Port must be between 1024 and 65535" + ) + + # Validate pg_tables format + tables = config['pg_tables'].split(',') + for table in tables: + table = table.strip() + if not table or '.' not in table: + raise click.BadParameter( + f"Invalid table format '{table}' in configuration at index {idx}. " + "Format should be 'schema.table'" + ) diff --git a/producer/event_producer.py b/producer/event_producer.py index a7ef083..3ed8a73 100644 --- a/producer/event_producer.py +++ b/producer/event_producer.py @@ -1,7 +1,8 @@ import json +import threading from abc import ABC -from typing import Type, Union +from typing import Type, Union, List import psycopg2 from psycopg2.extras import LogicalReplicationConnection @@ -221,3 +222,31 @@ def check_shutdown(self): if self.__shutdown and self.__db_conn: logger.warning('Shutting down...') self.__db_conn.close() + + +class MultiDBEventProducer: + def __init__(self, db_configs: List[dict], **common_kwargs): + self.producers = [] + self.threads = [] + + for db_config in db_configs: + config = {**common_kwargs, **db_config} + producer = EventProducer(**config) + self.producers.append(producer) + + def start(self): + for producer in self.producers: + thread = threading.Thread(target=self._run_producer, args=(producer,)) + self.threads.append(thread) + thread.start() + + def _run_producer(self, producer: EventProducer): + producer.connect() + producer.start_consuming() + + def shutdown(self): + for producer in self.producers: + producer.shutdown() + + for thread in self.threads: + thread.join() diff --git a/producer/main.py b/producer/main.py index efd3dde..b4101f7 100644 --- a/producer/main.py +++ b/producer/main.py @@ -1,4 +1,5 @@ import os +import json import signal import click @@ -6,7 +7,8 @@ from common import log from common.event import BaseEvent from common.qconnector import RabbitMQConnector -from producer.event_producer import EventProducer +from common.utils import validate_db_configs +from producer.event_producer import EventProducer, MultiDBEventProducer logger = log.get_logger(__name__) @@ -45,3 +47,28 @@ def produce(pg_host, pg_port, pg_database, pg_user, pg_password, pg_replication_ p.connect() p.start_consuming() + + + +@click.command() +@click.option('--db_configs', default=lambda: os.environ.get('DB_CONFIGS', None), required=True, help='DB Configs') +@click.option('--rabbitmq_url', default=lambda: os.environ.get('RABBITMQ_URL', None), required=True, help='RabbitMQ url ($RABBITMQ_URL)') +@click.option('--rabbitmq_exchange', default=lambda: os.environ.get('RABBITMQ_EXCHANGE', None), required=True, help='RabbitMQ exchange ($RABBITMQ_EXCHANGE)') +def produce_multiple_dbs(db_configs, rabbitmq_url, rabbitmq_exchange): + try: + db_configs = json.loads(db_configs) + validate_db_configs(db_configs) + except json.JSONDecodeError: + raise click.BadParameter("db_configs must be a valid JSON string") + + common_kwargs = { + 'qconnector_cls': RabbitMQConnector, + 'event_cls': BaseEvent, + 'pg_output_plugin': 'wal2json', + 'pg_publication_name': 'events', + 'rabbitmq_url': rabbitmq_url, + 'rabbitmq_exchange': rabbitmq_exchange + } + + multi_producer = MultiDBEventProducer(db_configs, **common_kwargs) + multi_producer.start() diff --git a/requirements.txt b/requirements.txt index c450072..777cb3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ iniconfig==1.1.1 isort==4.3.21 lazy-object-proxy==1.7.1 mccabe==0.6.1 -packaging==21.3 +packaging==22.0 pika==1.1.0 pluggy==1.0.0 psycopg2==2.8.6 diff --git a/setup.py b/setup.py index 7ab617d..4ddc747 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ entry_points=''' [console_scripts] producer=producer.main:produce + producer_multiple_dbs=producer.main:produce_multiple_dbs event_logger=consumer.main:consume ''' )