Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for multi DB producer #17

Merged
merged 5 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ dist/
*.egg-info/

.coverage

# Ignore files
experiment/
51 changes: 50 additions & 1 deletion common/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union
from typing import Union, List, Dict, Any
import click


class DeserializerUtils:
Expand All @@ -9,3 +10,51 @@ def convert_bytes_to_int(in_bytes: bytes) -> int:
@staticmethod
def convert_bytes_to_utf8(in_bytes: Union[bytes, bytearray]) -> str:
return in_bytes.decode('utf-8')


def validate_db_configs(db_configs: List[Dict[str, Any]]) -> None:
"""Validate database configurations for MultiDBEventProducer."""
required_fields = {
'pg_host': str,
'pg_port': int,
'pg_database': str,
'pg_user': str,
'pg_password': str,
'pg_tables': str,
'pg_replication_slot': str
}

if not isinstance(db_configs, list):
raise click.BadParameter("db_configs must be a list of database configurations")

if not db_configs:
raise click.BadParameter("db_configs cannot be empty")

for idx, config in enumerate(db_configs):
# Check required fields
for field, field_type in required_fields.items():
if field not in config:
raise click.BadParameter(f"Missing required field '{field}' in configuration at index {idx}")

if not isinstance(config[field], field_type):
raise click.BadParameter(
f"Field '{field}' in configuration at index {idx} "
f"must be of type {field_type.__name__}"
)

# Validate pg_port range
if not (1024 <= config['pg_port'] <= 65535):
raise click.BadParameter(
f"Invalid port number {config['pg_port']} in configuration at index {idx}. "
"Port must be between 1024 and 65535"
)

# Validate pg_tables format
tables = config['pg_tables'].split(',')
for table in tables:
table = table.strip()
if not table or '.' not in table:
raise click.BadParameter(
f"Invalid table format '{table}' in configuration at index {idx}. "
"Format should be 'schema.table'"
)
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ services:
command: ["bash", "-c", "producer"]
volumes:
- ./producer:/pgevents/producer
- ./test:/pgevents/test
- ./tests:/pgevents/tests
depends_on:
- database
- rabbitmq
Expand Down
31 changes: 30 additions & 1 deletion producer/event_producer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import threading

from abc import ABC
from typing import Type, Union
from typing import Type, Union, List

import psycopg2
from psycopg2.extras import LogicalReplicationConnection
Expand Down Expand Up @@ -221,3 +222,31 @@ def check_shutdown(self):
if self.__shutdown and self.__db_conn:
logger.warning('Shutting down...')
self.__db_conn.close()


class MultiDBEventProducer:
def __init__(self, db_configs: List[dict], **common_kwargs):
self.producers = []
self.threads = []

for db_config in db_configs:
config = {**common_kwargs, **db_config}
producer = EventProducer(**config)
self.producers.append(producer)

def start(self):
for producer in self.producers:
thread = threading.Thread(target=self._run_producer, args=(producer,))
self.threads.append(thread)
thread.start()

def _run_producer(self, producer: EventProducer):
producer.connect()
producer.start_consuming()

def shutdown(self):
for producer in self.producers:
producer.shutdown()

for thread in self.threads:
thread.join()
31 changes: 30 additions & 1 deletion producer/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import json
import signal

import click

from common import log
from common.event import BaseEvent
from common.qconnector import RabbitMQConnector
from producer.event_producer import EventProducer
from common.utils import validate_db_configs
from producer.event_producer import EventProducer, MultiDBEventProducer

logger = log.get_logger(__name__)

