diff --git a/.gitignore b/.gitignore index 364f8a8..c74175e 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,6 @@ dist/ *.egg-info/ .coverage + +# Ignore files +experiment/ diff --git a/common/utils.py b/common/utils.py index ff1a4fa..6b33227 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,4 +1,5 @@ -from typing import Union +from typing import Union, List, Dict, Any +import click class DeserializerUtils: @@ -9,3 +10,51 @@ def convert_bytes_to_int(in_bytes: bytes) -> int: @staticmethod def convert_bytes_to_utf8(in_bytes: Union[bytes, bytearray]) -> str: return in_bytes.decode('utf-8') + + +def validate_db_configs(db_configs: List[Dict[str, Any]]) -> None: + """Validate database configurations for MultiDBEventProducer.""" + required_fields = { + 'pg_host': str, + 'pg_port': int, + 'pg_database': str, + 'pg_user': str, + 'pg_password': str, + 'pg_tables': str, + 'pg_replication_slot': str + } + + if not isinstance(db_configs, list): + raise click.BadParameter("db_configs must be a list of database configurations") + + if not db_configs: + raise click.BadParameter("db_configs cannot be empty") + + for idx, config in enumerate(db_configs): + # Check required fields + for field, field_type in required_fields.items(): + if field not in config: + raise click.BadParameter(f"Missing required field '{field}' in configuration at index {idx}") + + if not isinstance(config[field], field_type): + raise click.BadParameter( + f"Field '{field}' in configuration at index {idx} " + f"must be of type {field_type.__name__}" + ) + + # Validate pg_port range + if not (1024 <= config['pg_port'] <= 65535): + raise click.BadParameter( + f"Invalid port number {config['pg_port']} in configuration at index {idx}. " + "Port must be between 1024 and 65535" + ) + + # Validate pg_tables format + tables = config['pg_tables'].split(',') + for table in tables: + table = table.strip() + if not table or '.' not in table: + raise click.BadParameter( + f"Invalid table format '{table}' in configuration at index {idx}. " + "Format should be 'schema.table'" + ) diff --git a/docker-compose.yaml b/docker-compose.yaml index 8a8ff0d..fe329f2 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -41,7 +41,7 @@ services: command: ["bash", "-c", "producer"] volumes: - ./producer:/pgevents/producer - - ./test:/pgevents/test + - ./tests:/pgevents/tests depends_on: - database - rabbitmq diff --git a/producer/event_producer.py b/producer/event_producer.py index a7ef083..3ed8a73 100644 --- a/producer/event_producer.py +++ b/producer/event_producer.py @@ -1,7 +1,8 @@ import json +import threading from abc import ABC -from typing import Type, Union +from typing import Type, Union, List import psycopg2 from psycopg2.extras import LogicalReplicationConnection @@ -221,3 +222,31 @@ def check_shutdown(self): if self.__shutdown and self.__db_conn: logger.warning('Shutting down...') self.__db_conn.close() + + +class MultiDBEventProducer: + def __init__(self, db_configs: List[dict], **common_kwargs): + self.producers = [] + self.threads = [] + + for db_config in db_configs: + config = {**common_kwargs, **db_config} + producer = EventProducer(**config) + self.producers.append(producer) + + def start(self): + for producer in self.producers: + thread = threading.Thread(target=self._run_producer, args=(producer,)) + self.threads.append(thread) + thread.start() + + def _run_producer(self, producer: EventProducer): + producer.connect() + producer.start_consuming() + + def shutdown(self): + for producer in self.producers: + producer.shutdown() + + for thread in self.threads: + thread.join() diff --git a/producer/main.py b/producer/main.py index efd3dde..6f72418 100644 --- a/producer/main.py +++ b/producer/main.py @@ -1,4 +1,5 @@ import os +import json import signal import click @@ -6,7 +7,8 @@ from common import log from common.event import BaseEvent from common.qconnector import RabbitMQConnector -from producer.event_producer import EventProducer +from common.utils import validate_db_configs +from producer.event_producer import EventProducer, MultiDBEventProducer logger = log.get_logger(__name__) @@ -45,3 +47,30 @@ def produce(pg_host, pg_port, pg_database, pg_user, pg_password, pg_replication_ p.connect() p.start_consuming() + + + +@click.command() +@click.option('--db_configs', default=lambda: os.environ.get('DB_CONFIGS', None), required=True, help='DB Configs') +@click.option('--pg_output_plugin', default=lambda: os.environ.get('PGOUTPUTPLUGIN', 'wal2json'), required=True, help='Postgresql Output Plugin ($PGOUTPUTPLUGIN)') +@click.option('--pg_publication_name', default=lambda: os.environ.get('PGPUBLICATION', None), required=False, help='Restrict to specific publications e.g. events') +@click.option('--rabbitmq_url', default=lambda: os.environ.get('RABBITMQ_URL', None), required=True, help='RabbitMQ url ($RABBITMQ_URL)') +@click.option('--rabbitmq_exchange', default=lambda: os.environ.get('RABBITMQ_EXCHANGE', None), required=True, help='RabbitMQ exchange ($RABBITMQ_EXCHANGE)') +def produce_multiple_dbs(db_configs, pg_output_plugin, pg_publication_name, rabbitmq_url, rabbitmq_exchange): + try: + db_configs = json.loads(db_configs) + validate_db_configs(db_configs) + except json.JSONDecodeError: + raise click.BadParameter("db_configs must be a valid JSON string") + + common_kwargs = { + 'qconnector_cls': RabbitMQConnector, + 'event_cls': BaseEvent, + 'pg_output_plugin': pg_output_plugin, + 'pg_publication_name': pg_publication_name, + 'rabbitmq_url': rabbitmq_url, + 'rabbitmq_exchange': rabbitmq_exchange + } + + multi_producer = MultiDBEventProducer(db_configs, **common_kwargs) + multi_producer.start() diff --git a/requirements.txt b/requirements.txt index c450072..777cb3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ iniconfig==1.1.1 isort==4.3.21 lazy-object-proxy==1.7.1 mccabe==0.6.1 -packaging==21.3 +packaging==22.0 pika==1.1.0 pluggy==1.0.0 psycopg2==2.8.6 diff --git a/setup.py b/setup.py index 7ab617d..e0962cd 100644 --- a/setup.py +++ b/setup.py @@ -2,13 +2,14 @@ setup( name='pgevents', - version='0.1', + version='0.2', packages=find_packages( include=['src', 'producer', 'consumer'] ), entry_points=''' [console_scripts] producer=producer.main:produce + producer_multiple_dbs=producer.main:produce_multiple_dbs event_logger=consumer.main:consume ''' ) diff --git a/tests/test_event_producer.py b/tests/test_event_producer.py index 2e829ea..577d807 100644 --- a/tests/test_event_producer.py +++ b/tests/test_event_producer.py @@ -1,9 +1,14 @@ import json +import click +from click.testing import CliRunner from unittest import mock +import pytest from common.event import base_event from common.qconnector.rabbitmq_connector import RabbitMQConnector +from common.utils import validate_db_configs from producer.event_producer import EventProducer +from producer.main import produce_multiple_dbs # Test init @@ -125,3 +130,159 @@ def test_start_consuming(mock_producer): mock_producer._EventProducer__pg_output_plugin = 'wal2json' mock_producer._EventProducer__db_cur.consume_stream.call_args[1]['consume']('test') mock_producer.wal2json_msg_processor.assert_called_once() + + +def test_valid_single_db_config(): + config = [{ + 'pg_host': 'localhost', + 'pg_port': 5432, + 'pg_database': 'test_db', + 'pg_user': 'postgres', + 'pg_password': 'secret', + 'pg_tables': 'public.users', + 'pg_replication_slot': 'test_slot' + }] + assert validate_db_configs(config) == None + + +def test_valid_multiple_db_configs(): + configs = [ + { + 'pg_host': 'localhost', + 'pg_port': 5432, + 'pg_database': 'db1', + 'pg_user': 'postgres', + 'pg_password': 'secret', + 'pg_tables': 'public.users', + 'pg_replication_slot': 'slot1' + }, + { + 'pg_host': 'localhost', + 'pg_port': 5433, + 'pg_database': 'db2', + 'pg_user': 'postgres', + 'pg_password': 'secret', + 'pg_tables': 'public.orders', + 'pg_replication_slot': 'slot2' + } + ] + assert validate_db_configs(configs) == None + + +def test_empty_config_list(): + with pytest.raises(click.BadParameter, match="db_configs cannot be empty"): + validate_db_configs([]) + + +def test_non_list_input(): + with pytest.raises(click.BadParameter, match="db_configs must be a list"): + validate_db_configs({'some': 'dict'}) + + +def test_missing_required_field(): + config = [{ + 'pg_host': 'localhost', + 'pg_port': 5432 + }] + with pytest.raises(click.BadParameter, match="Missing required field"): + validate_db_configs(config) + + +def test_invalid_field_type(): + config = [{ + 'pg_host': 'localhost', + 'pg_port': '5432', + 'pg_database': 'test_db', + 'pg_user': 'postgres', + 'pg_password': 'secret', + 'pg_tables': 'public.users', + 'pg_replication_slot': 'test_slot' + }] + with pytest.raises(click.BadParameter, match="must be of type int"): + validate_db_configs(config) + + +def test_invalid_port_range(): + config = [{ + 'pg_host': 'localhost', + 'pg_port': 80, + 'pg_database': 'test_db', + 'pg_user': 'postgres', + 'pg_password': 'secret', + 'pg_tables': 'public.users', + 'pg_replication_slot': 'test_slot' + }] + with pytest.raises(click.BadParameter, match="Invalid port number"): + validate_db_configs(config) + + +def test_invalid_table_format(): + config = [{ + 'pg_host': 'localhost', + 'pg_port': 5432, + 'pg_database': 'test_db', + 'pg_user': 'postgres', + 'pg_password': 'secret', + 'pg_tables': 'invalid_format', + 'pg_replication_slot': 'test_slot' + }] + with pytest.raises(click.BadParameter, match="Invalid table format"): + validate_db_configs(config) + + +@pytest.fixture +def valid_db_configs(): + return json.dumps([{ + 'pg_host': 'database', + 'pg_port': 5432, + 'pg_database': 'dummy', + 'pg_user': 'postgres', + 'pg_password': 'postgres', + 'pg_tables': 'public.users', + 'pg_replication_slot': 'events' + }]) + +def test_produce_multiple_dbs_success(valid_db_configs): + runner = CliRunner() + result = runner.invoke(produce_multiple_dbs, [ + '--db_configs', valid_db_configs, + '--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0', + '--rabbitmq_exchange', 'pgevents_exchange' + ]) + + assert result.exit_code == 0 + +def test_produce_multiple_dbs_invalid_json(): + runner = CliRunner() + result = runner.invoke(produce_multiple_dbs, [ + '--db_configs', 'invalid-json', + '--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0', + '--rabbitmq_exchange', 'pgevents_exchange' + ]) + + assert "db_configs must be a valid JSON string" in str(result.__dict__) + +@mock.patch('producer.main.validate_db_configs') +def test_produce_multiple_dbs_invalid_config(mock_validate, valid_db_configs): + mock_validate.side_effect = click.BadParameter("Invalid config") + + runner = CliRunner() + result = runner.invoke(produce_multiple_dbs, [ + '--db_configs', valid_db_configs, + '--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0', + '--rabbitmq_exchange', 'pgevents_exchange' + ]) + + assert "Invalid config" in str(result.__dict__) + +def test_produce_multiple_dbs_common_kwargs(valid_db_configs): + runner = CliRunner() + result = runner.invoke(produce_multiple_dbs, [ + '--db_configs', valid_db_configs, + '--pg_output_plugin', 'test_plugin', + '--pg_publication_name', 'test_pub', + '--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0', + '--rabbitmq_exchange', 'pgevents_exchange' + ]) + + assert result.exit_code == 0