Skip to content

Commit

Permalink
refactor: smtp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
s-aga-r committed Jan 8, 2025
1 parent db158b1 commit 735a947
Showing 1 changed file with 182 additions and 47 deletions.
229 changes: 182 additions & 47 deletions mail/smtp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,77 @@
import time
from collections.abc import Generator
from contextlib import contextmanager
from queue import Queue
from smtplib import SMTP, SMTP_SSL
from threading import Lock
from smtplib import SMTP, SMTP_SSL, SMTPServerDisconnected
from threading import Lock, Thread


class SMTPConnectionLimitError(Exception):
pass


class SMTPConnection:
def __init__(
self,
host: str,
port: int,
username: str,
password: str,
use_ssl: bool,
use_tls: bool,
inactivity_timeout: int,
session_duration: int,
max_emails: int,
) -> None:
self.__created_at = time.time()
self.__inactivity_timeout = inactivity_timeout
self.__session_duration = session_duration
self.__max_emails = max_emails
self.__email_count = 0

self.session = self.__create_connection(host, port, username, password, use_ssl, use_tls)
self.host = host
self.port = port
self.username = username
self.last_used = time.time()

def __create_connection(
self, host: str, port: int, username: str, password: str, use_ssl: bool, use_tls: bool
) -> SMTP | SMTP_SSL:
_SMTP = SMTP_SSL if use_ssl else SMTP
session = _SMTP(host, port)
if use_tls:
session.ehlo()
session.starttls()
session.ehlo()
session.login(username, password)
return session

def is_active(self) -> bool:
try:
self.session.noop()
return True
except (SMTPServerDisconnected, OSError):
return False

def is_valid(self) -> bool:
current_time = time.time()
expired = (
current_time - self.last_used > self.__inactivity_timeout
or current_time - self.__created_at > self.__session_duration
or self.__email_count >= self.__max_emails
)
return not expired and self.is_active()

def increment_email_count(self) -> None:
self.__email_count += 1
self.last_used = time.time()

def close(self) -> None:
try:
self.session.quit()
except (SMTPServerDisconnected, OSError):
pass


class SMTPConnectionPool:
Expand All @@ -12,71 +81,114 @@ class SMTPConnectionPool:
def __new__(cls, *args, **kwargs) -> "SMTPConnectionPool":
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
cls._instance = super().__new__(cls)
cls._instance._pools = {}
cls._instance._pool_lock = Lock()
cls._instance._running = False
return cls._instance

def __init__(self, max_connections: int) -> None:
if hasattr(self, "_initialized"):
return

self.max_connections = max_connections
self._initialized = True
self._running = True
self._cleanup_interval = 60
self._cleanup_thread = None
self._initialize_cleanup_thread()

def get_connection(
self,
host: str,
port: int,
username: str,
password: str,
use_ssl: bool = False,
use_tls: bool = False,
max_connections: int = 5,
) -> type[SMTP] | type[SMTP_SSL]:
use_ssl: bool,
use_tls: bool,
inactivity_timeout: int,
session_duration: int,
max_emails: int,
) -> "SMTPConnection":
key = (host, port, username)
with self._pool_lock:
if key not in self._pools:
self._pools[key] = Queue(max_connections)
self._pools[key] = Queue(self.max_connections)
self._running = True
self._initialize_cleanup_thread()

pool = self._pools[key]

with Lock():
if not pool.empty():
return pool.get()
if pool.qsize() < max_connections:
return self._create_connection(host, port, username, password, use_ssl, use_tls)
raise Exception(f"SMTP connection pool limit reached for {key}")

def return_connection(
self, host: str, port: int, username: str, connection: type[SMTP] | type[SMTP_SSL]
) -> None:
key = (host, port, username)
while not pool.empty():
connection: SMTPConnection = pool.get()
if connection.is_valid():
connection.last_used = time.time()
return connection
else:
connection.close()

if pool.qsize() < self.max_connections:
return SMTPConnection(
host,
port,
username,
password,
use_ssl,
use_tls,
inactivity_timeout,
session_duration,
max_emails,
)

raise SMTPConnectionLimitError(
f"SMTP connection pool limit ({self.max_connections}) reached for {key}"
)

