Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slack alert throttle #15

Merged
merged 14 commits into from
Nov 18, 2024
8 changes: 8 additions & 0 deletions files/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Virtual Environments
venv/
.venv/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
3 changes: 3 additions & 0 deletions files/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
psycopg2-binary>=2.9,<3.0
PyYAML>=6.0,<7.0
requests>=2.32,<3.0
58 changes: 58 additions & 0 deletions files/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import unittest
neoformit marked this conversation as resolved.
Show resolved Hide resolved
from unittest.mock import patch
from datetime import datetime, timedelta
import pathlib
import tempfile
from walle import NotificationHistory

SLACK_NOTIFY_PERIOD_DAYS = 7


class TestNotificationHistory(unittest.TestCase):
def setUp(self):
self.temp_file = tempfile.NamedTemporaryFile(delete=False)
self.record = NotificationHistory(self.temp_file.name)

def tearDown(self):
pathlib.Path(self.temp_file.name).unlink(missing_ok=True)

def test_contains_new_entry(self):
jwd = "unique_id_1"
self.assertFalse(
self.record.contains(jwd), "New entry should initially return False"
)
self.assertTrue(self.record.contains(jwd), "Duplicate entry should return True")

def test_contains_existing_entry(self):
jwd = "existing_id"
self.record._write_record(jwd)
self.assertTrue(self.record.contains(jwd), "Existing entry should return True")

@patch("walle.SLACK_NOTIFY_PERIOD_DAYS", new=SLACK_NOTIFY_PERIOD_DAYS)
def test_truncate_old_records(self):
old_jwd = "old_entry"
recent_jwd = "recent_entry"
old_date = datetime.now() - timedelta(days=SLACK_NOTIFY_PERIOD_DAYS + 1)
recent_date = datetime.now()

with open(self.temp_file.name, "a") as f:
f.write(f"{old_date.isoformat()}\t{old_jwd}\n")
f.write(f"{recent_date.isoformat()}\t{recent_jwd}\n")

self.record._truncate_records()
self.assertFalse(self.record.contains(old_jwd), "Old entry should be purged")
self.assertTrue(self.record.contains(recent_jwd), "Recent entry should remain")

def test_purge_invalid_records(self):
with open(self.temp_file.name, "w") as f:
f.write("invalid_date\tinvalid_path\n")

with patch("walle.logger.warning") as mock_warning:
self.record._read_records()
mock_warning.assert_called()

self.assertFalse(self.record._get_jwds(), "Invalid records should be purged")


if __name__ == "__main__":
unittest.main()
84 changes: 83 additions & 1 deletion files/walle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import sys
import time
import zlib
from typing import Dict, List
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Dict, List, Union

import galaxy_jwd
import requests
Expand All @@ -33,6 +35,9 @@
If you think your account was deleted due to an error, please contact
"""
ONLY_ONE_INSTANCE = "The other must be an instance of the Severity class"

# Number of days before repeating slack alert for the same JWD
SLACK_NOTIFY_PERIOD_DAYS = 7
SLACK_URL = "https://slack.com/api/chat.postMessage"

UserId = str
Expand All @@ -46,6 +51,9 @@
)
logger = logging.getLogger(__name__)
GXADMIN_PATH = os.getenv("GXADMIN_PATH", "/usr/local/bin/gxadmin")
NOTIFICATION_HISTORY_FILE = os.getenv(
"WALLE_NOTIFICATION_HISTORY_FILE", "/tmp/walle-notifications.txt"
)


def convert_arg_to_byte(mb: str) -> int:
Expand All @@ -56,6 +64,76 @@ def convert_arg_to_seconds(hours: str) -> float:
return float(hours) * 60 * 60


@dataclass
class Record:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, thanks a lot 😊

date: str
jwd: str

def __post_init__(self):
if not (
isinstance(self.date, str) and isinstance(self.jwd, (str, pathlib.Path))
):
raise ValueError
self.jwd = str(self.jwd)
datetime.fromisoformat(self.date) # will raise ValueError if invalid


class NotificationHistory:
"""Record of Slack notifications to avoid spamming users."""

def __init__(self, record_file: str) -> None:
self.record_file = pathlib.Path(record_file)
if not self.record_file.exists():
self.record_file.touch()
self._truncate_records()

def _get_jwds(self) -> List[str]:
return [record.jwd for record in self._read_records()]

def _read_records(self) -> List[Record]:
try:
with open(self.record_file, "r") as f:
records = [
Record(*line.strip().split("\t"))
for line in f.readlines()
if line.strip()
]
except ValueError:
logger.warning(
f"Invalid records found in {self.record_file}. The"
" file will be purged. This may result in duplicate Slack"
" notifications."
)
self._purge_records()
return []
return records

def _write_record(self, jwd: str) -> None:
with open(self.record_file, "a") as f:
f.write(f"{datetime.now()}\t{jwd}\n")

def _purge_records(self) -> None:
self.record_file.unlink()
self.record_file.touch()

def _truncate_records(self) -> None:
"""Truncate older records."""
records = self._read_records()
with open(self.record_file, "w") as f:
for record in records:
if datetime.fromisoformat(record.date) > datetime.now() - timedelta(
days=SLACK_NOTIFY_PERIOD_DAYS
):
f.write(f"{record.date}\t{record.jwd}\n")

def contains(self, jwd: Union[pathlib.Path, str]) -> bool:
jwd = str(jwd)
exists = jwd in self._get_jwds()
if not exists:
self._write_record(jwd)
return exists


class Severity:
def __init__(self, number: int, name: str):
self.value = number
Expand Down Expand Up @@ -87,6 +165,7 @@ def __ge__(self, other) -> bool:


VALID_SEVERITIES = (Severity(0, "LOW"), Severity(1, "MEDIUM"), Severity(2, "HIGH"))
notification_history = NotificationHistory(NOTIFICATION_HISTORY_FILE)


def convert_str_to_severity(test_level: str) -> Severity:
Expand Down Expand Up @@ -406,6 +485,9 @@ def report_matching_malware(self):
)

def post_slack_alert(self):
if notification_history.contains(self.job.jwd):
logger.debug("Skipping Slack notification - already posted for this JWD")
return
msg = f"""
:rotating_light: WALLE: *Malware detected* :rotating_light:

Expand Down
Loading