diff --git a/.github/workflows/linting-and-tests.yml b/.github/workflows/linting-and-tests.yml index fc43b57276..23688595e6 100644 --- a/.github/workflows/linting-and-tests.yml +++ b/.github/workflows/linting-and-tests.yml @@ -244,6 +244,7 @@ jobs: grafana_version: - 10.3.0 - 11.2.0 + - latest fail-fast: false with: grafana_version: ${{ matrix.grafana_version }} diff --git a/Tiltfile b/Tiltfile index 264424161c..00d7ec4189 100644 --- a/Tiltfile +++ b/Tiltfile @@ -32,12 +32,23 @@ def plugin_json(): return plugin_file return 'NOT_A_PLUGIN' +def extra_grafana_ini(): + return { + 'feature_toggles': { + 'accessControlOnCall': 'false' + } + } + def extra_env(): return { "GF_APP_URL": grafana_url, "GF_SERVER_ROOT_URL": grafana_url, "GF_FEATURE_TOGGLES_ENABLE": "externalServiceAccounts", - "ONCALL_API_URL": "http://oncall-dev-engine:8080" + "ONCALL_API_URL": "http://oncall-dev-engine:8080", + + # Enables managed service accounts for plugin authentication in Grafana >= 11.3 + # https://grafana.com/docs/grafana/latest/setup-grafana/configure-grafana/#managed_service_accounts_enabled + "GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED": "true", } def extra_deps(): @@ -132,7 +143,16 @@ def load_grafana(): "GF_APP_URL": grafana_url, # older versions of grafana need this "GF_SERVER_ROOT_URL": grafana_url, "GF_FEATURE_TOGGLES_ENABLE": "externalServiceAccounts", - "ONCALL_API_URL": "http://oncall-dev-engine:8080" + "ONCALL_API_URL": "http://oncall-dev-engine:8080", + + # Enables managed service accounts for plugin authentication in Grafana >= 11.3 + # https://grafana.com/docs/grafana/latest/setup-grafana/configure-grafana/#managed_service_accounts_enabled + "GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED": "true", + }, + extra_grafana_ini={ + "feature_toggles": { + "accessControlOnCall": "false" + } }, ) # --- GRAFANA END ---- diff --git a/dev/helm-local.yml b/dev/helm-local.yml index 33a28790c6..770a5dfb0c 100644 --- a/dev/helm-local.yml +++ b/dev/helm-local.yml @@ -39,7 +39,7 @@ engine: replicaCount: 1 celery: replicaCount: 1 - worker_beat_enabled: false + worker_beat_enabled: true externalGrafana: url: http://grafana:3000 @@ -47,6 +47,8 @@ externalGrafana: grafana: enabled: false grafana.ini: + feature_toggles: + accessControlOnCall: false server: domain: localhost:3000 root_url: "%(protocol)s://%(domain)s" @@ -71,6 +73,7 @@ grafana: value: oncallpassword env: GF_FEATURE_TOGGLES_ENABLE: externalServiceAccounts + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true GF_SECURITY_ADMIN_PASSWORD: oncall GF_SECURITY_ADMIN_USER: oncall GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app diff --git a/docker-compose-developer.yml b/docker-compose-developer.yml index b751ab1e98..ee668df794 100644 --- a/docker-compose-developer.yml +++ b/docker-compose-developer.yml @@ -324,6 +324,7 @@ services: GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app GF_FEATURE_TOGGLES_ENABLE: externalServiceAccounts ONCALL_API_URL: http://host.docker.internal:8080 + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true env_file: - ./dev/.env.${DB}.dev ports: diff --git a/docker-compose-mysql-rabbitmq.yml b/docker-compose-mysql-rabbitmq.yml index f587902e76..60b320e80f 100644 --- a/docker-compose-mysql-rabbitmq.yml +++ b/docker-compose-mysql-rabbitmq.yml @@ -144,6 +144,7 @@ services: GF_SECURITY_ADMIN_PASSWORD: ${GRAFANA_PASSWORD:-admin} GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app GF_INSTALL_PLUGINS: grafana-oncall-app + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true deploy: resources: limits: @@ -156,7 +157,16 @@ services: condition: service_healthy profiles: - with_grafana + configs: + - source: grafana.ini + target: /etc/grafana/grafana.ini volumes: dbdata: rabbitmqdata: + +configs: + grafana.ini: + content: | + [feature_toggles] + accessControlOnCall = false diff --git a/docker-compose.yml b/docker-compose.yml index b115199f8c..c54c2fb33f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -94,6 +94,7 @@ services: GF_SECURITY_ADMIN_PASSWORD: ${GRAFANA_PASSWORD:-admin} GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app GF_INSTALL_PLUGINS: grafana-oncall-app + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true volumes: - grafana_data:/var/lib/grafana deploy: @@ -103,9 +104,18 @@ services: cpus: "0.5" profiles: - with_grafana + configs: + - source: grafana.ini + target: /etc/grafana/grafana.ini volumes: grafana_data: prometheus_data: oncall_data: redis_data: + +configs: + grafana.ini: + content: | + [feature_toggles] + accessControlOnCall = false diff --git a/docs/sources/configure/jinja2-templating/_index.md b/docs/sources/configure/jinja2-templating/_index.md index 6883785877..6cc158f7f8 100644 --- a/docs/sources/configure/jinja2-templating/_index.md +++ b/docs/sources/configure/jinja2-templating/_index.md @@ -23,8 +23,7 @@ refs: destination: /docs/grafana-cloud/alerting-and-irm/oncall/configure/integrations/references/webhook/ --- - -## Configure templates +# Configure templates Grafana OnCall integrates with your monitoring systems using webhooks with JSON payloads. By default, these webhooks deliver raw JSON payloads. diff --git a/engine/apps/alerts/migrations/0001_squashed_initial.py b/engine/apps/alerts/migrations/0001_squashed_initial.py index 0c96d7d4ad..8426d2635a 100644 --- a/engine/apps/alerts/migrations/0001_squashed_initial.py +++ b/engine/apps/alerts/migrations/0001_squashed_initial.py @@ -119,7 +119,7 @@ class Migration(migrations.Migration): name='AlertGroupPostmortem', fields=[ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('public_primary_key', models.CharField(default=apps.alerts.models.resolution_note.generate_public_primary_key_for_alert_group_postmortem, max_length=20, unique=True, validators=[django.core.validators.MinLengthValidator(13)])), + ('public_primary_key', models.CharField(max_length=20, unique=True, validators=[django.core.validators.MinLengthValidator(13)])), ('created_at', models.DateTimeField(auto_now_add=True)), ('last_modified', models.DateTimeField(auto_now=True)), ('text', models.TextField(default=None, max_length=3000, null=True)), diff --git a/engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py b/engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py new file mode 100644 index 0000000000..306d8a0408 --- /dev/null +++ b/engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.15 on 2024-11-12 13:13 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0027_serviceaccount'), + ('alerts', '0064_migrate_resolutionnoteslackmessage_slack_channel_id'), + ] + + operations = [ + migrations.AddField( + model_name='alertreceivechannel', + name='service_account', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='alert_receive_channels', to='user_management.serviceaccount'), + ), + ] diff --git a/engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py b/engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py new file mode 100644 index 0000000000..03c5f53430 --- /dev/null +++ b/engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.16 on 2024-11-06 21:11 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('alerts', '0065_alertreceivechannel_service_account'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='channelfilter', + name='_slack_channel_id', + ), + migrations.RemoveField( + model_name='resolutionnoteslackmessage', + name='_slack_channel_id', + ), + migrations.DeleteModel( + name='AlertGroupPostmortem', + ), + ] diff --git a/engine/apps/alerts/models/alert_receive_channel.py b/engine/apps/alerts/models/alert_receive_channel.py index 4fd926ac47..7a351d2aad 100644 --- a/engine/apps/alerts/models/alert_receive_channel.py +++ b/engine/apps/alerts/models/alert_receive_channel.py @@ -234,6 +234,13 @@ class AlertReceiveChannel(IntegrationOptionsMixin, MaintainableObject): author = models.ForeignKey( "user_management.User", on_delete=models.SET_NULL, related_name="alert_receive_channels", blank=True, null=True ) + service_account = models.ForeignKey( + "user_management.ServiceAccount", + on_delete=models.SET_NULL, + related_name="alert_receive_channels", + blank=True, + null=True, + ) team = models.ForeignKey( "user_management.Team", on_delete=models.SET_NULL, @@ -518,29 +525,21 @@ def short_name(self): ) @property - def short_name_with_maintenance_status(self): - if self.maintenance_mode is not None: - return ( - self.short_name + f" *[ on " - f"{AlertReceiveChannel.MAINTENANCE_MODE_CHOICES[self.maintenance_mode][1]}" - f" :construction: ]*" - ) - else: - return self.short_name - - @property - def created_name(self): + def created_name(self) -> str: return f"{self.get_integration_display()} {self.smile_code}" @property def web_link(self) -> str: return UIURLBuilder(self.organization).integration_detail(self.public_primary_key) + @property + def is_maintenace_integration(self) -> bool: + return self.integration == AlertReceiveChannel.INTEGRATION_MAINTENANCE + @property def integration_url(self) -> str | None: if self.integration in [ AlertReceiveChannel.INTEGRATION_MANUAL, - AlertReceiveChannel.INTEGRATION_SLACK_CHANNEL, AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, AlertReceiveChannel.INTEGRATION_MAINTENANCE, ]: @@ -764,15 +763,16 @@ def listen_for_alertreceivechannel_model_save( from apps.heartbeat.models import IntegrationHeartBeat if created: - write_resource_insight_log(instance=instance, author=instance.author, event=EntityEvent.CREATED) + author = instance.author or instance.service_account + write_resource_insight_log(instance=instance, author=author, event=EntityEvent.CREATED) default_filter = ChannelFilter(alert_receive_channel=instance, filtering_term=None, is_default=True) default_filter.save() - write_resource_insight_log(instance=default_filter, author=instance.author, event=EntityEvent.CREATED) + write_resource_insight_log(instance=default_filter, author=author, event=EntityEvent.CREATED) TEN_MINUTES = 600 # this is timeout for cloud heartbeats if instance.is_available_for_integration_heartbeat: heartbeat = IntegrationHeartBeat.objects.create(alert_receive_channel=instance, timeout_seconds=TEN_MINUTES) - write_resource_insight_log(instance=heartbeat, author=instance.author, event=EntityEvent.CREATED) + write_resource_insight_log(instance=heartbeat, author=author, event=EntityEvent.CREATED) metrics_add_integrations_to_cache([instance], instance.organization) diff --git a/engine/apps/alerts/models/channel_filter.py b/engine/apps/alerts/models/channel_filter.py index f7cb302f7a..3ea2ea8bcb 100644 --- a/engine/apps/alerts/models/channel_filter.py +++ b/engine/apps/alerts/models/channel_filter.py @@ -69,9 +69,6 @@ class ChannelFilter(OrderedModel): notify_in_slack = models.BooleanField(null=True, default=True) notify_in_telegram = models.BooleanField(null=True, default=False) - - # TODO: remove _slack_channel_id in future release - _slack_channel_id = models.CharField(max_length=100, null=True, default=None) slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, @@ -79,7 +76,6 @@ class ChannelFilter(OrderedModel): on_delete=models.SET_NULL, related_name="+", ) - telegram_channel = models.ForeignKey( "telegram.TelegramToOrganizationConnector", on_delete=models.SET_NULL, diff --git a/engine/apps/alerts/models/resolution_note.py b/engine/apps/alerts/models/resolution_note.py index e2f3586a55..90e651662a 100644 --- a/engine/apps/alerts/models/resolution_note.py +++ b/engine/apps/alerts/models/resolution_note.py @@ -14,20 +14,7 @@ if typing.TYPE_CHECKING: from apps.alerts.models import AlertGroup from apps.slack.models import SlackChannel - - -def generate_public_primary_key_for_alert_group_postmortem(): - prefix = "P" - new_public_primary_key = generate_public_primary_key(prefix) - - failure_counter = 0 - while AlertGroupPostmortem.objects.filter(public_primary_key=new_public_primary_key).exists(): - new_public_primary_key = increase_public_primary_key_length( - failure_counter=failure_counter, prefix=prefix, model_name="AlertGroupPostmortem" - ) - failure_counter += 1 - - return new_public_primary_key + from apps.user_management.models import User def generate_public_primary_key_for_resolution_note(): @@ -75,9 +62,6 @@ class ResolutionNoteSlackMessage(models.Model): related_name="added_resolution_note_slack_messages", ) text = models.TextField(max_length=3000, default=None, null=True) - - # TODO: remove _slack_channel_id in future release - _slack_channel_id = models.CharField(max_length=100, null=True, default=None) slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, @@ -85,7 +69,6 @@ class ResolutionNoteSlackMessage(models.Model): on_delete=models.SET_NULL, related_name="+", ) - ts = models.CharField(max_length=100, null=True, default=None) thread_ts = models.CharField(max_length=100, null=True, default=None) permalink = models.CharField(max_length=250, null=True, default=None) @@ -130,6 +113,7 @@ def filter(self, *args, **kwargs): class ResolutionNote(models.Model): alert_group: "AlertGroup" + author: typing.Optional["User"] resolution_note_slack_message: typing.Optional[ResolutionNoteSlackMessage] objects = ResolutionNoteQueryset.as_manager() @@ -213,29 +197,11 @@ def render_log_line_json(self): return result - def author_verbal(self, mention): - """ - Postmortems to resolution notes included migrating AlertGroupPostmortem to ResolutionNotes. - But AlertGroupPostmortem has no author field. So this method was introduces as workaround. + def author_verbal(self, mention: bool) -> str: """ - if self.author is not None: - return self.author.get_username_with_slack_verbal(mention) - else: - return "" + Postmortems to resolution notes included migrating `AlertGroupPostmortem` to `ResolutionNote`s. + But `AlertGroupPostmortem` has no author field. So this method was introduced as a workaround. - -class AlertGroupPostmortem(models.Model): - public_primary_key = models.CharField( - max_length=20, - validators=[MinLengthValidator(settings.PUBLIC_PRIMARY_KEY_MIN_LENGTH + 1)], - unique=True, - default=generate_public_primary_key_for_alert_group_postmortem, - ) - alert_group = models.ForeignKey( - "alerts.AlertGroup", - on_delete=models.CASCADE, - related_name="postmortem_text", - ) - created_at = models.DateTimeField(auto_now_add=True) - last_modified = models.DateTimeField(auto_now=True) - text = models.TextField(max_length=3000, default=None, null=True) + (see git history for more details on what `AlertGroupPostmortem` was) + """ + return "" if self.author is None else self.author.get_username_with_slack_verbal(mention) diff --git a/engine/apps/alerts/tasks/check_escalation_finished.py b/engine/apps/alerts/tasks/check_escalation_finished.py index 9f3fb62d8c..8ae6d8146c 100644 --- a/engine/apps/alerts/tasks/check_escalation_finished.py +++ b/engine/apps/alerts/tasks/check_escalation_finished.py @@ -2,7 +2,9 @@ import typing import requests +from celery import uuid as celery_uuid from django.conf import settings +from django.core.cache import cache from django.db.models import Avg, F, Max, Q from django.utils import timezone @@ -174,6 +176,42 @@ def check_personal_notifications_task() -> None: task_logger.info(f"personal_notifications_triggered={triggered} personal_notifications_completed={completed}") +# Retries an alert group that has failed auditing if it is within the retry limit +# Returns whether an alert group escalation is being retried +def retry_audited_alert_group(alert_group) -> bool: + cache_key = f"audited-alert-group-retry-count-{alert_group.id}" + retry_count = cache.get(cache_key, 0) + if retry_count >= settings.AUDITED_ALERT_GROUP_MAX_RETRIES: + task_logger.info(f"Not retrying audited alert_group={alert_group.id} max retries exceeded.") + return False + + if alert_group.is_silenced_for_period: + task_logger.info(f"Not retrying audited alert_group={alert_group.id} as it is silenced.") + return False + + if not alert_group.escalation_snapshot: + task_logger.info(f"Not retrying audited alert_group={alert_group.id} as its escalation snapshot is empty.") + return False + + retry_count += 1 + cache.set(cache_key, retry_count, timeout=3600) + + task_id = celery_uuid() + alert_group.active_escalation_id = task_id + alert_group.save(update_fields=["active_escalation_id"]) + + from apps.alerts.tasks import escalate_alert_group + + escalate_alert_group.apply_async( + args=(alert_group.pk,), + immutable=True, + task_id=task_id, + eta=alert_group.next_step_eta, + ) + task_logger.info(f"Retrying audited alert_group={alert_group.id} attempt={retry_count}") + return True + + @shared_log_exception_on_failure_task def check_escalation_finished_task() -> None: """ @@ -221,7 +259,8 @@ def check_escalation_finished_task() -> None: try: audit_alert_group_escalation(alert_group) except AlertGroupEscalationPolicyExecutionAuditException: - alert_group_ids_that_failed_audit.append(str(alert_group.id)) + if not retry_audited_alert_group(alert_group): + alert_group_ids_that_failed_audit.append(str(alert_group.id)) failed_alert_groups_count = len(alert_group_ids_that_failed_audit) success_ratio = ( diff --git a/engine/apps/alerts/tests/test_check_escalation_finished_task.py b/engine/apps/alerts/tests/test_check_escalation_finished_task.py index 8aa5cbbdd9..229fabff49 100644 --- a/engine/apps/alerts/tests/test_check_escalation_finished_task.py +++ b/engine/apps/alerts/tests/test_check_escalation_finished_task.py @@ -6,12 +6,14 @@ from django.utils import timezone from apps.alerts.models import EscalationPolicy +from apps.alerts.tasks import escalate_alert_group from apps.alerts.tasks.check_escalation_finished import ( AlertGroupEscalationPolicyExecutionAuditException, audit_alert_group_escalation, check_alert_group_personal_notifications_task, check_escalation_finished_task, check_personal_notifications_task, + retry_audited_alert_group, send_alert_group_escalation_auditor_task_heartbeat, ) from apps.base.models import UserNotificationPolicy, UserNotificationPolicyLogRecord @@ -580,3 +582,124 @@ def test_check_escalation_finished_task_calls_audit_alert_group_personal_notific check_personal_notifications_task() assert "personal_notifications_triggered=6 personal_notifications_completed=2" in caplog.text + + +@patch("apps.alerts.tasks.check_escalation_finished.audit_alert_group_escalation") +@patch("apps.alerts.tasks.check_escalation_finished.retry_audited_alert_group") +@patch("apps.alerts.tasks.check_escalation_finished.send_alert_group_escalation_auditor_task_heartbeat") +@pytest.mark.django_db +def test_invoke_retry_from_check_escalation_finished_task( + mocked_send_alert_group_escalation_auditor_task_heartbeat, + mocked_retry_audited_alert_group, + mocked_audit_alert_group_escalation, + make_organization_and_user, + make_alert_receive_channel, + make_alert_group_that_started_at_specific_date, +): + organization, _ = make_organization_and_user() + alert_receive_channel = make_alert_receive_channel(organization) + + # Pass audit (should not be counted in final message or go to retry function) + alert_group1 = make_alert_group_that_started_at_specific_date(alert_receive_channel, received_delta=1) + # Fail audit but not retrying (should be counted in final message) + alert_group2 = make_alert_group_that_started_at_specific_date(alert_receive_channel, received_delta=5) + # Fail audit but retry (should not be counted in final message) + alert_group3 = make_alert_group_that_started_at_specific_date(alert_receive_channel, received_delta=10) + + def _mocked_audit_alert_group_escalation(alert_group): + if alert_group.id == alert_group2.id or alert_group.id == alert_group3.id: + raise AlertGroupEscalationPolicyExecutionAuditException(f"{alert_group2.id} failed audit") + + mocked_audit_alert_group_escalation.side_effect = _mocked_audit_alert_group_escalation + + def _mocked_retry_audited_alert_group(alert_group): + if alert_group.id == alert_group2.id: + return False + return True + + mocked_retry_audited_alert_group.side_effect = _mocked_retry_audited_alert_group + + with pytest.raises(AlertGroupEscalationPolicyExecutionAuditException) as exc: + check_escalation_finished_task() + + error_msg = str(exc.value) + + assert "The following alert group id(s) failed auditing:" in error_msg + assert str(alert_group1.id) not in error_msg + assert str(alert_group2.id) in error_msg + assert str(alert_group3.id) not in error_msg + + assert mocked_retry_audited_alert_group.call_count == 2 + mocked_send_alert_group_escalation_auditor_task_heartbeat.assert_not_called() + + +@patch.object(escalate_alert_group, "apply_async") +@override_settings(AUDITED_ALERT_GROUP_MAX_RETRIES=1) +@pytest.mark.django_db +def test_retry_audited_alert_group( + mocked_escalate_alert_group, + make_organization_and_user, + make_user_for_organization, + make_user_notification_policy, + make_escalation_chain, + make_escalation_policy, + make_channel_filter, + make_alert_receive_channel, + make_alert_group_that_started_at_specific_date, +): + organization, user = make_organization_and_user() + make_user_notification_policy( + user=user, + step=UserNotificationPolicy.Step.NOTIFY, + notify_by=UserNotificationPolicy.NotificationChannel.SLACK, + ) + + alert_receive_channel = make_alert_receive_channel(organization) + escalation_chain = make_escalation_chain(organization) + channel_filter = make_channel_filter(alert_receive_channel, escalation_chain=escalation_chain) + notify_to_multiple_users_step = make_escalation_policy( + escalation_chain=channel_filter.escalation_chain, + escalation_policy_step=EscalationPolicy.STEP_NOTIFY_MULTIPLE_USERS, + ) + notify_to_multiple_users_step.notify_to_users_queue.set([user]) + + alert_group1 = make_alert_group_that_started_at_specific_date(alert_receive_channel, channel_filter=channel_filter) + alert_group1.raw_escalation_snapshot = alert_group1.build_raw_escalation_snapshot() + alert_group1.raw_escalation_snapshot["last_active_escalation_policy_order"] = 1 + alert_group1.save() + + # Retry should occur + is_retrying = retry_audited_alert_group(alert_group1) + assert is_retrying + mocked_escalate_alert_group.assert_called() + mocked_escalate_alert_group.reset_mock() + + # No retry as attempts == max + is_retrying = retry_audited_alert_group(alert_group1) + assert not is_retrying + mocked_escalate_alert_group.assert_not_called() + mocked_escalate_alert_group.reset_mock() + + alert_group2 = make_alert_group_that_started_at_specific_date(alert_receive_channel, channel_filter=channel_filter) + # No retry because no escalation snapshot + is_retrying = retry_audited_alert_group(alert_group2) + assert not is_retrying + mocked_escalate_alert_group.assert_not_called() + mocked_escalate_alert_group.reset_mock() + + alert_group3 = make_alert_group_that_started_at_specific_date( + alert_receive_channel, + channel_filter=channel_filter, + silenced=True, + silenced_at=timezone.now(), + silenced_by_user=user, + silenced_until=(now + timezone.timedelta(hours=1)), + ) + alert_group3.raw_escalation_snapshot = alert_group1.build_raw_escalation_snapshot() + alert_group3.raw_escalation_snapshot["last_active_escalation_policy_order"] = 1 + alert_group3.save() + + # No retry because alert group silenced + is_retrying = retry_audited_alert_group(alert_group3) + assert not is_retrying + mocked_escalate_alert_group.assert_not_called() diff --git a/engine/apps/api/permissions.py b/engine/apps/api/permissions.py index 852506a109..d9dad6b37d 100644 --- a/engine/apps/api/permissions.py +++ b/engine/apps/api/permissions.py @@ -18,6 +18,7 @@ RBAC_PERMISSIONS_ATTR = "rbac_permissions" RBAC_OBJECT_PERMISSIONS_ATTR = "rbac_object_permissions" + ViewSetOrAPIView = typing.Union[ViewSet, APIView] diff --git a/engine/apps/api/tests/test_schedules.py b/engine/apps/api/tests/test_schedules.py index 4a29dc9dd6..8efcb6b236 100644 --- a/engine/apps/api/tests/test_schedules.py +++ b/engine/apps/api/tests/test_schedules.py @@ -1442,8 +1442,9 @@ def test_next_shifts_per_user( ("B", "UTC"), ("C", None), ("D", "America/Montevideo"), + ("E", None), ) - user_a, user_b, user_c, user_d = ( + user_a, user_b, user_c, user_d, user_e = ( make_user_for_organization(organization, username=i, _timezone=tz) for i, tz in users ) @@ -1469,8 +1470,7 @@ def test_next_shifts_per_user( ) on_call_shift.add_rolling_users([[user]]) - # override in the past: 17-18 / D - # won't be listed, but user D will still be included in the response + # override in the past, won't be listed: 17-18 / D override_data = { "start": tomorrow - timezone.timedelta(days=3), "rotation_start": tomorrow - timezone.timedelta(days=3), @@ -1483,6 +1483,7 @@ def test_next_shifts_per_user( override.add_rolling_users([[user_d]]) # override: 17-18 / C + # this is before C's shift, so it will be listed as upcoming override_data = { "start": tomorrow + timezone.timedelta(hours=17), "rotation_start": tomorrow + timezone.timedelta(hours=17), @@ -1494,11 +1495,26 @@ def test_next_shifts_per_user( ) override.add_rolling_users([[user_c]]) + # override: 17-18 / E + fifteend_days_later = tomorrow + timezone.timedelta(days=15) + override_data = { + "start": fifteend_days_later + timezone.timedelta(hours=17), + "rotation_start": fifteend_days_later + timezone.timedelta(hours=17), + "duration": timezone.timedelta(hours=1), + "schedule": schedule, + } + override = make_on_call_shift( + organization=organization, shift_type=CustomOnCallShift.TYPE_OVERRIDE, **override_data + ) + override.add_rolling_users([[user_e]]) + # final schedule: 7-12: B, 15-16: A, 16-17: B, 17-18: C (override), 18-20: C schedule.refresh_ical_final_schedule() url = reverse("api-internal:schedule-next-shifts-per-user", kwargs={"pk": schedule.public_primary_key}) - response = client.get(url, format="json", **make_user_auth_headers(admin, token)) + + # check for users with shifts in the next week + response = client.get(url + "?days=7", format="json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_200_OK expected = { @@ -1517,13 +1533,27 @@ def test_next_shifts_per_user( tomorrow + timezone.timedelta(hours=18), user_c.timezone, ), - user_d.public_primary_key: (None, None, user_d.timezone), } returned_data = { u: (ev.get("start"), ev.get("end"), ev.get("user_timezone")) for u, ev in response.data["users"].items() } assert returned_data == expected + # by default it will check for shifts in the next 45 days + response = client.get(url, format="json", **make_user_auth_headers(admin, token)) + assert response.status_code == status.HTTP_200_OK + + # include user E with the override + expected[user_e.public_primary_key] = ( + fifteend_days_later + timezone.timedelta(hours=17), + fifteend_days_later + timezone.timedelta(hours=18), + user_e.timezone, + ) + returned_data = { + u: (ev.get("start"), ev.get("end"), ev.get("user_timezone")) for u, ev in response.data["users"].items() + } + assert returned_data == expected + @pytest.mark.django_db def test_next_shifts_per_user_ical_schedule_using_emails( diff --git a/engine/apps/api/views/schedule.py b/engine/apps/api/views/schedule.py index 78635290de..e30aa8cbde 100644 --- a/engine/apps/api/views/schedule.py +++ b/engine/apps/api/views/schedule.py @@ -388,20 +388,22 @@ def filter_shift_swaps(self, request: Request, pk: str) -> Response: @action(detail=True, methods=["get"]) def next_shifts_per_user(self, request, pk): """Return next shift for users in schedule.""" + days = self.request.query_params.get("days") + days = int(days) if days else 30 now = timezone.now() - datetime_end = now + datetime.timedelta(days=30) + datetime_end = now + datetime.timedelta(days=days) schedule = self.get_object(annotate=False) + users = {} events = schedule.final_events(now, datetime_end) - - # include user TZ information for every user - users = {u.public_primary_key: {"user_timezone": u.timezone} for u in schedule.related_users()} + users_tz = {u.public_primary_key: u.timezone for u in schedule.related_users()} added_users = set() for e in events: - user = e["users"][0]["pk"] if e["users"] else None - if user is not None and user not in added_users and user in users and e["end"] > now: - users[user].update(e) - added_users.add(user) + user_ppk = e["users"][0]["pk"] if e["users"] else None + if user_ppk is not None and user_ppk not in users and user_ppk in users_tz and e["end"] > now: + users[user_ppk] = e + users[user_ppk]["user_timezone"] = users_tz[user_ppk] + added_users.add(user_ppk) result = {"users": users} return Response(result, status=status.HTTP_200_OK) diff --git a/engine/apps/auth_token/auth.py b/engine/apps/auth_token/auth.py index dc6ccf7ae0..3a7e25d6bd 100644 --- a/engine/apps/auth_token/auth.py +++ b/engine/apps/auth_token/auth.py @@ -9,7 +9,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_header from rest_framework.request import Request -from apps.api.permissions import GrafanaAPIPermissions, LegacyAccessControlRole from apps.grafana_plugin.helpers.gcom import check_token from apps.grafana_plugin.sync_data import SyncPermission, SyncUser from apps.user_management.exceptions import OrganizationDeletedException, OrganizationMovedException @@ -20,13 +19,13 @@ from .constants import GOOGLE_OAUTH2_AUTH_TOKEN_NAME, SCHEDULE_EXPORT_TOKEN_NAME, SLACK_AUTH_TOKEN_NAME from .exceptions import InvalidToken -from .grafana.grafana_auth_token import get_service_account_token_permissions from .models import ( ApiAuthToken, GoogleOAuth2Token, IntegrationBacksyncAuthToken, PluginAuthToken, ScheduleExportAuthToken, + ServiceAccountToken, SlackAuthToken, UserScheduleExportAuthToken, ) @@ -336,8 +335,8 @@ def authenticate_credentials( return auth_token.user, auth_token +X_GRAFANA_URL = "X-Grafana-URL" X_GRAFANA_INSTANCE_ID = "X-Grafana-Instance-ID" -GRAFANA_SA_PREFIX = "glsa_" class GrafanaServiceAccountAuthentication(BaseAuthentication): @@ -345,7 +344,7 @@ def authenticate(self, request): auth = get_authorization_header(request).decode("utf-8") if not auth: raise exceptions.AuthenticationFailed("Invalid token.") - if not auth.startswith(GRAFANA_SA_PREFIX): + if not auth.startswith(ServiceAccountToken.GRAFANA_SA_PREFIX): return None organization = self.get_organization(request) @@ -359,6 +358,13 @@ def authenticate(self, request): return self.authenticate_credentials(organization, auth) def get_organization(self, request): + grafana_url = request.headers.get(X_GRAFANA_URL) + if grafana_url: + organization = Organization.objects.filter(grafana_url=grafana_url).first() + if not organization: + raise exceptions.AuthenticationFailed("Invalid Grafana URL.") + return organization + if settings.LICENSE == settings.CLOUD_LICENSE_NAME: instance_id = request.headers.get(X_GRAFANA_INSTANCE_ID) if not instance_id: @@ -370,36 +376,13 @@ def get_organization(self, request): return Organization.objects.filter(org_slug=org_slug, stack_slug=instance_slug).first() def authenticate_credentials(self, organization, token): - permissions = get_service_account_token_permissions(organization, token) - if not permissions: + try: + user, auth_token = ServiceAccountToken.validate_token(organization, token) + except InvalidToken: raise exceptions.AuthenticationFailed("Invalid token.") - role = LegacyAccessControlRole.NONE - if not organization.is_rbac_permissions_enabled: - role = self.determine_role_from_permissions(permissions) - - user = User( - organization_id=organization.pk, - name="Grafana Service Account", - username="grafana_service_account", - role=role, - permissions=GrafanaAPIPermissions.construct_permissions(permissions.keys()), - ) - - auth_token = ApiAuthToken(organization=organization, user=user, name="Grafana Service Account") - return user, auth_token - # Using default permissions as proxies for roles since we cannot explicitly get role from the service account token - def determine_role_from_permissions(self, permissions): - if "plugins:write" in permissions: - return LegacyAccessControlRole.ADMIN - if "dashboards:write" in permissions: - return LegacyAccessControlRole.EDITOR - if "dashboards:read" in permissions: - return LegacyAccessControlRole.VIEWER - return LegacyAccessControlRole.NONE - class IntegrationBacksyncAuthentication(BaseAuthentication): model = IntegrationBacksyncAuthToken diff --git a/engine/apps/auth_token/grafana/grafana_auth_token.py b/engine/apps/auth_token/grafana/grafana_auth_token.py index 07bae6446f..6576e41793 100644 --- a/engine/apps/auth_token/grafana/grafana_auth_token.py +++ b/engine/apps/auth_token/grafana/grafana_auth_token.py @@ -46,3 +46,9 @@ def get_service_account_token_permissions(organization: Organization, token: str grafana_api_client = GrafanaAPIClient(api_url=organization.grafana_url, api_token=token) permissions, _ = grafana_api_client.get_service_account_token_permissions() return permissions + + +def get_service_account_details(organization: Organization, token: str) -> typing.Dict[str, typing.List[str]]: + grafana_api_client = GrafanaAPIClient(api_url=organization.grafana_url, api_token=token) + user_data, _ = grafana_api_client.get_current_user() + return user_data diff --git a/engine/apps/auth_token/migrations/0007_serviceaccounttoken.py b/engine/apps/auth_token/migrations/0007_serviceaccounttoken.py new file mode 100644 index 0000000000..920b9ada3e --- /dev/null +++ b/engine/apps/auth_token/migrations/0007_serviceaccounttoken.py @@ -0,0 +1,29 @@ +# Generated by Django 4.2.15 on 2024-11-12 13:13 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0027_serviceaccount'), + ('auth_token', '0006_googleoauth2token'), + ] + + operations = [ + migrations.CreateModel( + name='ServiceAccountToken', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('token_key', models.CharField(db_index=True, max_length=8)), + ('digest', models.CharField(max_length=128)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('revoked_at', models.DateTimeField(null=True)), + ('service_account', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='tokens', to='user_management.serviceaccount')), + ], + options={ + 'unique_together': {('token_key', 'service_account', 'digest')}, + }, + ), + ] diff --git a/engine/apps/auth_token/models/__init__.py b/engine/apps/auth_token/models/__init__.py index 272adbda60..42cc60c516 100644 --- a/engine/apps/auth_token/models/__init__.py +++ b/engine/apps/auth_token/models/__init__.py @@ -4,5 +4,6 @@ from .integration_backsync_auth_token import IntegrationBacksyncAuthToken # noqa: F401 from .plugin_auth_token import PluginAuthToken # noqa: F401 from .schedule_export_auth_token import ScheduleExportAuthToken # noqa: F401 +from .service_account_token import ServiceAccountToken # noqa: F401 from .slack_auth_token import SlackAuthToken # noqa: F401 from .user_schedule_export_auth_token import UserScheduleExportAuthToken # noqa: F401 diff --git a/engine/apps/auth_token/models/service_account_token.py b/engine/apps/auth_token/models/service_account_token.py new file mode 100644 index 0000000000..716dc55db3 --- /dev/null +++ b/engine/apps/auth_token/models/service_account_token.py @@ -0,0 +1,110 @@ +import binascii +from hmac import compare_digest + +from django.db import models + +from apps.api.permissions import GrafanaAPIPermissions, LegacyAccessControlRole +from apps.auth_token import constants +from apps.auth_token.crypto import hash_token_string +from apps.auth_token.exceptions import InvalidToken +from apps.auth_token.grafana.grafana_auth_token import ( + get_service_account_details, + get_service_account_token_permissions, +) +from apps.auth_token.models import BaseAuthToken +from apps.user_management.models import ServiceAccount, ServiceAccountUser + + +class ServiceAccountTokenManager(models.Manager): + def get_queryset(self): + return super().get_queryset().select_related("service_account__organization") + + +class ServiceAccountToken(BaseAuthToken): + GRAFANA_SA_PREFIX = "glsa_" + + objects = ServiceAccountTokenManager() + + service_account: "ServiceAccount" + service_account = models.ForeignKey(ServiceAccount, on_delete=models.CASCADE, related_name="tokens") + + class Meta: + unique_together = ("token_key", "service_account", "digest") + + @property + def organization(self): + return self.service_account.organization + + @classmethod + def validate_token(cls, organization, token): + # require RBAC enabled to allow service account auth + if not organization.is_rbac_permissions_enabled: + raise InvalidToken + + # Grafana API request: get permissions and confirm token is valid + permissions = get_service_account_token_permissions(organization, token) + if not permissions: + # NOTE: a token can be disabled/re-enabled (not setting as revoked in oncall DB for now) + raise InvalidToken + + # check if we have already seen this token + validated_token = None + service_account = None + prefix_length = len(cls.GRAFANA_SA_PREFIX) + token_key = token[prefix_length : prefix_length + constants.TOKEN_KEY_LENGTH] + try: + hashable_token = binascii.hexlify(token.encode()).decode() + digest = hash_token_string(hashable_token) + except (TypeError, binascii.Error): + raise InvalidToken + for existing_token in cls.objects.filter(service_account__organization=organization, token_key=token_key): + if compare_digest(digest, existing_token.digest): + validated_token = existing_token + service_account = existing_token.service_account + break + + if not validated_token: + # if it didn't match an existing token, create a new one + # make request to Grafana API api/user using token + service_account_data = get_service_account_details(organization, token) + if not service_account_data: + # Grafana versions < 11.3 return 403 trying to get user details with service account token + # use some default values + service_account_data = { + "login": "grafana_service_account", + "uid": None, # "service-account:7" + } + + grafana_id = 0 # default to zero for old Grafana versions (to keep service account unique) + if service_account_data["uid"] is not None: + # extract service account Grafana ID + try: + grafana_id = int(service_account_data["uid"].split(":")[-1]) + except ValueError: + pass + + # get or create service account + service_account, _ = ServiceAccount.objects.get_or_create( + organization=organization, + grafana_id=grafana_id, + defaults={ + "login": service_account_data["login"], + }, + ) + # create token + validated_token, _ = cls.objects.get_or_create( + service_account=service_account, + token_key=token_key, + digest=digest, + ) + + user = ServiceAccountUser( + organization=organization, + service_account=service_account, + username=service_account.username, + public_primary_key=service_account.public_primary_key, + role=LegacyAccessControlRole.NONE, + permissions=GrafanaAPIPermissions.construct_permissions(permissions.keys()), + ) + + return user, validated_token diff --git a/engine/apps/auth_token/tests/helpers.py b/engine/apps/auth_token/tests/helpers.py new file mode 100644 index 0000000000..bcecce6f2c --- /dev/null +++ b/engine/apps/auth_token/tests/helpers.py @@ -0,0 +1,18 @@ +import json + +import httpretty + + +def setup_service_account_api_mocks(organization, perms=None, user_data=None, perms_status=200, user_status=200): + # requires enabling httpretty + if perms is None: + perms = {} + mock_response = httpretty.Response(status=perms_status, body=json.dumps(perms)) + perms_url = f"{organization.grafana_url}/api/access-control/user/permissions" + httpretty.register_uri(httpretty.GET, perms_url, responses=[mock_response]) + + if user_data is None: + user_data = {"login": "some-login", "uid": "service-account:42"} + mock_response = httpretty.Response(status=user_status, body=json.dumps(user_data)) + user_url = f"{organization.grafana_url}/api/user" + httpretty.register_uri(httpretty.GET, user_url, responses=[mock_response]) diff --git a/engine/apps/auth_token/tests/test_grafana_auth.py b/engine/apps/auth_token/tests/test_grafana_auth.py index 5b78636c4a..3a8ec56c0d 100644 --- a/engine/apps/auth_token/tests/test_grafana_auth.py +++ b/engine/apps/auth_token/tests/test_grafana_auth.py @@ -1,11 +1,16 @@ import typing from unittest.mock import patch +import httpretty import pytest from rest_framework import exceptions from rest_framework.test import APIRequestFactory -from apps.auth_token.auth import GRAFANA_SA_PREFIX, X_GRAFANA_INSTANCE_ID, GrafanaServiceAccountAuthentication +from apps.api.permissions import LegacyAccessControlRole +from apps.auth_token.auth import X_GRAFANA_INSTANCE_ID, GrafanaServiceAccountAuthentication +from apps.auth_token.models import ServiceAccountToken +from apps.auth_token.tests.helpers import setup_service_account_api_mocks +from apps.user_management.models import ServiceAccountUser from settings.base import CLOUD_LICENSE_NAME, OPEN_SOURCE_LICENSE_NAME, SELF_HOSTED_SETTINGS @@ -53,7 +58,7 @@ def test_grafana_authentication_cloud_inputs(make_organization, settings): mock.assert_called_once_with(organization, token) -def check_common_inputs() -> (dict[str, typing.Any], str): +def check_common_inputs() -> tuple[dict[str, typing.Any], str]: request = APIRequestFactory().get("/") with pytest.raises(exceptions.AuthenticationFailed): GrafanaServiceAccountAuthentication().authenticate(request) @@ -65,7 +70,7 @@ def check_common_inputs() -> (dict[str, typing.Any], str): result = GrafanaServiceAccountAuthentication().authenticate(request) assert result is None - token = f"{GRAFANA_SA_PREFIX}xyz" + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" headers = { "HTTP_AUTHORIZATION": token, } @@ -74,3 +79,221 @@ def check_common_inputs() -> (dict[str, typing.Any], str): GrafanaServiceAccountAuthentication().authenticate(request) return headers, token + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_missing_org(): + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + } + request = APIRequestFactory().get("/", **headers) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid organization." + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_invalid_grafana_url(): + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + "HTTP_X_GRAFANA_URL": "http://grafana.test", # no org for this URL + } + request = APIRequestFactory().get("/", **headers) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid Grafana URL." + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_rbac_disabled_fails(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if organization.is_rbac_permissions_enabled: + return + + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid token." + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_permissions_call_fails(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + # permissions endpoint returns a 401 + setup_service_account_api_mocks(organization, perms_status=401) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid token." + + last_request = httpretty.last_request() + assert last_request.method == "GET" + expected_url = f"{organization.grafana_url}/api/access-control/user/permissions" + assert last_request.url == expected_url + # the request uses the given token + assert last_request.headers["Authorization"] == f"Bearer {token}" + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_existing_token( + make_organization, make_service_account_for_organization, make_token_for_service_account +): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + service_account = make_service_account_for_organization(organization) + token_string = "glsa_the-token" + token = make_token_for_service_account(service_account, token_string) + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + setup_service_account_api_mocks(organization, {"some-perm": "value"}) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + assert user.service_account == service_account + assert user.public_primary_key == service_account.public_primary_key + assert user.username == service_account.username + assert user.role == LegacyAccessControlRole.NONE + assert auth_token == token + + last_request = httpretty.last_request() + assert last_request.method == "GET" + expected_url = f"{organization.grafana_url}/api/access-control/user/permissions" + assert last_request.url == expected_url + # the request uses the given token + assert last_request.headers["Authorization"] == f"Bearer {token_string}" + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_token_created(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + token_string = "glsa_the-token" + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + permissions = {"some-perm": "value"} + user_data = {"login": "some-login", "uid": "service-account:42"} + setup_service_account_api_mocks(organization, permissions, user_data) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + service_account = user.service_account + assert service_account.organization == organization + assert user.public_primary_key == service_account.public_primary_key + assert user.username == service_account.username + assert service_account.grafana_id == 42 + assert service_account.login == "some-login" + assert user.role == LegacyAccessControlRole.NONE + assert user.permissions == [{"action": p} for p in permissions] + assert auth_token.service_account == user.service_account + + perms_request, user_request = httpretty.latest_requests() + for req in (perms_request, user_request): + assert req.method == "GET" + assert req.headers["Authorization"] == f"Bearer {token_string}" + perms_url = f"{organization.grafana_url}/api/access-control/user/permissions" + assert perms_request.url == perms_url + user_url = f"{organization.grafana_url}/api/user" + assert user_request.url == user_url + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_token_created_older_grafana(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + token_string = "glsa_the-token" + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + permissions = {"some-perm": "value"} + # User API fails for older Grafana versions + setup_service_account_api_mocks(organization, permissions, user_status=400) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + service_account = user.service_account + assert service_account.organization == organization + # use fallback data + assert service_account.grafana_id == 0 + assert service_account.login == "grafana_service_account" + assert auth_token.service_account == user.service_account + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_token_reuse_service_account(make_organization, make_service_account_for_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + service_account = make_service_account_for_organization(organization) + token_string = "glsa_the-token" + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + permissions = {"some-perm": "value"} + user_data = { + "login": service_account.login, + "uid": f"service-account:{service_account.grafana_id}", + } + setup_service_account_api_mocks(organization, permissions, user_data) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + assert user.service_account == service_account + assert auth_token.service_account == service_account diff --git a/engine/apps/email/inbound.py b/engine/apps/email/inbound.py index 1780f00c83..6c86e19485 100644 --- a/engine/apps/email/inbound.py +++ b/engine/apps/email/inbound.py @@ -1,27 +1,45 @@ import logging +from functools import cached_property from typing import Optional, TypedDict -from anymail.exceptions import AnymailInvalidAddress, AnymailWebhookValidationFailure +import requests +from anymail.exceptions import AnymailAPIError, AnymailInvalidAddress, AnymailWebhookValidationFailure from anymail.inbound import AnymailInboundMessage from anymail.signals import AnymailInboundEvent from anymail.webhooks import amazon_ses, mailgun, mailjet, mandrill, postal, postmark, sendgrid, sparkpost from django.http import HttpResponse, HttpResponseNotAllowed from django.utils import timezone from rest_framework import status -from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView from apps.base.utils import live_settings +from apps.email.validate_amazon_sns_message import validate_amazon_sns_message from apps.integrations.mixins import AlertChannelDefiningMixin from apps.integrations.tasks import create_alert logger = logging.getLogger(__name__) +class AmazonSESValidatedInboundWebhookView(amazon_ses.AmazonSESInboundWebhookView): + # disable "Your Anymail webhooks are insecure and open to anyone on the web." warning + warn_if_no_basic_auth = False + + def validate_request(self, request): + """Add SNS message validation to Amazon SES inbound webhook view, which is not implemented in Anymail.""" + if not validate_amazon_sns_message(self._parse_sns_message(request)): + raise AnymailWebhookValidationFailure("SNS message validation failed") + + def auto_confirm_sns_subscription(self, sns_message): + """This method is called after validate_request, so we can be sure that the message is valid.""" + response = requests.get(sns_message["SubscribeURL"]) + response.raise_for_status() + + # {: (, ), ...} INBOUND_EMAIL_ESP_OPTIONS = { "amazon_ses": (amazon_ses.AmazonSESInboundWebhookView, None), + "amazon_ses_validated": (AmazonSESValidatedInboundWebhookView, None), "mailgun": (mailgun.MailgunInboundWebhookView, "webhook_signing_key"), "mailjet": (mailjet.MailjetInboundWebhookView, "webhook_secret"), "mandrill": (mandrill.MandrillCombinedWebhookView, "webhook_key"), @@ -62,38 +80,33 @@ def dispatch(self, request): return super().dispatch(request, alert_channel_key=integration_token) def post(self, request): - timestamp = timezone.now().isoformat() - for message in self.get_messages_from_esp_request(request): - payload = self.get_alert_payload_from_email_message(message) - create_alert.delay( - title=payload["subject"], - message=payload["message"], - alert_receive_channel_pk=request.alert_receive_channel.pk, - image_url=None, - link_to_upstream_details=None, - integration_unique_data=None, - raw_request_data=payload, - received_at=timestamp, - ) - + payload = self.get_alert_payload_from_email_message(self.message) + create_alert.delay( + title=payload["subject"], + message=payload["message"], + alert_receive_channel_pk=request.alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data=payload, + received_at=timezone.now().isoformat(), + ) return Response("OK", status=status.HTTP_200_OK) def get_integration_token_from_request(self, request) -> Optional[str]: - messages = self.get_messages_from_esp_request(request) - if not messages: + if not self.message: return None - message = messages[0] # First try envelope_recipient field. # According to AnymailInboundMessage it's provided not by all ESPs. - if message.envelope_recipient: - recipients = message.envelope_recipient.split(",") + if self.message.envelope_recipient: + recipients = self.message.envelope_recipient.split(",") for recipient in recipients: # if there is more than one recipient, the first matching the expected domain will be used try: token, domain = recipient.strip().split("@") except ValueError: logger.error( - f"get_integration_token_from_request: envelope_recipient field has unexpected format: {message.envelope_recipient}" + f"get_integration_token_from_request: envelope_recipient field has unexpected format: {self.message.envelope_recipient}" ) continue if domain == live_settings.INBOUND_EMAIL_DOMAIN: @@ -113,20 +126,27 @@ def get_integration_token_from_request(self, request) -> Optional[str]: # return cc.address.split("@")[0] return None - def get_messages_from_esp_request(self, request: Request) -> list[AnymailInboundMessage]: - view_class, secret_name = INBOUND_EMAIL_ESP_OPTIONS[live_settings.INBOUND_EMAIL_ESP] + @cached_property + def message(self) -> AnymailInboundMessage | None: + esps = live_settings.INBOUND_EMAIL_ESP.split(",") + for esp in esps: + view_class, secret_name = INBOUND_EMAIL_ESP_OPTIONS[esp] - kwargs = {secret_name: live_settings.INBOUND_EMAIL_WEBHOOK_SECRET} if secret_name else {} - view = view_class(**kwargs) + kwargs = {secret_name: live_settings.INBOUND_EMAIL_WEBHOOK_SECRET} if secret_name else {} + view = view_class(**kwargs) - try: - view.run_validators(request) - events = view.parse_events(request) - except AnymailWebhookValidationFailure as e: - logger.info(f"get_messages_from_esp_request: inbound email webhook validation failed: {e}") - return [] + try: + view.run_validators(self.request) + events = view.parse_events(self.request) + except (AnymailWebhookValidationFailure, AnymailAPIError) as e: + logger.info(f"inbound email webhook validation failed for ESP {esp}: {e}") + continue - return [event.message for event in events if isinstance(event, AnymailInboundEvent)] + messages = [event.message for event in events if isinstance(event, AnymailInboundEvent)] + if messages: + return messages[0] + + return None def check_inbound_email_settings_set(self): """ diff --git a/engine/apps/email/tests/test_inbound_email.py b/engine/apps/email/tests/test_inbound_email.py index 81a76e923a..252b529208 100644 --- a/engine/apps/email/tests/test_inbound_email.py +++ b/engine/apps/email/tests/test_inbound_email.py @@ -1,13 +1,299 @@ +import datetime +import hashlib +import hmac import json +from base64 import b64encode from textwrap import dedent +from unittest.mock import ANY, Mock, patch import pytest from anymail.inbound import AnymailInboundMessage +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.x509 import CertificateBuilder, NameOID +from django.conf import settings from django.urls import reverse from rest_framework import status from rest_framework.test import APIClient +from apps.alerts.models import AlertReceiveChannel from apps.email.inbound import InboundEmailWebhookView +from apps.integrations.tasks import create_alert + +PRIVATE_KEY = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, +) +ISSUER_NAME = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Amazon"), + x509.NameAttribute(NameOID.COMMON_NAME, "Test"), + ] +) +CERTIFICATE = ( + CertificateBuilder() + .subject_name(ISSUER_NAME) + .issuer_name(ISSUER_NAME) + .public_key(PRIVATE_KEY.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now() - datetime.timedelta(days=1)) + .not_valid_after(datetime.datetime.now() + datetime.timedelta(days=10)) + .sign(PRIVATE_KEY, hashes.SHA256()) + .public_bytes(serialization.Encoding.PEM) +) +AMAZON_SNS_TOPIC_ARN = "arn:aws:sns:us-east-2:123456789012:test" +SIGNING_CERT_URL = "https://sns.us-east-2.amazonaws.com/SimpleNotificationService-example.pem" +SENDER_EMAIL = "sender@example.com" +TO_EMAIL = "test-token@inbound.example.com" +SUBJECT = "Test email" +MESSAGE = "This is a test email message body." + + +def _sns_inbound_email_payload_and_headers(sender_email, to_email, subject, message): + content = ( + f"From: Sender Name <{sender_email}>\n" + f"To: {to_email}\n" + f"Subject: {subject}\n" + "Date: Tue, 5 Nov 2024 16:05:39 +0000\n" + "Message-ID: \n\n" + f"{message}\r\n" + ) + + message = { + "notificationType": "Received", + "mail": { + "timestamp": "2024-11-05T16:05:52.387Z", + "source": sender_email, + "messageId": "example-message-id-5678", + "destination": [to_email], + "headersTruncated": False, + "headers": [ + {"name": "Return-Path", "value": f"<{sender_email}>"}, + { + "name": "Received", + "value": ( + f"from mail.example.com (mail.example.com [203.0.113.1]) " + f"by inbound-smtp.us-east-2.amazonaws.com with SMTP id example-id " + f"for {to_email}; Tue, 05 Nov 2024 16:05:52 +0000 (UTC)" + ), + }, + {"name": "X-SES-Spam-Verdict", "value": "PASS"}, + {"name": "X-SES-Virus-Verdict", "value": "PASS"}, + { + "name": "Received-SPF", + "value": ( + "pass (spfCheck: domain of example.com designates 203.0.113.1 as permitted sender) " + f"client-ip=203.0.113.1; envelope-from={sender_email}; helo=mail.example.com;" + ), + }, + { + "name": "Authentication-Results", + "value": ( + "amazonses.com; spf=pass (spfCheck: domain of example.com designates 203.0.113.1 as permitted sender) " + f"client-ip=203.0.113.1; envelope-from={sender_email}; helo=mail.example.com; " + "dkim=pass header.i=@example.com; dmarc=pass header.from=example.com;" + ), + }, + {"name": "X-SES-RECEIPT", "value": "example-receipt-data"}, + {"name": "X-SES-DKIM-SIGNATURE", "value": "example-dkim-signature"}, + { + "name": "Received", + "value": ( + f"by mail.example.com with SMTP id example-id for <{to_email}>; " + "Tue, 05 Nov 2024 08:05:52 -0800 (PST)" + ), + }, + { + "name": "DKIM-Signature", + "value": ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=example.com; s=default; t=1234567890; " + "bh=examplehash; h=From:To:Subject:Date:Message-ID; b=example-signature" + ), + }, + {"name": "X-Google-DKIM-Signature", "value": "example-google-dkim-signature"}, + {"name": "X-Gm-Message-State", "value": "example-message-state"}, + {"name": "X-Google-Smtp-Source", "value": "example-smtp-source"}, + { + "name": "X-Received", + "value": "by 2002:a17:example with SMTP id example-id; Tue, 05 Nov 2024 08:05:50 -0800 (PST)", + }, + {"name": "MIME-Version", "value": "1.0"}, + {"name": "From", "value": f"Sender Name <{sender_email}>"}, + {"name": "Date", "value": "Tue, 5 Nov 2024 16:05:39 +0000"}, + {"name": "Message-ID", "value": ""}, + {"name": "Subject", "value": subject}, + {"name": "To", "value": to_email}, + { + "name": "Content-Type", + "value": 'multipart/alternative; boundary="00000000000036b9f706262c9312"', + }, + ], + "commonHeaders": { + "returnPath": sender_email, + "from": [f"Sender Name <{sender_email}>"], + "date": "Tue, 5 Nov 2024 16:05:39 +0000", + "to": [to_email], + "messageId": "", + "subject": subject, + }, + }, + "receipt": { + "timestamp": "2024-11-05T16:05:52.387Z", + "processingTimeMillis": 638, + "recipients": [to_email], + "spamVerdict": {"status": "PASS"}, + "virusVerdict": {"status": "PASS"}, + "spfVerdict": {"status": "PASS"}, + "dkimVerdict": {"status": "PASS"}, + "dmarcVerdict": {"status": "PASS"}, + "action": { + "type": "SNS", + "topicArn": "arn:aws:sns:us-east-2:123456789012:test", + "encoding": "BASE64", + }, + }, + "content": b64encode(content.encode()).decode(), + } + + payload = { + "Type": "Notification", + "MessageId": "example-message-id-1234", + "TopicArn": AMAZON_SNS_TOPIC_ARN, + "Subject": "Amazon SES Email Receipt Notification", + "Message": json.dumps(message), + "Timestamp": "2024-11-05T16:05:53.041Z", + "SignatureVersion": "1", + "SigningCertURL": SIGNING_CERT_URL, + "UnsubscribeURL": ( + "https://sns.us-east-2.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=" + "arn:aws:sns:us-east-2:123456789012:test:example-subscription-id" + ), + } + # Sign the payload + canonical_message = "".join( + f"{key}\n{payload[key]}\n" for key in ("Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type") + ) + signature = PRIVATE_KEY.sign( + canonical_message.encode(), + padding.PKCS1v15(), + hashes.SHA1(), + ) + payload["Signature"] = b64encode(signature).decode() + + headers = { + "X-Amz-Sns-Message-Type": "Notification", + "X-Amz-Sns-Message-Id": "example-message-id-1234", + } + return payload, headers + + +def _mailgun_inbound_email_payload(sender_email, to_email, subject, message): + timestamp, token = "1731341416", "example-token" + signature = hmac.new( + key=settings.INBOUND_EMAIL_WEBHOOK_SECRET.encode("ascii"), + msg="{}{}".format(timestamp, token).encode("ascii"), + digestmod=hashlib.sha256, + ).hexdigest() + + return { + "Content-Type": 'multipart/alternative; boundary="000000000000267130626a556e5"', + "Date": "Mon, 11 Nov 2024 16:10:03 +0000", + "Dkim-Signature": ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=example.com; s=default; " + "t=1731341415; x=1731946215; darn=example.com; " + "h=to:subject:message-id:date:from:mime-version:from:to:cc:subject " + ":date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + "From": f"Sender Name <{sender_email}>", + "Message-Id": "", + "Mime-Version": "1.0", + "Received": ( + f"by mail.example.com with SMTP id example-id for <{to_email}>; " "Mon, 11 Nov 2024 08:10:15 -0800 (PST)" + ), + "Subject": subject, + "To": to_email, + "X-Envelope-From": sender_email, + "X-Gm-Message-State": "example-message-state", + "X-Google-Dkim-Signature": ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; " + "t=1731341415; x=1731946215; " + "h=to:subject:message-id:date:from:mime-version:x-gm-message-state " + ":from:to:cc:subject:date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + "X-Google-Smtp-Source": "example-smtp-source", + "X-Mailgun-Incoming": "Yes", + "X-Received": "by 2002:a17:example with SMTP id example-id; Mon, 11 Nov 2024 08:10:14 -0800 (PST)", + "body-html": f'
{message}
\r\n', + "body-plain": f"{message}\r\n", + "from": f"Sender Name <{sender_email}>", + "message-headers": json.dumps( + [ + ["X-Mailgun-Incoming", "Yes"], + ["X-Envelope-From", sender_email], + [ + "Received", + ( + "from mail.example.com (mail.example.com [203.0.113.1]) " + "by example.com with SMTP id example-id; " + "Mon, 11 Nov 2024 16:10:15 GMT" + ), + ], + [ + "Received", + ( + f"by mail.example.com with SMTP id example-id for <{to_email}>; " + "Mon, 11 Nov 2024 08:10:15 -0800 (PST)" + ), + ], + [ + "Dkim-Signature", + ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=example.com; s=default; " + "t=1731341415; x=1731946215; darn=example.com; " + "h=to:subject:message-id:date:from:mime-version:from:to:cc:subject " + ":date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + ], + [ + "X-Google-Dkim-Signature", + ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; " + "t=1731341415; x=1731946215; " + "h=to:subject:message-id:date:from:mime-version:x-gm-message-state " + ":from:to:cc:subject:date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + ], + ["X-Gm-Message-State", "example-message-state"], + ["X-Google-Smtp-Source", "example-smtp-source"], + [ + "X-Received", + "by 2002:a17:example with SMTP id example-id; Mon, 11 Nov 2024 08:10:14 -0800 (PST)", + ], + ["Mime-Version", "1.0"], + ["From", f"Sender Name <{sender_email}>"], + ["Date", "Mon, 11 Nov 2024 16:10:03 +0000"], + ["Message-Id", ""], + ["Subject", subject], + ["To", to_email], + [ + "Content-Type", + 'multipart/alternative; boundary="000000000000267130626a556e5"', + ], + ] + ), + "recipient": to_email, + "sender": sender_email, + "signature": signature, + "stripped-html": f'
{message}
\n', + "stripped-text": f"{message}\n", + "subject": subject, + "timestamp": timestamp, + "token": token, + } @pytest.mark.parametrize( @@ -141,3 +427,235 @@ def test_get_sender_from_email_message(sender_value, expected_result): view = InboundEmailWebhookView() result = view.get_sender_from_email_message(email) assert result == expected_result + + +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_pass(create_alert_mock, settings, make_organization, make_alert_receive_channel): + settings.INBOUND_EMAIL_ESP = "amazon_ses,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + + organization = make_organization() + alert_receive_channel = make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + create_alert_mock.assert_called_once_with( + title=SUBJECT, + message=MESSAGE, + alert_receive_channel_pk=alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data={ + "subject": SUBJECT, + "message": MESSAGE, + "sender": SENDER_EMAIL, + }, + received_at=ANY, + ) + + +@patch("requests.get", return_value=Mock(content=CERTIFICATE)) +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_validated_pass( + mock_create_alert, mock_requests_get, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = AMAZON_SNS_TOPIC_ARN + + organization = make_organization() + alert_receive_channel = make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + mock_create_alert.assert_called_once_with( + title=SUBJECT, + message=MESSAGE, + alert_receive_channel_pk=alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data={ + "subject": SUBJECT, + "message": MESSAGE, + "sender": SENDER_EMAIL, + }, + received_at=ANY, + ) + + mock_requests_get.assert_called_once_with(SIGNING_CERT_URL, timeout=5) + + +@patch("requests.get", return_value=Mock(content=CERTIFICATE)) +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_validated_fail_wrong_sns_topic_arn( + mock_create_alert, mock_requests_get, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = "arn:aws:sns:us-east-2:123456789013:test" + + organization = make_organization() + make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + mock_create_alert.assert_not_called() + mock_requests_get.assert_not_called() + + +@patch("requests.get", return_value=Mock(content=CERTIFICATE)) +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_validated_fail_wrong_signature( + mock_create_alert, mock_requests_get, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = AMAZON_SNS_TOPIC_ARN + + organization = make_organization() + make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + sns_payload["Signature"] = "invalid-signature" + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + mock_create_alert.assert_not_called() + mock_requests_get.assert_called_once_with(SIGNING_CERT_URL, timeout=5) + + +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_mailgun_pass(create_alert_mock, settings, make_organization, make_alert_receive_channel): + settings.INBOUND_EMAIL_ESP = "amazon_ses,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + + organization = make_organization() + alert_receive_channel = make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + mailgun_payload = _mailgun_inbound_email_payload( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=mailgun_payload, + format="multipart", + ) + + assert response.status_code == status.HTTP_200_OK + create_alert_mock.assert_called_once_with( + title=SUBJECT, + message=MESSAGE, + alert_receive_channel_pk=alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data={ + "subject": SUBJECT, + "message": MESSAGE, + "sender": SENDER_EMAIL, + }, + received_at=ANY, + ) + + +@pytest.mark.django_db +def test_multiple_esps_fail(settings): + settings.INBOUND_EMAIL_ESP = "amazon_ses,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + + client = APIClient() + response = client.post(reverse("integrations:inbound_email_webhook"), data={}) + + assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/engine/apps/email/validate_amazon_sns_message.py b/engine/apps/email/validate_amazon_sns_message.py new file mode 100644 index 0000000000..f3d2aec482 --- /dev/null +++ b/engine/apps/email/validate_amazon_sns_message.py @@ -0,0 +1,99 @@ +import logging +import re +from base64 import b64decode +from urllib.parse import urlparse + +import requests +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 +from cryptography.hazmat.primitives.hashes import SHA1, SHA256 +from cryptography.x509 import NameOID, load_pem_x509_certificate +from django.conf import settings + +logger = logging.getLogger(__name__) + +HOST_PATTERN = re.compile(r"^sns\.[a-zA-Z0-9\-]{3,}\.amazonaws\.com(\.cn)?$") +REQUIRED_KEYS = ( + "Message", + "MessageId", + "Timestamp", + "TopicArn", + "Type", + "Signature", + "SigningCertURL", + "SignatureVersion", +) +SIGNING_KEYS_NOTIFICATION = ("Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type") +SIGNING_KEYS_SUBSCRIPTION = ("Message", "MessageId", "SubscribeURL", "Timestamp", "Token", "TopicArn", "Type") + + +def validate_amazon_sns_message(message: dict) -> bool: + """ + Validate an AWS SNS message. Based on: + - https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html + - https://github.com/aws/aws-js-sns-message-validator/blob/a6ba4d646dc60912653357660301f3b25f94d686/index.js + - https://github.com/aws/aws-php-sns-message-validator/blob/3cee0fc1aee5538e1bd677654b09fad811061d0b/src/MessageValidator.php + """ + + # Check if the message has all the required keys + if not all(key in message for key in REQUIRED_KEYS): + logger.warning("Missing required keys in the message, got: %s", message.keys()) + return False + + # Check TopicArn + if message["TopicArn"] != settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN: + logger.warning("Invalid TopicArn: %s", message["TopicArn"]) + return False + + # Construct the canonical message + if message["Type"] == "Notification": + signing_keys = SIGNING_KEYS_NOTIFICATION + elif message["Type"] in ("SubscriptionConfirmation", "UnsubscribeConfirmation"): + signing_keys = SIGNING_KEYS_SUBSCRIPTION + else: + logger.warning("Invalid message type: %s", message["Type"]) + return False + canonical_message = "".join(f"{key}\n{message[key]}\n" for key in signing_keys if key in message).encode() + + # Check if SigningCertURL is a valid SNS URL + signing_cert_url = message["SigningCertURL"] + parsed_url = urlparse(signing_cert_url) + if ( + parsed_url.scheme != "https" + or not HOST_PATTERN.match(parsed_url.netloc) + or not parsed_url.path.endswith(".pem") + ): + logger.warning("Invalid SigningCertURL: %s", signing_cert_url) + return False + + # Fetch the certificate + try: + response = requests.get(signing_cert_url, timeout=5) + response.raise_for_status() + certificate_bytes = response.content + except requests.RequestException as e: + logger.warning("Failed to fetch the certificate from %s: %s", signing_cert_url, e) + return False + + # Verify the certificate issuer + certificate = load_pem_x509_certificate(certificate_bytes) + if certificate.issuer.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)[0].value != "Amazon": + logger.warning("Invalid certificate issuer: %s", certificate.issuer) + return False + + # Verify the signature + signature = b64decode(message["Signature"]) + if message["SignatureVersion"] == "1": + hash_algorithm = SHA1() + elif message["SignatureVersion"] == "2": + hash_algorithm = SHA256() + else: + logger.warning("Invalid SignatureVersion: %s", message["SignatureVersion"]) + return False + try: + certificate.public_key().verify(signature, canonical_message, PKCS1v15(), hash_algorithm) + except InvalidSignature: + logger.warning("Invalid signature") + return False + + return True diff --git a/engine/apps/grafana_plugin/helpers/client.py b/engine/apps/grafana_plugin/helpers/client.py index 2beafa8bdf..17d1cabd20 100644 --- a/engine/apps/grafana_plugin/helpers/client.py +++ b/engine/apps/grafana_plugin/helpers/client.py @@ -315,6 +315,9 @@ def get_grafana_labels_plugin_settings(self) -> APIClientResponse["GrafanaAPICli def get_grafana_irm_plugin_settings(self) -> APIClientResponse["GrafanaAPIClient.Types.PluginSettings"]: return self.get_grafana_plugin_settings(PluginID.IRM) + def get_current_user(self) -> APIClientResponse[typing.Dict[str, typing.List[str]]]: + return self.api_get("api/user") + def get_service_account(self, login: str) -> APIClientResponse["GrafanaAPIClient.Types.ServiceAccountResponse"]: return self.api_get(f"api/serviceaccounts/search?query={login}") diff --git a/engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py b/engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py new file mode 100644 index 0000000000..e50d915ee5 --- /dev/null +++ b/engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.16 on 2024-11-20 15:39 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('heartbeat', '0002_delete_heartbeat'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='integrationheartbeat', + name='actual_check_up_task_id', + ), + migrations.RemoveField( + model_name='integrationheartbeat', + name='last_checkup_task_time', + ), + ] diff --git a/engine/apps/heartbeat/models.py b/engine/apps/heartbeat/models.py index 0c0084bd15..4688cc716d 100644 --- a/engine/apps/heartbeat/models.py +++ b/engine/apps/heartbeat/models.py @@ -48,16 +48,6 @@ class IntegrationHeartBeat(models.Model): Stores the latest received heartbeat signal time """ - last_checkup_task_time = models.DateTimeField(default=None, null=True) - """ - Deprecated. This field is not used. TODO: remove it - """ - - actual_check_up_task_id = models.CharField(max_length=100) - """ - Deprecated. Stored the latest scheduled `integration_heartbeat_checkup` task id. TODO: remove it - """ - previous_alerted_state_was_life = models.BooleanField(default=True) """ Last status of the heartbeat. Determines if integration was alive on latest checkup diff --git a/engine/apps/heartbeat/tasks.py b/engine/apps/heartbeat/tasks.py index 7939290ec5..e9d26c578d 100644 --- a/engine/apps/heartbeat/tasks.py +++ b/engine/apps/heartbeat/tasks.py @@ -105,12 +105,6 @@ def _get_timeout_expression() -> ExpressionWrapper: return f"Found {expired_count} expired and {restored_count} restored heartbeats" -@shared_dedicated_queue_retry_task() -def integration_heartbeat_checkup(heartbeat_id: int) -> None: - """Deprecated. TODO: Remove this task after this task cleared from queue""" - pass - - @shared_dedicated_queue_retry_task() def process_heartbeat_task(alert_receive_channel_pk): IntegrationHeartBeat.objects.filter( diff --git a/engine/apps/heartbeat/tests/factories.py b/engine/apps/heartbeat/tests/factories.py index 5e69db9de9..40011255e3 100644 --- a/engine/apps/heartbeat/tests/factories.py +++ b/engine/apps/heartbeat/tests/factories.py @@ -4,7 +4,5 @@ class IntegrationHeartBeatFactory(factory.DjangoModelFactory): - actual_check_up_task_id = "none" - class Meta: model = IntegrationHeartBeat diff --git a/engine/apps/integrations/tasks.py b/engine/apps/integrations/tasks.py index 45f3e04f2a..91f6a7d416 100644 --- a/engine/apps/integrations/tasks.py +++ b/engine/apps/integrations/tasks.py @@ -31,10 +31,7 @@ def create_alertmanager_alerts(alert_receive_channel_pk, alert, is_demo=False, r from apps.alerts.models import Alert, AlertReceiveChannel alert_receive_channel = AlertReceiveChannel.objects_with_deleted.get(pk=alert_receive_channel_pk) - if ( - alert_receive_channel.deleted_at is not None - or alert_receive_channel.integration == AlertReceiveChannel.INTEGRATION_MAINTENANCE - ): + if alert_receive_channel.deleted_at is not None or alert_receive_channel.is_maintenace_integration: logger.info("AlertReceiveChannel alert ignored if deleted/maintenance") return diff --git a/engine/apps/mobile_app/demo_push.py b/engine/apps/mobile_app/demo_push.py index 19daca5b2f..01194c1487 100644 --- a/engine/apps/mobile_app/demo_push.py +++ b/engine/apps/mobile_app/demo_push.py @@ -8,7 +8,7 @@ from apps.mobile_app.exceptions import DeviceNotSet from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import add_stack_slug_to_message_title, construct_fcm_message, send_push_notification +from apps.mobile_app.utils import construct_fcm_message, send_push_notification from apps.user_management.models import User if typing.TYPE_CHECKING: @@ -47,7 +47,7 @@ def _get_test_escalation_fcm_message(user: User, device_to_notify: "FCMDevice", apns_sound_name = mobile_app_user_settings.get_notification_sound_name(message_type, Platform.IOS) fcm_message_data: FCMMessageData = { - "title": add_stack_slug_to_message_title(get_test_push_title(critical), user.organization), + "title": get_test_push_title(critical), "orgName": user.organization.stack_slug, # Pass user settings, so the Android app can use them to play the correct sound and volume "default_notification_sound_name": mobile_app_user_settings.get_notification_sound_name( diff --git a/engine/apps/mobile_app/tasks/going_oncall_notification.py b/engine/apps/mobile_app/tasks/going_oncall_notification.py index 214fa19df8..34fd41607c 100644 --- a/engine/apps/mobile_app/tasks/going_oncall_notification.py +++ b/engine/apps/mobile_app/tasks/going_oncall_notification.py @@ -12,12 +12,7 @@ from firebase_admin.messaging import APNSPayload, Aps, ApsAlert, CriticalSound, Message from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import ( - MAX_RETRIES, - add_stack_slug_to_message_title, - construct_fcm_message, - send_push_notification, -) +from apps.mobile_app.utils import MAX_RETRIES, construct_fcm_message, send_push_notification from apps.schedules.models.on_call_schedule import OnCallSchedule, ScheduleEvent from apps.user_management.models import User from common.cache import ensure_cache_key_allocates_to_the_same_hash_slot @@ -82,7 +77,7 @@ def _get_fcm_message( notification_subtitle = _get_notification_subtitle(schedule, schedule_event, mobile_app_user_settings) data: FCMMessageData = { - "title": add_stack_slug_to_message_title(notification_title, user.organization), + "title": notification_title, "subtitle": notification_subtitle, "orgName": user.organization.stack_slug, "info_notification_sound_name": mobile_app_user_settings.get_notification_sound_name( diff --git a/engine/apps/mobile_app/tasks/new_alert_group.py b/engine/apps/mobile_app/tasks/new_alert_group.py index e33e91112e..2b759f5f6e 100644 --- a/engine/apps/mobile_app/tasks/new_alert_group.py +++ b/engine/apps/mobile_app/tasks/new_alert_group.py @@ -8,12 +8,7 @@ from apps.alerts.models import AlertGroup from apps.mobile_app.alert_rendering import get_push_notification_subtitle, get_push_notification_title from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import ( - MAX_RETRIES, - add_stack_slug_to_message_title, - construct_fcm_message, - send_push_notification, -) +from apps.mobile_app.utils import MAX_RETRIES, construct_fcm_message, send_push_notification from apps.user_management.models import User from common.custom_celery_tasks import shared_dedicated_queue_retry_task @@ -46,7 +41,7 @@ def _get_fcm_message(alert_group: AlertGroup, user: User, device_to_notify: "FCM apns_sound_name = mobile_app_user_settings.get_notification_sound_name(message_type, Platform.IOS) fcm_message_data: FCMMessageData = { - "title": add_stack_slug_to_message_title(alert_title, alert_group.channel.organization), + "title": alert_title, "subtitle": alert_subtitle, "orgId": alert_group.channel.organization.public_primary_key, "orgName": alert_group.channel.organization.stack_slug, diff --git a/engine/apps/mobile_app/tasks/new_shift_swap_request.py b/engine/apps/mobile_app/tasks/new_shift_swap_request.py index a6d49c8b20..3ab7167410 100644 --- a/engine/apps/mobile_app/tasks/new_shift_swap_request.py +++ b/engine/apps/mobile_app/tasks/new_shift_swap_request.py @@ -10,12 +10,7 @@ from firebase_admin.messaging import APNSPayload, Aps, ApsAlert, CriticalSound, Message from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import ( - MAX_RETRIES, - add_stack_slug_to_message_title, - construct_fcm_message, - send_push_notification, -) +from apps.mobile_app.utils import MAX_RETRIES, construct_fcm_message, send_push_notification from apps.schedules.models import ShiftSwapRequest from apps.user_management.models import User from common.custom_celery_tasks import shared_dedicated_queue_retry_task @@ -121,7 +116,7 @@ def _get_fcm_message( route = f"/schedules/{shift_swap_request.schedule.public_primary_key}/ssrs/{shift_swap_request.public_primary_key}" data: FCMMessageData = { - "title": add_stack_slug_to_message_title(notification_title, user.organization), + "title": notification_title, "subtitle": notification_subtitle, "orgName": user.organization.stack_slug, "route": route, diff --git a/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py b/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py index 2541d507f9..051e4ffbfe 100644 --- a/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py +++ b/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py @@ -18,7 +18,6 @@ conditionally_send_going_oncall_push_notifications_for_schedule, ) from apps.mobile_app.types import MessageType, Platform -from apps.mobile_app.utils import add_stack_slug_to_message_title from apps.schedules.models import OnCallScheduleCalendar, OnCallScheduleICal, OnCallScheduleWeb from apps.schedules.models.on_call_schedule import ScheduleEvent @@ -228,7 +227,7 @@ def test_get_fcm_message( maus = MobileAppUserSettings.objects.create(user=user, time_zone=user_tz) data = { - "title": add_stack_slug_to_message_title(mock_notification_title, organization), + "title": mock_notification_title, "subtitle": mock_notification_subtitle, "orgName": organization.stack_slug, "info_notification_sound_name": maus.get_notification_sound_name(MessageType.INFO, Platform.ANDROID), diff --git a/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py b/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py index 452b98952f..f77674f8fc 100644 --- a/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py +++ b/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py @@ -19,7 +19,6 @@ notify_shift_swap_requests, notify_user_about_shift_swap_request, ) -from apps.mobile_app.utils import add_stack_slug_to_message_title from apps.schedules.models import CustomOnCallShift, OnCallScheduleWeb, ShiftSwapRequest from apps.user_management.models import User from apps.user_management.models.user import default_working_hours @@ -288,7 +287,7 @@ def test_notify_user_about_shift_swap_request( message: Message = mock_send_push_notification.call_args.args[1] assert message.data["type"] == "oncall.info" - assert message.data["title"] == add_stack_slug_to_message_title("New shift swap request", organization) + assert message.data["title"] == "New shift swap request" assert message.data["subtitle"] == "John Doe, Test Schedule" assert ( message.data["route"] @@ -487,9 +486,7 @@ def test_notify_beneficiary_about_taken_shift_swap_request( message: Message = mock_send_push_notification.call_args.args[1] assert message.data["type"] == "oncall.info" - assert message.data["title"] == add_stack_slug_to_message_title( - "Your shift swap request has been taken", organization - ) + assert message.data["title"] == "Your shift swap request has been taken" assert message.data["subtitle"] == schedule_name assert ( message.data["route"] diff --git a/engine/apps/mobile_app/tests/test_demo_push.py b/engine/apps/mobile_app/tests/test_demo_push.py index 769691f75e..abf5f6eb9f 100644 --- a/engine/apps/mobile_app/tests/test_demo_push.py +++ b/engine/apps/mobile_app/tests/test_demo_push.py @@ -2,7 +2,6 @@ from apps.mobile_app.demo_push import _get_test_escalation_fcm_message, get_test_push_title from apps.mobile_app.models import FCMDevice, MobileAppUserSettings -from apps.mobile_app.utils import add_stack_slug_to_message_title @pytest.mark.django_db @@ -34,7 +33,7 @@ def test_test_escalation_fcm_message_user_settings( # Check expected test push content assert message.apns.payload.aps.badge is None assert message.apns.payload.aps.alert.title == get_test_push_title(critical=False) - assert message.data["title"] == add_stack_slug_to_message_title(get_test_push_title(critical=False), organization) + assert message.data["title"] == get_test_push_title(critical=False) assert message.data["type"] == "oncall.message" @@ -68,7 +67,7 @@ def test_escalation_fcm_message_user_settings_critical( # Check expected test push content assert message.apns.payload.aps.badge is None assert message.apns.payload.aps.alert.title == get_test_push_title(critical=True) - assert message.data["title"] == add_stack_slug_to_message_title(get_test_push_title(critical=True), organization) + assert message.data["title"] == get_test_push_title(critical=True) assert message.data["type"] == "oncall.critical_message" @@ -94,4 +93,4 @@ def test_escalation_fcm_message_user_settings_critical_override_dnd_disabled( # Check expected test push content assert message.apns.payload.aps.badge is None assert message.apns.payload.aps.alert.title == get_test_push_title(critical=True) - assert message.data["title"] == add_stack_slug_to_message_title(get_test_push_title(critical=True), organization) + assert message.data["title"] == get_test_push_title(critical=True) diff --git a/engine/apps/public_api/serializers/integrations.py b/engine/apps/public_api/serializers/integrations.py index b16aeb5472..0cbf460583 100644 --- a/engine/apps/public_api/serializers/integrations.py +++ b/engine/apps/public_api/serializers/integrations.py @@ -7,6 +7,7 @@ from apps.alerts.models import AlertReceiveChannel from apps.base.messaging import get_messaging_backends from apps.integrations.legacy_prefix import has_legacy_prefix, remove_legacy_prefix +from apps.user_management.models import ServiceAccountUser from common.api_helpers.custom_fields import TeamPrimaryKeyRelatedField from common.api_helpers.exceptions import BadRequest from common.api_helpers.mixins import PHONE_CALL, SLACK, SMS, TELEGRAM, WEB, EagerLoadingMixin @@ -123,11 +124,13 @@ def create(self, validated_data): connection_error = GrafanaAlertingSyncManager.check_for_connection_errors(organization) if connection_error: raise serializers.ValidationError(connection_error) + user = self.context["request"].user with transaction.atomic(): try: instance = AlertReceiveChannel.create( **validated_data, - author=self.context["request"].user, + author=user if not isinstance(user, ServiceAccountUser) else None, + service_account=user.service_account if isinstance(user, ServiceAccountUser) else None, organization=organization, ) except AlertReceiveChannel.DuplicateDirectPagingError: diff --git a/engine/apps/public_api/tests/test_alert_groups.py b/engine/apps/public_api/tests/test_alert_groups.py index 71421cd318..e3cc872e3a 100644 --- a/engine/apps/public_api/tests/test_alert_groups.py +++ b/engine/apps/public_api/tests/test_alert_groups.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import httpretty import pytest from django.urls import reverse from django.utils import timezone @@ -9,6 +10,8 @@ from apps.alerts.constants import ActionSource from apps.alerts.models import AlertGroup, AlertReceiveChannel from apps.alerts.tasks import delete_alert_group, wipe +from apps.api import permissions +from apps.auth_token.tests.helpers import setup_service_account_api_mocks def construct_expected_response_from_alert_groups(alert_groups): @@ -736,3 +739,34 @@ def test_alert_group_unsilence( assert alert_group.silenced == silenced assert response.status_code == status_code assert response_msg == response.json()["detail"] + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_actions_disabled_for_service_accounts( + make_organization, + make_service_account_for_organization, + make_token_for_service_account, + make_escalation_chain, +): + organization = make_organization(grafana_url="http://grafana.test") + service_account = make_service_account_for_organization(organization) + token_string = "glsa_token" + make_token_for_service_account(service_account, token_string) + make_escalation_chain(organization) + + perms = { + permissions.RBACPermission.Permissions.ALERT_GROUPS_WRITE.value: ["*"], + } + setup_service_account_api_mocks(organization, perms=perms) + + client = APIClient() + disabled_actions = ["acknowledge", "unacknowledge", "resolve", "unresolve", "silence", "unsilence"] + for action in disabled_actions: + url = reverse(f"api-public:alert_groups-{action}", kwargs={"pk": "ABCDEFG"}) + response = client.post( + url, + HTTP_AUTHORIZATION=f"{token_string}", + HTTP_X_GRAFANA_URL=organization.grafana_url, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN diff --git a/engine/apps/public_api/tests/test_integrations.py b/engine/apps/public_api/tests/test_integrations.py index b021df33e1..9a4e29c64f 100644 --- a/engine/apps/public_api/tests/test_integrations.py +++ b/engine/apps/public_api/tests/test_integrations.py @@ -1,9 +1,12 @@ +import httpretty import pytest from django.urls import reverse from rest_framework import status from rest_framework.test import APIClient from apps.alerts.models import AlertReceiveChannel +from apps.api import permissions +from apps.auth_token.tests.helpers import setup_service_account_api_mocks from apps.base.tests.messaging_backend import TestOnlyBackend TEST_MESSAGING_BACKEND_FIELD = TestOnlyBackend.backend_id.lower() @@ -104,6 +107,47 @@ def test_create_integration( assert response.status_code == status.HTTP_201_CREATED +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_create_integration_via_service_account( + make_organization, + make_service_account_for_organization, + make_token_for_service_account, + make_escalation_chain, +): + organization = make_organization(grafana_url="http://grafana.test") + service_account = make_service_account_for_organization(organization) + token_string = "glsa_token" + make_token_for_service_account(service_account, token_string) + make_escalation_chain(organization) + + perms = { + permissions.RBACPermission.Permissions.INTEGRATIONS_WRITE.value: ["*"], + } + setup_service_account_api_mocks(organization, perms) + + client = APIClient() + data_for_create = { + "type": "grafana", + "name": "grafana_created", + "team_id": None, + } + url = reverse("api-public:integrations-list") + response = client.post( + url, + data=data_for_create, + format="json", + HTTP_AUTHORIZATION=f"{token_string}", + HTTP_X_GRAFANA_URL=organization.grafana_url, + ) + if not organization.is_rbac_permissions_enabled: + assert response.status_code == status.HTTP_403_FORBIDDEN + else: + assert response.status_code == status.HTTP_201_CREATED + integration = AlertReceiveChannel.objects.get(public_primary_key=response.data["id"]) + assert integration.service_account == service_account + + @pytest.mark.django_db def test_integration_name_uniqueness( make_organization_and_user_with_token, @@ -859,7 +903,6 @@ def test_get_list_integrations_link_and_inbound_email( if integration_type in [ AlertReceiveChannel.INTEGRATION_MANUAL, - AlertReceiveChannel.INTEGRATION_SLACK_CHANNEL, AlertReceiveChannel.INTEGRATION_MAINTENANCE, ]: assert integration_link is None diff --git a/engine/apps/public_api/tests/test_rbac_permissions.py b/engine/apps/public_api/tests/test_rbac_permissions.py index 9829550d8c..95154ab4de 100644 --- a/engine/apps/public_api/tests/test_rbac_permissions.py +++ b/engine/apps/public_api/tests/test_rbac_permissions.py @@ -1,5 +1,7 @@ +import json from unittest.mock import patch +import httpretty import pytest from django.urls import reverse from rest_framework import status @@ -9,6 +11,13 @@ from apps.api.permissions import GrafanaAPIPermission, LegacyAccessControlRole, get_most_authorized_role from apps.public_api.urls import router +VIEWS_REQUIRING_USER_AUTH = ( + "EscalationView", + "PersonalNotificationView", + "MakeCallView", + "SendSMSView", +) + @pytest.mark.parametrize( "rbac_enabled,role,give_perm", @@ -96,3 +105,98 @@ def test_rbac_permissions( with patch(method_path, return_value=success): response = client.generic(path=url, method=http_method, HTTP_AUTHORIZATION=token) assert response.status_code == expected + + +@pytest.mark.parametrize( + "rbac_enabled,role,give_perm", + [ + # rbac disabled: auth is disabled + (False, LegacyAccessControlRole.ADMIN, None), + # rbac enabled: having role None, check the perm is required + (True, LegacyAccessControlRole.NONE, False), + (True, LegacyAccessControlRole.NONE, True), + ], +) +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_service_account_auth( + make_organization, + make_service_account_for_organization, + make_token_for_service_account, + rbac_enabled, + role, + give_perm, +): + # APIView default actions + # (name, http method, detail-based) + default_actions = { + "create": ("post", False), + "list": ("get", False), + "retrieve": ("get", True), + "update": ("put", True), + "partial_update": ("patch", True), + "destroy": ("delete", True), + } + + organization = make_organization(grafana_url="http://grafana.test") + service_account = make_service_account_for_organization(organization) + token_string = "glsa_token" + make_token_for_service_account(service_account, token_string) + + if organization.is_rbac_permissions_enabled != rbac_enabled: + # skip if the organization's rbac_enabled is not the expected by the test + return + + client = APIClient() + # check all actions for all public API viewsets + for _, viewset, _basename in router.registry: + if viewset.__name__ == "ActionView": + # old actions (webhooks) are deprecated, no RBAC or service account support + continue + for viewset_method_name, required_perms in viewset.rbac_permissions.items(): + # setup Grafana API permissions response + if rbac_enabled: + permissions = {"perm": "value"} + expected = status.HTTP_403_FORBIDDEN + if give_perm: + permissions = {perm.value: "value" for perm in required_perms} + expected = status.HTTP_200_OK + mock_response = httpretty.Response(status=200, body=json.dumps(permissions)) + perms_url = f"{organization.grafana_url}/api/access-control/user/permissions" + httpretty.register_uri(httpretty.GET, perms_url, responses=[mock_response]) + else: + # service account auth is disabled + expected = status.HTTP_403_FORBIDDEN + + # iterate over all viewset actions, making an API request for each, + # using the user's token and confirming the response status code + if viewset_method_name in default_actions: + http_method, detail = default_actions[viewset_method_name] + else: + action_method = getattr(viewset, viewset_method_name) + http_method = list(action_method.mapping.keys())[0] + detail = action_method.detail + + method_path = f"{viewset.__module__}.{viewset.__name__}.{viewset_method_name}" + success = Response(status=status.HTTP_200_OK) + kwargs = {"pk": "NONEXISTENT"} if detail else None + if viewset_method_name in default_actions and detail: + url = reverse(f"api-public:{_basename}-detail", kwargs=kwargs) + elif viewset_method_name in default_actions and not detail: + url = reverse(f"api-public:{_basename}-list", kwargs=kwargs) + else: + name = viewset_method_name.replace("_", "-") + url = reverse(f"api-public:{_basename}-{name}", kwargs=kwargs) + + with patch(method_path, return_value=success): + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + response = client.generic(path=url, method=http_method, **headers) + assert ( + response.status_code == expected + if viewset.__name__ not in VIEWS_REQUIRING_USER_AUTH + # user-specific APIs do not support service account auth + else status.HTTP_403_FORBIDDEN + ) diff --git a/engine/apps/public_api/tests/test_resolution_notes.py b/engine/apps/public_api/tests/test_resolution_notes.py index c3a89a1da4..7a730e18ca 100644 --- a/engine/apps/public_api/tests/test_resolution_notes.py +++ b/engine/apps/public_api/tests/test_resolution_notes.py @@ -6,8 +6,8 @@ from rest_framework.test import APIClient from apps.alerts.models import ResolutionNote -from apps.auth_token.auth import GRAFANA_SA_PREFIX, ApiTokenAuthentication, GrafanaServiceAccountAuthentication -from apps.auth_token.models import ApiAuthToken +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication +from apps.auth_token.models import ApiAuthToken, ServiceAccountToken @pytest.mark.django_db @@ -366,7 +366,7 @@ def test_create_resolution_note_grafana_auth(make_organization_and_user, make_al mock_api_key_auth.assert_called_once() assert response.status_code == status.HTTP_403_FORBIDDEN - token = f"{GRAFANA_SA_PREFIX}123" + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}123" # GrafanaServiceAccountAuthentication handle invalid token with patch( "apps.auth_token.auth.ApiTokenAuthentication.authenticate", wraps=api_token_auth.authenticate diff --git a/engine/apps/public_api/views/alert_groups.py b/engine/apps/public_api/views/alert_groups.py index d4f4a302ff..fc5d01d029 100644 --- a/engine/apps/public_api/views/alert_groups.py +++ b/engine/apps/public_api/views/alert_groups.py @@ -12,12 +12,13 @@ from apps.alerts.tasks import delete_alert_group, wipe from apps.api.label_filtering import parse_label_query from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.constants import VALID_DATE_FOR_DELETE_INCIDENT from apps.public_api.helpers import is_valid_group_creation_date, team_has_slack_token_for_deleting from apps.public_api.serializers import AlertGroupSerializer from apps.public_api.throttlers.user_throttle import UserThrottle -from common.api_helpers.exceptions import BadRequest +from apps.user_management.models import ServiceAccountUser +from common.api_helpers.exceptions import BadRequest, Forbidden from common.api_helpers.filters import ( NO_TEAM_VALUE, ByTeamModelFieldFilterMixin, @@ -57,7 +58,7 @@ class AlertGroupView( mixins.DestroyModelMixin, GenericViewSet, ): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { @@ -170,6 +171,9 @@ def destroy(self, request, *args, **kwargs): @action(methods=["post"], detail=True) def acknowledge(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to acknowledge alert groups") + alert_group = self.get_object() if alert_group.acknowledged: @@ -189,6 +193,9 @@ def acknowledge(self, request, pk): @action(methods=["post"], detail=True) def unacknowledge(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to unacknowledge alert groups") + alert_group = self.get_object() if not alert_group.acknowledged: @@ -208,6 +215,9 @@ def unacknowledge(self, request, pk): @action(methods=["post"], detail=True) def resolve(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to resolve alert groups") + alert_group = self.get_object() if alert_group.resolved: @@ -225,6 +235,9 @@ def resolve(self, request, pk): @action(methods=["post"], detail=True) def unresolve(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to unresolve alert groups") + alert_group = self.get_object() if not alert_group.resolved: @@ -241,6 +254,9 @@ def unresolve(self, request, pk): @action(methods=["post"], detail=True) def silence(self, request, pk=None): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to silence alert groups") + alert_group = self.get_object() delay = request.data.get("delay") @@ -267,6 +283,9 @@ def silence(self, request, pk=None): @action(methods=["post"], detail=True) def unsilence(self, request, pk=None): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to unsilence alert groups") + alert_group = self.get_object() if not alert_group.silenced: diff --git a/engine/apps/public_api/views/alerts.py b/engine/apps/public_api/views/alerts.py index b96d51c50c..0f3d1d4669 100644 --- a/engine/apps/public_api/views/alerts.py +++ b/engine/apps/public_api/views/alerts.py @@ -7,7 +7,7 @@ from apps.alerts.models import Alert from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.alerts import AlertSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.mixins import RateLimitHeadersMixin @@ -19,7 +19,7 @@ class AlertFilter(filters.FilterSet): class AlertView(RateLimitHeadersMixin, mixins.ListModelMixin, GenericViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/escalation_chains.py b/engine/apps/public_api/views/escalation_chains.py index 84bb71628d..52a1cc444c 100644 --- a/engine/apps/public_api/views/escalation_chains.py +++ b/engine/apps/public_api/views/escalation_chains.py @@ -5,7 +5,7 @@ from apps.alerts.models import EscalationChain from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import EscalationChainSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.filters import ByTeamFilter @@ -15,7 +15,7 @@ class EscalationChainView(RateLimitHeadersMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/escalation_policies.py b/engine/apps/public_api/views/escalation_policies.py index ddbaeae803..e91e52f48b 100644 --- a/engine/apps/public_api/views/escalation_policies.py +++ b/engine/apps/public_api/views/escalation_policies.py @@ -5,7 +5,7 @@ from apps.alerts.models import EscalationPolicy from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import EscalationPolicySerializer, EscalationPolicyUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.mixins import RateLimitHeadersMixin, UpdateSerializerMixin @@ -14,7 +14,7 @@ class EscalationPolicyView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/integrations.py b/engine/apps/public_api/views/integrations.py index 26c55224fd..e8ec9a852b 100644 --- a/engine/apps/public_api/views/integrations.py +++ b/engine/apps/public_api/views/integrations.py @@ -5,7 +5,7 @@ from apps.alerts.models import AlertReceiveChannel from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import IntegrationSerializer, IntegrationUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.exceptions import BadRequest @@ -24,7 +24,7 @@ class IntegrationView( MaintainableObjectMixin, ModelViewSet, ): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/on_call_shifts.py b/engine/apps/public_api/views/on_call_shifts.py index e825ea3537..2e091e947c 100644 --- a/engine/apps/public_api/views/on_call_shifts.py +++ b/engine/apps/public_api/views/on_call_shifts.py @@ -5,7 +5,7 @@ from rest_framework.viewsets import ModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import CustomOnCallShiftSerializer, CustomOnCallShiftUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.schedules.models import CustomOnCallShift @@ -16,7 +16,7 @@ class CustomOnCallShiftView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/organizations.py b/engine/apps/public_api/views/organizations.py index 1df2f63a5d..473d79de6c 100644 --- a/engine/apps/public_api/views/organizations.py +++ b/engine/apps/public_api/views/organizations.py @@ -3,7 +3,7 @@ from rest_framework.viewsets import ReadOnlyModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import OrganizationSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.user_management.models import Organization @@ -15,7 +15,7 @@ class OrganizationView( RateLimitHeadersMixin, ReadOnlyModelViewSet, ): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/routes.py b/engine/apps/public_api/views/routes.py index 7946152718..19ddc1056a 100644 --- a/engine/apps/public_api/views/routes.py +++ b/engine/apps/public_api/views/routes.py @@ -7,7 +7,7 @@ from apps.alerts.models import ChannelFilter from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import ChannelFilterSerializer, ChannelFilterUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.exceptions import BadRequest @@ -17,7 +17,7 @@ class ChannelFilterView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/schedules.py b/engine/apps/public_api/views/schedules.py index 6dcca6fd08..5960ad4894 100644 --- a/engine/apps/public_api/views/schedules.py +++ b/engine/apps/public_api/views/schedules.py @@ -9,7 +9,11 @@ from rest_framework.viewsets import ModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication, ScheduleExportAuthentication +from apps.auth_token.auth import ( + ApiTokenAuthentication, + GrafanaServiceAccountAuthentication, + ScheduleExportAuthentication, +) from apps.public_api.custom_renderers import CalendarRenderer from apps.public_api.serializers import PolymorphicScheduleSerializer, PolymorphicScheduleUpdateSerializer from apps.public_api.serializers.schedules_base import FinalShiftQueryParamsSerializer @@ -28,7 +32,7 @@ class OnCallScheduleChannelView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/shift_swap.py b/engine/apps/public_api/views/shift_swap.py index 07f978e5c9..c46c141965 100644 --- a/engine/apps/public_api/views/shift_swap.py +++ b/engine/apps/public_api/views/shift_swap.py @@ -10,7 +10,7 @@ from apps.api.permissions import AuthenticatedRequest, RBACPermission from apps.api.views.shift_swap import BaseShiftSwapViewSet -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.throttlers.user_throttle import UserThrottle from apps.schedules.models import ShiftSwapRequest from apps.user_management.models import User @@ -23,7 +23,7 @@ class ShiftSwapViewSet(RateLimitHeadersMixin, BaseShiftSwapViewSet): # set authentication and permission classes - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/slack_channels.py b/engine/apps/public_api/views/slack_channels.py index 77581f3dde..35f384021a 100644 --- a/engine/apps/public_api/views/slack_channels.py +++ b/engine/apps/public_api/views/slack_channels.py @@ -3,7 +3,7 @@ from rest_framework.viewsets import GenericViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.slack_channel import SlackChannelSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.slack.models import SlackChannel @@ -12,7 +12,7 @@ class SlackChannelView(RateLimitHeadersMixin, mixins.ListModelMixin, GenericViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/teams.py b/engine/apps/public_api/views/teams.py index 490e74efb1..6d399bade5 100644 --- a/engine/apps/public_api/views/teams.py +++ b/engine/apps/public_api/views/teams.py @@ -3,7 +3,7 @@ from rest_framework.permissions import IsAuthenticated from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.teams import TeamSerializer from apps.public_api.tf_sync import is_request_from_terraform, sync_teams_on_tf_request from apps.public_api.throttlers.user_throttle import UserThrottle @@ -14,7 +14,7 @@ class TeamView(PublicPrimaryKeyMixin, RetrieveModelMixin, ListModelMixin, viewsets.GenericViewSet): serializer_class = TeamSerializer - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/user_groups.py b/engine/apps/public_api/views/user_groups.py index ced7f626bf..bb1dac7f37 100644 --- a/engine/apps/public_api/views/user_groups.py +++ b/engine/apps/public_api/views/user_groups.py @@ -3,7 +3,7 @@ from rest_framework.viewsets import GenericViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.user_groups import UserGroupSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.slack.models import SlackUserGroup @@ -12,7 +12,7 @@ class UserGroupView(RateLimitHeadersMixin, mixins.ListModelMixin, GenericViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/users.py b/engine/apps/public_api/views/users.py index 97315fe202..129096e560 100644 --- a/engine/apps/public_api/views/users.py +++ b/engine/apps/public_api/views/users.py @@ -6,7 +6,11 @@ from rest_framework.viewsets import ReadOnlyModelViewSet from apps.api.permissions import LegacyAccessControlRole, RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication, UserScheduleExportAuthentication +from apps.auth_token.auth import ( + ApiTokenAuthentication, + GrafanaServiceAccountAuthentication, + UserScheduleExportAuthentication, +) from apps.public_api.custom_renderers import CalendarRenderer from apps.public_api.serializers import FastUserSerializer, UserSerializer from apps.public_api.tf_sync import is_request_from_terraform, sync_users_on_tf_request @@ -35,7 +39,7 @@ class Meta: class UserView(RateLimitHeadersMixin, ShortSerializerMixin, ReadOnlyModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/webhooks.py b/engine/apps/public_api/views/webhooks.py index 8f75148b71..b1a6a47bb1 100644 --- a/engine/apps/public_api/views/webhooks.py +++ b/engine/apps/public_api/views/webhooks.py @@ -6,7 +6,7 @@ from rest_framework.viewsets import ModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.webhooks import ( WebhookCreateSerializer, WebhookResponseSerializer, @@ -21,7 +21,7 @@ class WebhooksView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py b/engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py new file mode 100644 index 0000000000..e4d1913827 --- /dev/null +++ b/engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.16 on 2024-11-06 21:13 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('schedules', '0019_auto_20241021_1735'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='oncallschedule', + name='channel', + ), + ] diff --git a/engine/apps/schedules/models/on_call_schedule.py b/engine/apps/schedules/models/on_call_schedule.py index 544ec847b2..e57cf4bc48 100644 --- a/engine/apps/schedules/models/on_call_schedule.py +++ b/engine/apps/schedules/models/on_call_schedule.py @@ -209,8 +209,6 @@ class OnCallSchedule(PolymorphicModel): name = models.CharField(max_length=200) - # TODO: drop this field in a subsequent release, this has been migrated to slack_channel field - channel = models.CharField(max_length=100, null=True, default=None) slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, diff --git a/engine/apps/slack/alert_group_slack_service.py b/engine/apps/slack/alert_group_slack_service.py index 9bb9510bde..ed614305f8 100644 --- a/engine/apps/slack/alert_group_slack_service.py +++ b/engine/apps/slack/alert_group_slack_service.py @@ -35,9 +35,8 @@ def __init__( self._slack_client = SlackClient(slack_team_identity) def update_alert_group_slack_message(self, alert_group: "AlertGroup") -> None: - from apps.alerts.models import AlertReceiveChannel - logger.info(f"Update message for alert_group {alert_group.pk}") + try: self._slack_client.chat_update( channel=alert_group.slack_message.channel_id, @@ -47,7 +46,7 @@ def update_alert_group_slack_message(self, alert_group: "AlertGroup") -> None: ) logger.info(f"Message has been updated for alert_group {alert_group.pk}") except SlackAPIRatelimitError as e: - if alert_group.channel.integration != AlertReceiveChannel.INTEGRATION_MAINTENANCE: + if not alert_group.channel.is_maintenace_integration: if not alert_group.channel.is_rate_limited_in_slack: alert_group.channel.start_send_rate_limit_message_task(e.retry_after) logger.info( diff --git a/engine/apps/slack/scenarios/alertgroup_timeline.py b/engine/apps/slack/scenarios/alertgroup_timeline.py index 08f74b8802..7ca3a56f2d 100644 --- a/engine/apps/slack/scenarios/alertgroup_timeline.py +++ b/engine/apps/slack/scenarios/alertgroup_timeline.py @@ -2,6 +2,7 @@ from apps.api.permissions import RBACPermission from apps.slack.chatops_proxy_routing import make_private_metadata +from apps.slack.constants import BLOCK_SECTION_TEXT_MAX_SIZE from apps.slack.scenarios import scenario_step from apps.slack.scenarios.slack_renderer import AlertGroupLogSlackRenderer from apps.slack.types import ( @@ -47,9 +48,13 @@ def process_scenario( future_log_report = AlertGroupLogSlackRenderer.render_alert_group_future_log_report_text(alert_group) blocks: typing.List[Block.Section] = [] if past_log_report: - blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": past_log_report}}) + blocks.append( + {"type": "section", "text": {"type": "mrkdwn", "text": past_log_report[:BLOCK_SECTION_TEXT_MAX_SIZE]}} + ) if future_log_report: - blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": future_log_report}}) + blocks.append( + {"type": "section", "text": {"type": "mrkdwn", "text": future_log_report[:BLOCK_SECTION_TEXT_MAX_SIZE]}} + ) view: ModalView = { "blocks": blocks, diff --git a/engine/apps/slack/scenarios/distribute_alerts.py b/engine/apps/slack/scenarios/distribute_alerts.py index 3a7090e320..3d3c1a60a8 100644 --- a/engine/apps/slack/scenarios/distribute_alerts.py +++ b/engine/apps/slack/scenarios/distribute_alerts.py @@ -141,22 +141,6 @@ def _post_alert_group_to_slack( channel_id=channel_id, ) - # If alert was made out of a message: - if alert_group.channel.integration == AlertReceiveChannel.INTEGRATION_SLACK_CHANNEL: - channel = json.loads(alert.integration_unique_data)["channel"] - result = self._slack_client.chat_postMessage( - channel=channel, - thread_ts=json.loads(alert.integration_unique_data)["ts"], - text=":rocket: <{}|Incident registered!>".format(alert_group.slack_message.permalink), - team=slack_team_identity, - ) - alert_group.slack_messages.create( - slack_id=result["ts"], - organization=alert_group.channel.organization, - _slack_team_identity=self.slack_team_identity, - channel_id=channel, - ) - alert.delivered = True except SlackAPITokenError: alert_group.reason_to_skip_escalation = AlertGroup.ACCOUNT_INACTIVE @@ -172,7 +156,7 @@ def _post_alert_group_to_slack( logger.info("Not delivering alert due to channel is archived.") except SlackAPIRatelimitError as e: # don't rate limit maintenance alert - if alert_group.channel.integration != AlertReceiveChannel.INTEGRATION_MAINTENANCE: + if not alert_group.channel.is_maintenace_integration: alert_group.reason_to_skip_escalation = AlertGroup.RATE_LIMITED alert_group.save(update_fields=["reason_to_skip_escalation"]) alert_group.channel.start_send_rate_limit_message_task(e.retry_after) diff --git a/engine/apps/user_management/migrations/0027_serviceaccount.py b/engine/apps/user_management/migrations/0027_serviceaccount.py new file mode 100644 index 0000000000..dc9e520b3b --- /dev/null +++ b/engine/apps/user_management/migrations/0027_serviceaccount.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.15 on 2024-11-12 13:13 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0026_auto_20241017_1919'), + ] + + operations = [ + migrations.CreateModel( + name='ServiceAccount', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('grafana_id', models.PositiveIntegerField()), + ('login', models.CharField(max_length=300)), + ('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='service_accounts', to='user_management.organization')), + ], + options={ + 'unique_together': {('grafana_id', 'organization')}, + }, + ), + ] diff --git a/engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py b/engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py new file mode 100644 index 0000000000..6d415bdb44 --- /dev/null +++ b/engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.16 on 2024-11-06 21:11 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0027_serviceaccount'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='organization', + name='general_log_channel_id', + ), + ] diff --git a/engine/apps/user_management/models/__init__.py b/engine/apps/user_management/models/__init__.py index e2bcd4c7f0..2fd5a9aa1e 100644 --- a/engine/apps/user_management/models/__init__.py +++ b/engine/apps/user_management/models/__init__.py @@ -1,4 +1,5 @@ from .user import User # noqa: F401, isort: skip from .organization import Organization # noqa: F401 from .region import Region # noqa: F401 +from .service_account import ServiceAccount, ServiceAccountUser # noqa: F401 from .team import Team # noqa: F401 diff --git a/engine/apps/user_management/models/organization.py b/engine/apps/user_management/models/organization.py index aac0aeae9a..2fbeefca1d 100644 --- a/engine/apps/user_management/models/organization.py +++ b/engine/apps/user_management/models/organization.py @@ -162,9 +162,6 @@ class Organization(MaintainableObject): slack_team_identity = models.ForeignKey( "slack.SlackTeamIdentity", on_delete=models.PROTECT, null=True, default=None, related_name="organizations" ) - - # TODO: drop this field in a subsequent release, this has been migrated to default_slack_channel field - general_log_channel_id = models.CharField(max_length=100, null=True, default=None) default_slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, diff --git a/engine/apps/user_management/models/service_account.py b/engine/apps/user_management/models/service_account.py new file mode 100644 index 0000000000..5082f7b965 --- /dev/null +++ b/engine/apps/user_management/models/service_account.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import List + +from django.db import models + +from apps.user_management.models import Organization + + +@dataclass +class ServiceAccountUser: + """Authenticated service account in public API requests.""" + + service_account: "ServiceAccount" + organization: "Organization" # required for insight logs interface + username: str # required for insight logs interface + public_primary_key: str # required for insight logs interface + role: str # required for permissions check + permissions: List[str] # required for permissions check + + @property + def id(self): + return self.service_account.id + + @property + def pk(self): + return self.service_account.id + + @property + def organization_id(self): + return self.organization.id + + @property + def is_authenticated(self): + return True + + +class ServiceAccount(models.Model): + organization: "Organization" + + grafana_id = models.PositiveIntegerField() + organization = models.ForeignKey(Organization, on_delete=models.CASCADE, related_name="service_accounts") + login = models.CharField(max_length=300) + + class Meta: + unique_together = ("grafana_id", "organization") + + @property + def username(self): + # required for insight logs interface + return self.login + + @property + def public_primary_key(self): + # required for insight logs interface + return f"service-account:{self.grafana_id}" diff --git a/engine/apps/user_management/tests/factories.py b/engine/apps/user_management/tests/factories.py index ccfbb8586e..a33aefaca1 100644 --- a/engine/apps/user_management/tests/factories.py +++ b/engine/apps/user_management/tests/factories.py @@ -1,6 +1,6 @@ import factory -from apps.user_management.models import Organization, Region, Team, User +from apps.user_management.models import Organization, Region, ServiceAccount, Team, User from common.utils import UniqueFaker @@ -41,3 +41,11 @@ class RegionFactory(factory.DjangoModelFactory): class Meta: model = Region + + +class ServiceAccountFactory(factory.DjangoModelFactory): + grafana_id = UniqueFaker("pyint") + login = UniqueFaker("user_name") + + class Meta: + model = ServiceAccount diff --git a/engine/config_integrations/heartbeat.py b/engine/config_integrations/heartbeat.py deleted file mode 100644 index 60699c4507..0000000000 --- a/engine/config_integrations/heartbeat.py +++ /dev/null @@ -1,29 +0,0 @@ -# Main -enabled = True -title = "Heartbeat" -slug = "heartbeat" -short_description = None -description = None -is_displayed_on_web = False -is_featured = False -is_able_to_autoresolve = True -is_demo_alert_enabled = False - -description = None - -# Default templates -slack_title = """\ -*<{{ grafana_oncall_link }}|#{{ grafana_oncall_incident_id }} {{ payload.get("title", "Title undefined (check Slack Title Template)") }}>* via {{ integration_name }} -{% if source_link %} - (*<{{ source_link }}|source>*) -{%- endif %}""" - -grouping_id = """\ -{{ payload.get("id", "") }}{{ payload.get("user_defined_id", "") }} -""" - -resolve_condition = '{{ payload.get("is_resolve", False) == True }}' - -acknowledge_condition = None - -example_payload = None diff --git a/engine/config_integrations/slack_channel.py b/engine/config_integrations/slack_channel.py deleted file mode 100644 index 05021935f1..0000000000 --- a/engine/config_integrations/slack_channel.py +++ /dev/null @@ -1,44 +0,0 @@ -# Main -enabled = True -title = "Slack Channel" -slug = "slack_channel" -short_description = None -description = None -is_displayed_on_web = False -is_featured = False -is_able_to_autoresolve = False -is_demo_alert_enabled = False - -description = None - -# Default templates -slack_title = """\ -{% if source_link -%} -*<{{ source_link }}|<#{{ payload.get("channel", "") }}>>* -{%- else -%} -<#{{ payload.get("channel", "") }}> -{%- endif %}""" - -web_title = """\ -{% if source_link -%} -[#{{ grafana_oncall_incident_id }}]{{ source_link }}) <#{{ payload.get("channel", "") }}>>* -{%- else -%} -*#{{ grafana_oncall_incident_id }}* <#{{ payload.get("channel", "") }}> -{%- endif %}""" - -telegram_title = """\ -{% if source_link -%} -#{{ grafana_oncall_incident_id }} {{ payload.get("channel", "") }} -{%- else -%} -*#{{ grafana_oncall_incident_id }}* <#{{ payload.get("channel", "") }}> -{%- endif %}""" - -grouping_id = '{{ payload.get("ts", "") }}' - -resolve_condition = None - -acknowledge_condition = None - -source_link = '{{ payload.get("amixr_mixin", {}).get("permalink", "")}}' - -example_payload = None diff --git a/engine/conftest.py b/engine/conftest.py index a95383dd94..0b66e3adea 100644 --- a/engine/conftest.py +++ b/engine/conftest.py @@ -1,3 +1,4 @@ +import binascii import datetime import json import os @@ -46,11 +47,14 @@ LegacyAccessControlRole, RBACPermission, ) +from apps.auth_token import constants as auth_token_constants +from apps.auth_token.crypto import hash_token_string from apps.auth_token.models import ( ApiAuthToken, GoogleOAuth2Token, IntegrationBacksyncAuthToken, PluginAuthToken, + ServiceAccountToken, SlackAuthToken, ) from apps.base.models.user_notification_policy_log_record import ( @@ -102,7 +106,13 @@ TelegramVerificationCodeFactory, ) from apps.user_management.models.user import User, listen_for_user_model_save -from apps.user_management.tests.factories import OrganizationFactory, RegionFactory, TeamFactory, UserFactory +from apps.user_management.tests.factories import ( + OrganizationFactory, + RegionFactory, + ServiceAccountFactory, + TeamFactory, + UserFactory, +) from apps.webhooks.presets.preset_options import WebhookPresetOptions from apps.webhooks.tests.factories import CustomWebhookFactory, WebhookResponseFactory from apps.webhooks.tests.test_webhook_presets import ( @@ -252,6 +262,30 @@ def _make_user_for_organization(organization, role: typing.Optional[LegacyAccess return _make_user_for_organization +@pytest.fixture +def make_service_account_for_organization(make_user): + def _make_service_account_for_organization(organization, **kwargs): + return ServiceAccountFactory(organization=organization, **kwargs) + + return _make_service_account_for_organization + + +@pytest.fixture +def make_token_for_service_account(): + def _make_token_for_service_account(service_account, token_string): + prefix_length = len(ServiceAccountToken.GRAFANA_SA_PREFIX) + token_key = token_string[prefix_length : prefix_length + auth_token_constants.TOKEN_KEY_LENGTH] + hashable_token = binascii.hexlify(token_string.encode()).decode() + digest = hash_token_string(hashable_token) + return ServiceAccountToken.objects.create( + service_account=service_account, + token_key=token_key, + digest=digest, + ) + + return _make_token_for_service_account + + @pytest.fixture def make_token_for_organization(): def _make_token_for_organization(organization): diff --git a/engine/engine/middlewares.py b/engine/engine/middlewares.py index c3da3c4c2b..0173323bc0 100644 --- a/engine/engine/middlewares.py +++ b/engine/engine/middlewares.py @@ -28,9 +28,13 @@ def log_message(request, response, tag, message=""): ) if hasattr(request, "user") and request.user and request.user.id and hasattr(request.user, "organization"): user_id = request.user.id + if hasattr(request.user, "service_account"): + message += f"service_account_id={user_id} " + else: + message += f"user_id={user_id} " org_id = request.user.organization.id org_slug = request.user.organization.org_slug - message += f"user_id={user_id} org_id={org_id} org_slug={org_slug} " + message += f"org_id={org_id} org_slug={org_slug} " if request.path.startswith("/integrations/v1"): split_path = request.path.split("/") integration_type = split_path[3] diff --git a/engine/settings/base.py b/engine/settings/base.py index 5b6eba8f14..0f73c8d5af 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -867,6 +867,7 @@ class BrokerTypes: INBOUND_EMAIL_ESP = os.getenv("INBOUND_EMAIL_ESP") INBOUND_EMAIL_DOMAIN = os.getenv("INBOUND_EMAIL_DOMAIN") INBOUND_EMAIL_WEBHOOK_SECRET = os.getenv("INBOUND_EMAIL_WEBHOOK_SECRET") +INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = os.getenv("INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN") INSTALLED_ONCALL_INTEGRATIONS = [ # Featured @@ -877,11 +878,9 @@ class BrokerTypes: "config_integrations.formatted_webhook", "config_integrations.kapacitor", "config_integrations.elastalert", - "config_integrations.heartbeat", "config_integrations.inbound_email", "config_integrations.maintenance", "config_integrations.manual", - "config_integrations.slack_channel", "config_integrations.zabbix", "config_integrations.direct_paging", # Actually it's Grafana 8 integration. @@ -987,3 +986,5 @@ class BrokerTypes: SYNC_V2_MAX_TASKS = getenv_integer("SYNC_V2_MAX_TASKS", 6) SYNC_V2_PERIOD_SECONDS = getenv_integer("SYNC_V2_PERIOD_SECONDS", 240) SYNC_V2_BATCH_SIZE = getenv_integer("SYNC_V2_BATCH_SIZE", 500) + +AUDITED_ALERT_GROUP_MAX_RETRIES = getenv_integer("AUDITED_ALERT_GROUP_MAX_RETRIES", 1) diff --git a/engine/settings/celery_task_routes.py b/engine/settings/celery_task_routes.py index 04a8ffa49a..7ef62121dd 100644 --- a/engine/settings/celery_task_routes.py +++ b/engine/settings/celery_task_routes.py @@ -12,7 +12,6 @@ "common.oncall_gateway.tasks.delete_oncall_connector_async": {"queue": "default"}, "common.oncall_gateway.tasks.create_slack_connector_async_v2": {"queue": "default"}, "common.oncall_gateway.tasks.delete_slack_connector_async_v2": {"queue": "default"}, - "apps.heartbeat.tasks.integration_heartbeat_checkup": {"queue": "default"}, "apps.heartbeat.tasks.process_heartbeat_task": {"queue": "default"}, "apps.labels.tasks.update_labels_cache": {"queue": "default"}, "apps.labels.tasks.update_instances_labels_cache": {"queue": "default"}, diff --git a/helm/oncall/values.yaml b/helm/oncall/values.yaml index 8ca59a2664..826e0a5be3 100644 --- a/helm/oncall/values.yaml +++ b/helm/oncall/values.yaml @@ -639,6 +639,9 @@ grafana: serve_from_sub_path: true feature_toggles: enable: externalServiceAccounts + accessControlOnCall: false + env: + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true persistence: enabled: true # Disable psp as PodSecurityPolicy is deprecated in v1.21+, unavailable in v1.25+