def return_connection(self, connection: SMTPConnection) -> None:
key = (connection.host, connection.port, connection.username)
with self._pool_lock:
if key in self._pools:
pool = self._pools[key]
with Lock():
if pool.qsize() < pool.maxsize:
if connection.is_active() and pool.qsize() < pool.maxsize:
pool.put(connection)
return
connection.quit()
connection.close()

def close_all(self) -> None:
def close_all_connections(self) -> None:
with self._pool_lock:
for pool in list(self._pools.values()):
while not pool.empty():
connection = pool.get()
connection.quit()
connection: SMTPConnection = pool.get()
connection.close()
self._pools.clear()
self._running = False
self._stop_cleanup_thread()

@staticmethod
def _create_connection(
host: str, port: int, username: str, password: str, use_ssl: bool, use_tls: bool
) -> type[SMTP] | type[SMTP_SSL]:
_SMTP = SMTP_SSL if use_ssl else SMTP

connection = _SMTP(host, port)
def _initialize_cleanup_thread(self) -> None:
if self._running and self._cleanup_thread is None:
self._cleanup_thread = Thread(target=self._cleanup_stale_connections, daemon=True)
self._cleanup_thread.start()

if use_tls:
connection.ehlo()
connection.starttls()
connection.ehlo()
def _cleanup_stale_connections(self) -> None:
while self._running:
time.sleep(self._cleanup_interval)
with self._pool_lock:
for key, pool in self._pools.items():
valid_connections = Queue(self.max_connections)
while not pool.empty():
connection: SMTPConnection = pool.get()
if connection.is_valid():
valid_connections.put(connection)
else:
connection.close()
self._pools[key] = valid_connections

connection.login(username, password)
return connection
def _stop_cleanup_thread(self) -> None:
if self._cleanup_thread and self._cleanup_thread.is_alive():
self._cleanup_thread.join()
self._cleanup_thread = None


class SMTPContext:
Expand All @@ -88,27 +200,43 @@ def __init__(
password: str,
use_ssl: bool = False,
use_tls: bool = False,
inactivity_timeout: int = 300,
session_duration: int = 600,
max_emails: int = 10,
max_connections: int = 5,
) -> None:
self._pool = SMTPConnectionPool()
self._pool = SMTPConnectionPool(max_connections)
self._host = host
self._port = port
self._username = username
self._password = password
self._use_ssl = use_ssl
self._use_tls = use_tls
self._inactivity_timeout = inactivity_timeout
self._session_duration = session_duration
self._max_emails = max_emails
self._connection = None

def __enter__(self) -> SMTP | SMTP_SSL:
self._connection = self._pool.get_connection(
self._host, self._port, self._username, self._password, self._use_ssl, self._use_tls
self._connection: SMTPConnection = self._pool.get_connection(
self._host,
self._port,
self._username,
self._password,
self._use_ssl,
self._use_tls,
self._inactivity_timeout,
self._session_duration,
self._max_emails,
)
return self._connection
return self._connection.session

def __exit__(self, exc_type, exc_value, traceback) -> None:
if exc_type is not None:
self._connection.quit()
self._connection.close()
else:
self._pool.return_connection(self._host, self._port, self._username, self._connection)
self._connection.increment_email_count()
self._pool.return_connection(self._connection)


@contextmanager
Expand All @@ -119,12 +247,19 @@ def smtp_server(
password: str,
use_ssl: bool = False,
use_tls: bool = False,
inactivity_timeout: int = 300,
session_duration: int = 600,
max_emails: int = 10,
max_connections: int = 5,
) -> Generator[type[SMTP] | type[SMTP_SSL], None, None]:
pool = SMTPConnectionPool()
connection = pool.get_connection(host, port, username, password, use_ssl, use_tls)
_pool = SMTPConnectionPool(max_connections)
_connection: SMTPConnection = _pool.get_connection(
host, port, username, password, use_ssl, use_tls, inactivity_timeout, session_duration, max_emails
)

try:
yield connection
yield _connection.session
finally:
if connection:
pool.return_connection(host, port, username, connection)
if _connection:
_connection.increment_email_count()
_pool.return_connection(_connection)

0 comments on commit 735a947

Please sign in to comment.