Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for multi DB producer #17

Merged
merged 5 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ dist/
*.egg-info/

.coverage

# Ignore files
experiment/
54 changes: 53 additions & 1 deletion common/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union
from typing import Union, List, Dict, Any
import click


class DeserializerUtils:
Expand All @@ -9,3 +10,54 @@ def convert_bytes_to_int(in_bytes: bytes) -> int:
@staticmethod
def convert_bytes_to_utf8(in_bytes: Union[bytes, bytearray]) -> str:
return in_bytes.decode('utf-8')


def validate_db_configs(db_configs: List[Dict[str, Any]]) -> None:
"""Validate database configurations for MultiDBEventProducer."""
required_fields = {
'pg_host': str,
'pg_port': int,
'pg_database': str,
'pg_user': str,
'pg_password': str,
'pg_tables': str,
'pg_replication_slot': str
}

if not isinstance(db_configs, list):
raise click.BadParameter("db_configs must be a list of database configurations")

if not db_configs:
raise click.BadParameter("db_configs cannot be empty")

for idx, config in enumerate(db_configs):
if not isinstance(config, dict):
raise click.BadParameter(f"Configuration at index {idx} must be a dictionary")

# Check required fields
for field, field_type in required_fields.items():
if field not in config:
raise click.BadParameter(f"Missing required field '{field}' in configuration at index {idx}")

if not isinstance(config[field], field_type):
raise click.BadParameter(
f"Field '{field}' in configuration at index {idx} "
f"must be of type {field_type.__name__}"
)

# Validate pg_port range
if not (1024 <= config['pg_port'] <= 65535):
raise click.BadParameter(
f"Invalid port number {config['pg_port']} in configuration at index {idx}. "
"Port must be between 1024 and 65535"
)

# Validate pg_tables format
tables = config['pg_tables'].split(',')
for table in tables:
table = table.strip()
if not table or '.' not in table:
raise click.BadParameter(
f"Invalid table format '{table}' in configuration at index {idx}. "
"Format should be 'schema.table'"
)
31 changes: 30 additions & 1 deletion producer/event_producer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import threading

from abc import ABC
from typing import Type, Union
from typing import Type, Union, List

import psycopg2
from psycopg2.extras import LogicalReplicationConnection
Expand Down Expand Up @@ -221,3 +222,31 @@ def check_shutdown(self):
if self.__shutdown and self.__db_conn:
logger.warning('Shutting down...')
self.__db_conn.close()


class MultiDBEventProducer:
def __init__(self, db_configs: List[dict], **common_kwargs):
self.producers = []
self.threads = []

for db_config in db_configs:
config = {**common_kwargs, **db_config}
producer = EventProducer(**config)
self.producers.append(producer)

def start(self):
for producer in self.producers:
thread = threading.Thread(target=self._run_producer, args=(producer,))
self.threads.append(thread)
thread.start()

def _run_producer(self, producer: EventProducer):
producer.connect()
producer.start_consuming()

def shutdown(self):
for producer in self.producers:
producer.shutdown()

for thread in self.threads:
thread.join()
31 changes: 30 additions & 1 deletion producer/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import json
import signal

import click

from common import log
from common.event import BaseEvent
from common.qconnector import RabbitMQConnector
from producer.event_producer import EventProducer
from common.utils import validate_db_configs
from producer.event_producer import EventProducer, MultiDBEventProducer

logger = log.get_logger(__name__)

Expand Down Expand Up @@ -45,3 +47,30 @@ def produce(pg_host, pg_port, pg_database, pg_user, pg_password, pg_replication_

p.connect()
p.start_consuming()



@click.command()
@click.option('--db_configs', default=lambda: os.environ.get('DB_CONFIGS', None), required=True, help='DB Configs')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For multiple dbs this looks a little dirty
I would prefer that we get all this config from a yaml file.

Eg. what if for different producer I want different settings altogether

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is okay for now, but please fix this later

@click.option('--pg_output_plugin', default=lambda: os.environ.get('PGOUTPUTPLUGIN', 'wal2json'), required=True, help='Postgresql Output Plugin ($PGOUTPUTPLUGIN)')
@click.option('--pg_publication_name', default=lambda: os.environ.get('PGPUBLICATION', None), required=False, help='Restrict to specific publications e.g. events')
@click.option('--rabbitmq_url', default=lambda: os.environ.get('RABBITMQ_URL', None), required=True, help='RabbitMQ url ($RABBITMQ_URL)')
@click.option('--rabbitmq_exchange', default=lambda: os.environ.get('RABBITMQ_EXCHANGE', None), required=True, help='RabbitMQ exchange ($RABBITMQ_EXCHANGE)')
def produce_multiple_dbs(db_configs, pg_output_plugin, pg_publication_name, rabbitmq_url, rabbitmq_exchange):
try:
db_configs = json.loads(db_configs)
validate_db_configs(db_configs)
except json.JSONDecodeError:
raise click.BadParameter("db_configs must be a valid JSON string")

common_kwargs = {
'qconnector_cls': RabbitMQConnector,
'event_cls': BaseEvent,
'pg_output_plugin': pg_output_plugin,
'pg_publication_name': pg_publication_name,
'rabbitmq_url': rabbitmq_url,
'rabbitmq_exchange': rabbitmq_exchange
}

multi_producer = MultiDBEventProducer(db_configs, **common_kwargs)
multi_producer.start()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ iniconfig==1.1.1
isort==4.3.21
lazy-object-proxy==1.7.1
mccabe==0.6.1
packaging==21.3
packaging==22.0
pika==1.1.0
pluggy==1.0.0
psycopg2==2.8.6
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
entry_points='''
[console_scripts]
producer=producer.main:produce
producer_multiple_dbs=producer.main:produce_multiple_dbs
event_logger=consumer.main:consume
'''
)
Loading