Skip to content

Commit

Permalink
feat: Support for multi DB producer (#17)
Browse files Browse the repository at this point in the history
* feat: Support for multi DB producer

* get publication and plugin

* add tests

* bump version

* add tests
  • Loading branch information
ashwin1111 authored Dec 23, 2024
1 parent 060d03f commit 5c50c5f
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ dist/
*.egg-info/

.coverage

# Ignore files
experiment/
51 changes: 50 additions & 1 deletion common/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union
from typing import Union, List, Dict, Any
import click


class DeserializerUtils:
Expand All @@ -9,3 +10,51 @@ def convert_bytes_to_int(in_bytes: bytes) -> int:
@staticmethod
def convert_bytes_to_utf8(in_bytes: Union[bytes, bytearray]) -> str:
return in_bytes.decode('utf-8')


def validate_db_configs(db_configs: List[Dict[str, Any]]) -> None:
"""Validate database configurations for MultiDBEventProducer."""
required_fields = {
'pg_host': str,
'pg_port': int,
'pg_database': str,
'pg_user': str,
'pg_password': str,
'pg_tables': str,
'pg_replication_slot': str
}

if not isinstance(db_configs, list):
raise click.BadParameter("db_configs must be a list of database configurations")

if not db_configs:
raise click.BadParameter("db_configs cannot be empty")

for idx, config in enumerate(db_configs):
# Check required fields
for field, field_type in required_fields.items():
if field not in config:
raise click.BadParameter(f"Missing required field '{field}' in configuration at index {idx}")

if not isinstance(config[field], field_type):
raise click.BadParameter(
f"Field '{field}' in configuration at index {idx} "
f"must be of type {field_type.__name__}"
)

# Validate pg_port range
if not (1024 <= config['pg_port'] <= 65535):
raise click.BadParameter(
f"Invalid port number {config['pg_port']} in configuration at index {idx}. "
"Port must be between 1024 and 65535"
)

# Validate pg_tables format
tables = config['pg_tables'].split(',')
for table in tables:
table = table.strip()
if not table or '.' not in table:
raise click.BadParameter(
f"Invalid table format '{table}' in configuration at index {idx}. "
"Format should be 'schema.table'"
)
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ services:
command: ["bash", "-c", "producer"]
volumes:
- ./producer:/pgevents/producer
- ./test:/pgevents/test
- ./tests:/pgevents/tests
depends_on:
- database
- rabbitmq
Expand Down
31 changes: 30 additions & 1 deletion producer/event_producer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import threading

from abc import ABC
from typing import Type, Union
from typing import Type, Union, List

import psycopg2
from psycopg2.extras import LogicalReplicationConnection
Expand Down Expand Up @@ -221,3 +222,31 @@ def check_shutdown(self):
if self.__shutdown and self.__db_conn:
logger.warning('Shutting down...')
self.__db_conn.close()


class MultiDBEventProducer:
def __init__(self, db_configs: List[dict], **common_kwargs):
self.producers = []
self.threads = []

for db_config in db_configs:
config = {**common_kwargs, **db_config}
producer = EventProducer(**config)
self.producers.append(producer)

def start(self):
for producer in self.producers:
thread = threading.Thread(target=self._run_producer, args=(producer,))
self.threads.append(thread)
thread.start()

def _run_producer(self, producer: EventProducer):
producer.connect()
producer.start_consuming()

def shutdown(self):
for producer in self.producers:
producer.shutdown()

for thread in self.threads:
thread.join()
31 changes: 30 additions & 1 deletion producer/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import json
import signal

import click

from common import log
from common.event import BaseEvent
from common.qconnector import RabbitMQConnector
from producer.event_producer import EventProducer
from common.utils import validate_db_configs
from producer.event_producer import EventProducer, MultiDBEventProducer

logger = log.get_logger(__name__)

Expand Down Expand Up @@ -45,3 +47,30 @@ def produce(pg_host, pg_port, pg_database, pg_user, pg_password, pg_replication_

p.connect()
p.start_consuming()



@click.command()
@click.option('--db_configs', default=lambda: os.environ.get('DB_CONFIGS', None), required=True, help='DB Configs')
@click.option('--pg_output_plugin', default=lambda: os.environ.get('PGOUTPUTPLUGIN', 'wal2json'), required=True, help='Postgresql Output Plugin ($PGOUTPUTPLUGIN)')
@click.option('--pg_publication_name', default=lambda: os.environ.get('PGPUBLICATION', None), required=False, help='Restrict to specific publications e.g. events')
@click.option('--rabbitmq_url', default=lambda: os.environ.get('RABBITMQ_URL', None), required=True, help='RabbitMQ url ($RABBITMQ_URL)')
@click.option('--rabbitmq_exchange', default=lambda: os.environ.get('RABBITMQ_EXCHANGE', None), required=True, help='RabbitMQ exchange ($RABBITMQ_EXCHANGE)')
def produce_multiple_dbs(db_configs, pg_output_plugin, pg_publication_name, rabbitmq_url, rabbitmq_exchange):
try:
db_configs = json.loads(db_configs)
validate_db_configs(db_configs)
except json.JSONDecodeError:
raise click.BadParameter("db_configs must be a valid JSON string")

common_kwargs = {
'qconnector_cls': RabbitMQConnector,
'event_cls': BaseEvent,
'pg_output_plugin': pg_output_plugin,
'pg_publication_name': pg_publication_name,
'rabbitmq_url': rabbitmq_url,
'rabbitmq_exchange': rabbitmq_exchange
}

multi_producer = MultiDBEventProducer(db_configs, **common_kwargs)
multi_producer.start()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ iniconfig==1.1.1
isort==4.3.21
lazy-object-proxy==1.7.1
mccabe==0.6.1
packaging==21.3
packaging==22.0
pika==1.1.0
pluggy==1.0.0
psycopg2==2.8.6
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

setup(
name='pgevents',
version='0.1',
version='0.2',
packages=find_packages(
include=['src', 'producer', 'consumer']
),
entry_points='''
[console_scripts]
producer=producer.main:produce
producer_multiple_dbs=producer.main:produce_multiple_dbs
event_logger=consumer.main:consume
'''
)
161 changes: 161 additions & 0 deletions tests/test_event_producer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import json
import click
from click.testing import CliRunner
from unittest import mock
import pytest

from common.event import base_event
from common.qconnector.rabbitmq_connector import RabbitMQConnector
from common.utils import validate_db_configs
from producer.event_producer import EventProducer
from producer.main import produce_multiple_dbs


# Test init
Expand Down Expand Up @@ -125,3 +130,159 @@ def test_start_consuming(mock_producer):
mock_producer._EventProducer__pg_output_plugin = 'wal2json'
mock_producer._EventProducer__db_cur.consume_stream.call_args[1]['consume']('test')
mock_producer.wal2json_msg_processor.assert_called_once()


def test_valid_single_db_config():
config = [{
'pg_host': 'localhost',
'pg_port': 5432,
'pg_database': 'test_db',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'public.users',
'pg_replication_slot': 'test_slot'
}]
assert validate_db_configs(config) == None


def test_valid_multiple_db_configs():
configs = [
{
'pg_host': 'localhost',
'pg_port': 5432,
'pg_database': 'db1',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'public.users',
'pg_replication_slot': 'slot1'
},
{
'pg_host': 'localhost',
'pg_port': 5433,
'pg_database': 'db2',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'public.orders',
'pg_replication_slot': 'slot2'
}
]
assert validate_db_configs(configs) == None


def test_empty_config_list():
with pytest.raises(click.BadParameter, match="db_configs cannot be empty"):
validate_db_configs([])


def test_non_list_input():
with pytest.raises(click.BadParameter, match="db_configs must be a list"):
validate_db_configs({'some': 'dict'})


def test_missing_required_field():
config = [{
'pg_host': 'localhost',
'pg_port': 5432
}]
with pytest.raises(click.BadParameter, match="Missing required field"):
validate_db_configs(config)


def test_invalid_field_type():
config = [{
'pg_host': 'localhost',
'pg_port': '5432',
'pg_database': 'test_db',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'public.users',
'pg_replication_slot': 'test_slot'
}]
with pytest.raises(click.BadParameter, match="must be of type int"):
validate_db_configs(config)


def test_invalid_port_range():
config = [{
'pg_host': 'localhost',
'pg_port': 80,
'pg_database': 'test_db',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'public.users',
'pg_replication_slot': 'test_slot'
}]
with pytest.raises(click.BadParameter, match="Invalid port number"):
validate_db_configs(config)


def test_invalid_table_format():
config = [{
'pg_host': 'localhost',
'pg_port': 5432,
'pg_database': 'test_db',
'pg_user': 'postgres',
'pg_password': 'secret',
'pg_tables': 'invalid_format',
'pg_replication_slot': 'test_slot'
}]
with pytest.raises(click.BadParameter, match="Invalid table format"):
validate_db_configs(config)


@pytest.fixture
def valid_db_configs():
return json.dumps([{
'pg_host': 'database',
'pg_port': 5432,
'pg_database': 'dummy',
'pg_user': 'postgres',
'pg_password': 'postgres',
'pg_tables': 'public.users',
'pg_replication_slot': 'events'
}])

def test_produce_multiple_dbs_success(valid_db_configs):
runner = CliRunner()
result = runner.invoke(produce_multiple_dbs, [
'--db_configs', valid_db_configs,
'--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0',
'--rabbitmq_exchange', 'pgevents_exchange'
])

assert result.exit_code == 0

def test_produce_multiple_dbs_invalid_json():
runner = CliRunner()
result = runner.invoke(produce_multiple_dbs, [
'--db_configs', 'invalid-json',
'--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0',
'--rabbitmq_exchange', 'pgevents_exchange'
])

assert "db_configs must be a valid JSON string" in str(result.__dict__)

@mock.patch('producer.main.validate_db_configs')
def test_produce_multiple_dbs_invalid_config(mock_validate, valid_db_configs):
mock_validate.side_effect = click.BadParameter("Invalid config")

runner = CliRunner()
result = runner.invoke(produce_multiple_dbs, [
'--db_configs', valid_db_configs,
'--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0',
'--rabbitmq_exchange', 'pgevents_exchange'
])

assert "Invalid config" in str(result.__dict__)

def test_produce_multiple_dbs_common_kwargs(valid_db_configs):
runner = CliRunner()
result = runner.invoke(produce_multiple_dbs, [
'--db_configs', valid_db_configs,
'--pg_output_plugin', 'test_plugin',
'--pg_publication_name', 'test_pub',
'--rabbitmq_url', 'amqp://admin:password@rabbitmq:5672/?heartbeat=0',
'--rabbitmq_exchange', 'pgevents_exchange'
])

assert result.exit_code == 0

1 comment on commit 5c50c5f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
common
   compression.py5180%9
   log.py90100% 
   utils.py280100% 
common/qconnector
   rabbitmq_connector.py451371%23–24, 37–49, 52–65
consumer
   event_consumer.py340100% 
pgoutput_parser
   base.py610100% 
   delete.py140100% 
   insert.py150100% 
   relation.py270100% 
   update.py190100% 
producer
   event_producer.py1481590%118–122, 215–217, 223–224, 245, 248–252
TOTAL4052993% 

Tests Skipped Failures Errors Time
30 0 💤 0 ❌ 0 🔥 0.244s ⏱️

Please sign in to comment.