Expand Down Expand Up @@ -45,3 +47,30 @@ def produce(pg_host, pg_port, pg_database, pg_user, pg_password, pg_replication_

p.connect()
p.start_consuming()



@click.command()
@click.option('--db_configs', default=lambda: os.environ.get('DB_CONFIGS', None), required=True, help='DB Configs')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For multiple dbs this looks a little dirty
I would prefer that we get all this config from a yaml file.

Eg. what if for different producer I want different settings altogether

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is okay for now, but please fix this later

@click.option('--pg_output_plugin', default=lambda: os.environ.get('PGOUTPUTPLUGIN', 'wal2json'), required=True, help='Postgresql Output Plugin ($PGOUTPUTPLUGIN)')
@click.option('--pg_publication_name', default=lambda: os.environ.get('PGPUBLICATION', None), required=False, help='Restrict to specific publications e.g. events')
@click.option('--rabbitmq_url', default=lambda: os.environ.get('RABBITMQ_URL', None), required=True, help='RabbitMQ url ($RABBITMQ_URL)')
@click.option('--rabbitmq_exchange', default=lambda: os.environ.get('RABBITMQ_EXCHANGE', None), required=True, help='RabbitMQ exchange ($RABBITMQ_EXCHANGE)')
def produce_multiple_dbs(db_configs, pg_output_plugin, pg_publication_name, rabbitmq_url, rabbitmq_exchange):
try:
db_configs = json.loads(db_configs)
validate_db_configs(db_configs)
except json.JSONDecodeError:
raise click.BadParameter("db_configs must be a valid JSON string")

common_kwargs = {
'qconnector_cls': RabbitMQConnector,
'event_cls': BaseEvent,
'pg_output_plugin': pg_output_plugin,
'pg_publication_name': pg_publication_name,
'rabbitmq_url': rabbitmq_url,
'rabbitmq_exchange': rabbitmq_exchange
}

multi_producer = MultiDBEventProducer(db_configs, **common_kwargs)
multi_producer.start()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ iniconfig==1.1.1
isort==4.3.21
lazy-object-proxy==1.7.1
mccabe==0.6.1
packaging==21.3
packaging==22.0
pika==1.1.0
pluggy==1.0.0
psycopg2==2.8.6
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

setup(
name='pgevents',
version='0.1',
version='0.2',
packages=find_packages(
include=['src', 'producer', 'consumer']
),
entry_points='''
[console_scripts]
producer=producer.main:produce
producer_multiple_dbs=producer.main:produce_multiple_dbs
event_logger=consumer.main:consume
'''
)
161 changes: 161 additions & 0 deletions tests/test_event_producer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import json
import click
from click.testing import CliRunner
from unittest import mock
import pytest

from common.event import base_event
from common.qconnector.rabbitmq_connector import RabbitMQConnector
from common.utils import validate_db_configs
from producer.event_producer import EventProducer
from producer.main import produce_multiple_dbs


# Test init
Expand Down Expand Up @@ -125,3 +130,159 @@ def test_start_consuming(mock_producer):
mock_producer._EventProducer__pg_output_plugin = 'wal2json'
mock_producer._EventProducer__db_cur.consume_stream.call_args[1]['consume']('test')
mock_producer.wal2json_msg_processor.assert_called_once()


def test_valid_single_db_config():
config = [{
'pg_host': 'localhost',
'pg_port': 5432,
'pg_database': 'test_db',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'public.users',
'pg_replication_slot': 'test_slot'
}]
assert validate_db_configs(config) == None


def test_valid_multiple_db_configs():
configs = [
{
'pg_host': 'localhost',
'pg_port': 5432,
'pg_database': 'db1',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'public.users',
'pg_replication_slot': 'slot1'
},
{
'pg_host': 'localhost',
'pg_port': 5433,
'pg_database': 'db2',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'public.orders',
'pg_replication_slot': 'slot2'
}
]
assert validate_db_configs(configs) == None


def test_empty_config_list():
with pytest.raises(click.BadParameter, match="db_configs cannot be empty"):
validate_db_configs([])


def test_non_list_input():
with pytest.raises(click.BadParameter, match="db_configs must be a list"):
validate_db_configs({'some': 'dict'})


def test_missing_required_field():
config = [{
'pg_host': 'localhost',
'pg_port': 5432
}]
with pytest.raises(click.BadParameter, match="Missing required field"):
validate_db_configs(config)


def test_invalid_field_type():
config = [{
'pg_host': 'localhost',
'pg_port': '5432',
'pg_database': 'test_db',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'public.users',
'pg_replication_slot': 'test_slot'
}]
with pytest.raises(click.BadParameter, match="must be of type int"):
validate_db_configs(config)


def test_invalid_port_range():
config = [{
'pg_host': 'localhost',
'pg_port': 80,
'pg_database': 'test_db',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'public.users',
'pg_replication_slot': 'test_slot'
}]
with pytest.raises(click.BadParameter, match="Invalid port number"):
validate_db_configs(config)


def test_invalid_table_format():
config = [{
'pg_host': 'localhost',
'pg_port': 5432,
'pg_database': 'test_db',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'invalid_format',
'pg_replication_slot': 'test_slot'
}]
with pytest.raises(click.BadParameter, match="Invalid table format"):
validate_db_configs(config)


@pytest.fixture
def valid_db_configs():
return json.dumps([{
'pg_host': 'database',
'pg_port': 5432,
'pg_database': 'dummy',
'pg_user': 'postgres',
'pg_password': 'postgres',
'pg_tables': 'public.users',
'pg_replication_slot': 'events'
}])

def test_produce_multiple_dbs_success(valid_db_configs):
runner = CliRunner()
result = runner.invoke(produce_multiple_dbs, [
'--db_configs', valid_db_configs,
'--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0',
'--rabbitmq_exchange', 'pgevents_exchange'
])

assert result.exit_code == 0

def test_produce_multiple_dbs_invalid_json():
runner = CliRunner()
result = runner.invoke(produce_multiple_dbs, [
'--db_configs', 'invalid-json',
'--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0',
'--rabbitmq_exchange', 'pgevents_exchange'
])

assert "db_configs must be a valid JSON string" in str(result.__dict__)

@mock.patch('producer.main.validate_db_configs')
def test_produce_multiple_dbs_invalid_config(mock_validate, valid_db_configs):
mock_validate.side_effect = click.BadParameter("Invalid config")

runner = CliRunner()
result = runner.invoke(produce_multiple_dbs, [
'--db_configs', valid_db_configs,
'--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0',
'--rabbitmq_exchange', 'pgevents_exchange'
])

assert "Invalid config" in str(result.__dict__)

def test_produce_multiple_dbs_common_kwargs(valid_db_configs):
runner = CliRunner()
result = runner.invoke(produce_multiple_dbs, [
'--db_configs', valid_db_configs,
'--pg_output_plugin', 'test_plugin',
'--pg_publication_name', 'test_pub',
'--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0',
'--rabbitmq_exchange', 'pgevents_exchange'
])

assert result.exit_code == 0
Loading