Skip to content

Commit

Permalink
feat: Support for multi DB producer
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwin1111 committed Dec 13, 2024
1 parent 060d03f commit c698a95
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ dist/
*.egg-info/

.coverage

# Ignore files
experiment/
54 changes: 53 additions & 1 deletion common/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union
from typing import Union, List, Dict, Any
import click


class DeserializerUtils:
Expand All @@ -9,3 +10,54 @@ def convert_bytes_to_int(in_bytes: bytes) -> int:
@staticmethod
def convert_bytes_to_utf8(in_bytes: Union[bytes, bytearray]) -> str:
return in_bytes.decode('utf-8')


def validate_db_configs(db_configs: List[Dict[str, Any]]) -> None:
"""Validate database configurations for MultiDBEventProducer."""
required_fields = {
'pg_host': str,
'pg_port': int,
'pg_database': str,
'pg_user': str,
'pg_password': str,
'pg_tables': str,
'pg_replication_slot': str
}

if not isinstance(db_configs, list):
raise click.BadParameter("db_configs must be a list of database configurations")

if not db_configs:
raise click.BadParameter("db_configs cannot be empty")

for idx, config in enumerate(db_configs):
if not isinstance(config, dict):
raise click.BadParameter(f"Configuration at index {idx} must be a dictionary")

# Check required fields
for field, field_type in required_fields.items():
if field not in config:
raise click.BadParameter(f"Missing required field '{field}' in configuration at index {idx}")

if not isinstance(config[field], field_type):
raise click.BadParameter(
f"Field '{field}' in configuration at index {idx} "
f"must be of type {field_type.__name__}"
)

# Validate pg_port range
if not (1024 <= config['pg_port'] <= 65535):
raise click.BadParameter(
f"Invalid port number {config['pg_port']} in configuration at index {idx}. "
"Port must be between 1024 and 65535"
)

# Validate pg_tables format
tables = config['pg_tables'].split(',')
for table in tables:
table = table.strip()
if not table or '.' not in table:
raise click.BadParameter(
f"Invalid table format '{table}' in configuration at index {idx}. "
"Format should be 'schema.table'"
)
31 changes: 30 additions & 1 deletion producer/event_producer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import threading

from abc import ABC
from typing import Type, Union
from typing import Type, Union, List

import psycopg2
from psycopg2.extras import LogicalReplicationConnection
Expand Down Expand Up @@ -221,3 +222,31 @@ def check_shutdown(self):
if self.__shutdown and self.__db_conn:
logger.warning('Shutting down...')
self.__db_conn.close()


class MultiDBEventProducer:
def __init__(self, db_configs: List[dict], **common_kwargs):
self.producers = []
self.threads = []

for db_config in db_configs:
config = {**common_kwargs, **db_config}
producer = EventProducer(**config)
self.producers.append(producer)

def start(self):
for producer in self.producers:
thread = threading.Thread(target=self._run_producer, args=(producer,))
self.threads.append(thread)
thread.start()

def _run_producer(self, producer: EventProducer):
producer.connect()
producer.start_consuming()

def shutdown(self):
for producer in self.producers:
producer.shutdown()

for thread in self.threads:
thread.join()
29 changes: 28 additions & 1 deletion producer/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import json
import signal

import click

from common import log
from common.event import BaseEvent
from common.qconnector import RabbitMQConnector
from producer.event_producer import EventProducer
from common.utils import validate_db_configs
from producer.event_producer import EventProducer, MultiDBEventProducer

logger = log.get_logger(__name__)

Expand Down Expand Up @@ -45,3 +47,28 @@ def produce(pg_host, pg_port, pg_database, pg_user, pg_password, pg_replication_

p.connect()
p.start_consuming()



@click.command()
@click.option('--db_configs', default=lambda: os.environ.get('DB_CONFIGS', None), required=True, help='DB Configs')
@click.option('--rabbitmq_url', default=lambda: os.environ.get('RABBITMQ_URL', None), required=True, help='RabbitMQ url ($RABBITMQ_URL)')
@click.option('--rabbitmq_exchange', default=lambda: os.environ.get('RABBITMQ_EXCHANGE', None), required=True, help='RabbitMQ exchange ($RABBITMQ_EXCHANGE)')
def produce_multiple_dbs(db_configs, rabbitmq_url, rabbitmq_exchange):
try:
db_configs = json.loads(db_configs)
validate_db_configs(db_configs)
except json.JSONDecodeError:
raise click.BadParameter("db_configs must be a valid JSON string")

common_kwargs = {
'qconnector_cls': RabbitMQConnector,
'event_cls': BaseEvent,
'pg_output_plugin': 'wal2json',
'pg_publication_name': 'events',
'rabbitmq_url': rabbitmq_url,
'rabbitmq_exchange': rabbitmq_exchange
}

multi_producer = MultiDBEventProducer(db_configs, **common_kwargs)
multi_producer.start()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ iniconfig==1.1.1
isort==4.3.21
lazy-object-proxy==1.7.1
mccabe==0.6.1
packaging==21.3
packaging==22.0
pika==1.1.0
pluggy==1.0.0
psycopg2==2.8.6
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
entry_points='''
[console_scripts]
producer=producer.main:produce
producer_multiple_dbs=producer.main:produce_multiple_dbs
event_logger=consumer.main:consume
'''
)

0 comments on commit c698a95

Please sign in to comment.