diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
new file mode 100644
index 00000000..09778b48
--- /dev/null
+++ b/.buildkite/pipeline.yml
@@ -0,0 +1,27 @@
+
+steps:
+ - command:
+ - "python -m pip install black"
+ - "black --check ."
+ label: "Check Code Formatting"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.7"
+
+ - command:
+ - "python -m pip install flake8"
+ - "flake8 ."
+ label: "Check Code Style"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.7"
+
+ - wait
+
+ - command:
+ - "python -m pip install -e ."
+ - "trial tests"
+ label: "Run unit tests"
+ plugins:
+ - docker#v3.0.1:
+ image: "python:3.7"
diff --git a/.gitignore b/.gitignore
index b6e20577..d4c05d27 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,9 @@
*.pyc
sygnal.conf
+sygnal.yaml
gunicorn_config.py
var/
sygnal.pid
sygnal.db
+/_trial_temp*
+/.idea
diff --git a/README.rst b/README.rst
index 5d3d7c97..1dc04fd9 100644
--- a/README.rst
+++ b/README.rst
@@ -1,47 +1,28 @@
Introduction
============
-sygnal is a reference Push Gateway for Matrix (http://matrix.org/).
+Sygnal is a reference Push Gateway for `Matrix `_.
-See
-http://matrix.org/docs/spec/client_server/r0.2.0.html#id51 for a high level overview of how notifications work in Matrix.
+See https://matrix.org/docs/spec/client_server/r0.5.0#id134
+for a high level overview of how notifications work in Matrix.
-http://matrix.org/docs/spec/push_gateway/unstable.html#post-matrix-push-r0-notify
-describes the protocol that Matrix Home Servers use to send notifications to
-Push Gateways such as sygnal.
+https://matrix.org/docs/spec/push_gateway/r0.1.0
+describes the protocol that Matrix Home Servers use to send notifications to Push Gateways such as Sygnal.
Setup
=====
-sygnal is a plain WSGI app, although these instructions use gunicorn which
-will create a complete, standalone webserver. When used with gunicorn,
-sygnal can use gunicorn's extra hook to perform a clean shutdown which tries as
-hard as possible to ensure no messages are lost.
+Sygnal is configured through a YAML configuration file.
+By default, this configuration file is assumed to be named ``sygnal.yaml`` and to be in the working directory.
+To change this, set the ``SYGNAL_CONF`` environment variable to the path to your configuration file.
+A sample configuration file is provided in this repository;
+see ``sygnal.yaml.sample``.
-There are two config files:
- * sygnal.conf (The app-specific config file)
- * gunicorn_config.py (gunicorn's config file)
+The `apps:` section is where you set up different apps that are to be handled.
+Each app should be given its own subsection, with the key of that subsection being the app's ``app_id``.
+Keys in this section take the form of the ``app_id``, as specified when setting up a Matrix pusher
+(see https://matrix.org/docs/spec/client_server/r0.5.0#post-matrix-client-r0-pushers-set).
-sygnal.conf contains configuration for sygnal itself. This includes the location
-and level of sygnal's log file. The [apps] section is where you set up different
-apps that are to be handled. Keys in this section take the form of the app_id
-and the name of the configuration key, joined by a single dot ('.'). The app_id
-is as specified when setting up a Matrix pusher (see
-http://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-pushers-set). So for example, the `type` for
-the App ID of `com.example.myapp.ios.prod` would be specified as follows::
-
- com.example.myapp.ios.prod.type = foobar
-
-By default sygnal.conf is assumed to be in the working directory, but the path
-can be overriden by setting the `sygnal.conf` environment variable.
-
-The gunicorn sample config contains everything necessary to run sygnal from
-gunicorn. The shutdown hook handles clean shutdown. You can customise other
-aspects of this file as you wish to change, for example, the log location or the
-bind port.
-
-Note that sygnal uses gevent. You should therefore not change the worker class
-or the number of workers (which should be 1: in gevent, a single worker uses
-multiple greenlets to handle all the requests).
+See the sample configuration for examples.
App Types
---------
@@ -49,42 +30,44 @@ There are two supported App Types:
apns
This sends push notifications to iOS apps via the Apple Push Notification
- Service (APNS). It expects the 'certfile' parameter to be a path relative to
- sygnal's working directory of a PEM file containing the APNS certificate and
- unencrypted private key.
+ Service (APNS).
-gcm
- This sends messages via Google Cloud Messaging (GCM) and hence can be used
- to deliver notifications to Android apps. It expects the 'apiKey' parameter
- to contain the secret GCM key.
+ Expected configuration depends on which kind of authentication you wish to use:
-Running
-=======
-To run with gunicorn:
+ |
+
+ For certificate-based authentication:
+ It expects:
-gunicorn -c gunicorn_config.py sygnal:app
+ * the ``certfile`` parameter to be a path relative to
+ sygnal's working directory of a PEM file containing the APNS certificate and
+ unencrypted private key.
-You can customise the gunicorn_config.py to determine whether this daemonizes or runs in the foreground.
+ For token-based authentication:
+ It expects:
-Gunicorn maintains its own logging in addition to the app's, so the access_log
-and error_log contain gunicorn's accesses and gunicorn specific errors. The log
-file in sygnal.conf contains app level logging.
+ * the 'keyfile' parameter to be a path relative to Sygnal's working directory of a p8 file
+ * the 'key_id' parameter
+ * the 'team_id' parameter
+ * the 'topic' parameter
-Clean shutdown
-==============
-The code for APNS uses a grace period where it waits for errors to come down the
-socket before declaring it safe for the app to shut down (due to the design of
-APNS). Terminating using SIGTERM performs a clean shutdown::
+gcm
+ This sends messages via Google/Firebase Cloud Messaging (GCM/FCM) and hence can be used
+ to deliver notifications to Android apps. It expects the 'api_key' parameter
+ to contain the 'Server key', which can be acquired from Firebase Console at:
+ ``https://console.firebase.google.com/project//settings/cloudmessaging/``
- kill -TERM `cat sygnal.pid`
+Running
+=======
-Restarting sygnal using SIGHUP will handle this gracefully::
+``python -m sygnal.sygnal``
- kill -HUP `cat sygnal.pid`
+Python 3.7 or higher is required.
Log Rotation
============
-Gunicorn appends to files but does not use a rotating logger.
-Sygnal's app logging does the same. Gunicorn will re-open all log files
-(including the app's) when sent SIGUSR1. The recommended configuration is
-therefore to use logrotate.
+Sygnal's logging appends to files but does not use a rotating logger.
+The recommended configuration is therefore to use ``logrotate``.
+The log file will be automatically reopened if the log file changes, for example
+due to ``logrotate``.
+
diff --git a/gunicorn_config.py.sample b/gunicorn_config.py.sample
deleted file mode 100644
index aa70f82c..00000000
--- a/gunicorn_config.py.sample
+++ /dev/null
@@ -1,43 +0,0 @@
-# Customise these settings to your needs
-bind = '0.0.0.0:5000'
-daemon = False
-accesslog = 'var/access_log'
-errorlog = 'var/error_log'
-#accesslog = '-'
-#errorlog = '-'
-loglevel = 'debug'
-pidfile = 'sygnal.pid'
-worker_connections = 1000
-keepalive = 2
-proc_name = 'sygnal'
-
-# It is inadvisable to change anything below here,
-# since these settings make gunicorn work appropriately
-# with sygnal.
-
-preload_app = False
-workers = 1
-worker_class = 'gevent'
-
-
-def worker_exit(server, worker):
- # Used in the hooks
- try:
- # This must be imported inside the hook or it won't find
- # the import
- import sygnal
- # NB. We obviously need to clean up in the worker, not
- # the arbiter process. worker_exit runs in the worker
- # (despite the docs claiming it runs after the worker
- # has exited)
- # We use a flask hook to handle the worker setup.
- # Unfortunately flask doesn't have a shutdown hook
- # (it's not a standard thing in WSGI).
- sygnal.shutdown()
-
- except:
- # We swallow this exception because it's generally a completely
- # useless, "No module named sygnal" due to it failing to load
- # the sygnal module because an exception was thrown.
- print("Failed to load sygnal - check your log file")
-
diff --git a/setup.cfg b/setup.cfg
index d08a9014..b9ef0b54 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,3 +1,15 @@
+[flake8]
+# line length defaulted to by black
+max-line-length = 88
+
+# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes
+# for error codes. The ones we ignore are:
+# W503: line break before binary operator
+# W504: line break after binary operator
+# E203: whitespace before ':' (which is contrary to pep8?)
+# (this is a subset of those ignored in Synapse)
+ignore=W503,W504,E203
+
[isort]
line_length = 80
not_skip = __init__.py
diff --git a/setup.py b/setup.py
index 15d55a11..1a351a7b 100755
--- a/setup.py
+++ b/setup.py
@@ -2,6 +2,7 @@
# Copyright 2014 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
# limitations under the License.
import os
+
from setuptools import setup, find_packages
@@ -26,18 +28,20 @@
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
+
setup(
name="matrix-sygnal",
version=read("VERSION").strip(),
packages=find_packages(exclude=["tests", "tests.*"]),
description="Reference Push Gateway for Matrix Notifications",
install_requires=[
- "flask>=1.0.2",
- "gevent>=1.0.1",
- "pushbaby>=0.0.9",
- "grequests",
- "six",
- "prometheus_client>=0.7.0,<0.8"
+ "Twisted>=19.2.1",
+ "prometheus_client>=0.7.0,<0.8",
+ "aioapns>=1.7",
+ "pyyaml>=5.1.1",
+ "service_identity>=18.1.0",
+ "jaeger-client>=4.0.0",
+ "opentracing>=2.2.0",
],
long_description=read("README.rst"),
)
diff --git a/sygnal.conf.sample b/sygnal.conf.sample
deleted file mode 100644
index 0da51b56..00000000
--- a/sygnal.conf.sample
+++ /dev/null
@@ -1,14 +0,0 @@
-[log]
-loglevel = debug
-logfile = var/sygnal.log
-
-# [metrics]
-# prometheus_port = 8800 # Enable prometheus metrics and bind to this port
-# prometheus_addr = 127.0.0.1 # Optional bind address
-# sentry_dsn = https://... # Enable sentry integration
-
-[apps]
-com.example.myapp.ios.type = apns
-com.example.myapp.ios.certfile = com.example.myApp_prod_APNS.pem
-com.example.myapp.android.type = gcm
-com.example.myapp.android.apiKey = your_api_key_for_gcm
diff --git a/sygnal.yaml.sample b/sygnal.yaml.sample
new file mode 100644
index 00000000..95a6f41c
--- /dev/null
+++ b/sygnal.yaml.sample
@@ -0,0 +1,168 @@
+##
+# This is a configuration for Sygnal, the reference Push Gateway for Matrix
+# See: matrix.org
+##
+
+## Logging #
+#
+log:
+ # Specify a Python logging 'dictConfig', as described at:
+ # https://docs.python.org/3.7/library/logging.config.html#logging.config.dictConfig
+ #
+ setup:
+ version: 1
+ formatters:
+ normal:
+ format: "%(asctime)s [%(process)d] %(levelname)-5s %(name)s %(message)s"
+ handlers:
+ # This handler prints to Standard Error
+ #
+ stderr:
+ class: "logging.StreamHandler"
+ formatter: "normal"
+ stream: "ext://sys.stderr"
+
+ # This handler prints to Standard Output.
+ #
+ stdout:
+ class: "logging.StreamHandler"
+ formatter: "normal"
+ stream: "ext://sys.stdout"
+
+ # This handler demonstrates logging to a text file on the filesystem.
+ # You can use logrotate(8) to perform log rotation.
+ #
+ file:
+ class: "logging.handlers.WatchedFileHandler"
+ formatter: "normal"
+ filename: "./sygnal.log"
+ loggers:
+ # sygnal.access contains the access logging lines.
+ # Comment out this section if you don't want to give access logging
+ # any special treatment.
+ #
+ sygnal.access:
+ propagate: false
+ handlers: ["stdout"]
+ level: "INFO"
+
+ # sygnal contains log lines from Sygnal itself.
+ # You can comment out this section to fall back to the root logger.
+ #
+ sygnal:
+ propagate: false
+ handlers: ["stderr", "file"]
+
+ root:
+ # Specify the handler(s) to send log messages to.
+ handlers: ["stderr"]
+ level: "INFO"
+
+
+ access:
+ # Specify whether or not to trust the IP address in the `X-Forwarded-For`
+ # header. In general, you want to enable this if and only if you are using a
+ # reverse proxy which is configured to emit it.
+ #
+ x_forwarded_for: false
+
+## HTTP Server (Matrix Push Gateway API) #
+#
+http:
+ # Specify a list of interface addresses to bind to.
+ #
+ # This example listens on the IPv4 loopback device:
+ bind_addresses: ['127.0.0.1']
+ # This example listens on all IPv4 interfaces:
+ #bind_addresses: ['0.0.0.0']
+ # This example listens on all IPv4 and IPv6 interfaces:
+ #bind_addresses: ['0.0.0.0', '::']
+
+ # Specify the port number to listen on.
+ #
+ port: 5000
+
+## Metrics #
+#
+metrics:
+ ## Prometheus #
+ #
+ prometheus:
+ # Specify whether or not to enable Prometheus.
+ #
+ enabled: false
+
+ # Specify an address for the Prometheus HTTP Server to listen on.
+ #
+ address: '127.0.0.1'
+
+ # Specify a port for the Prometheus HTTP Server to listen on.
+ #
+ port: 8000
+
+ ## OpenTracing #
+ #
+ opentracing:
+ # Specify whether or not to enable OpenTracing.
+ #
+ enabled: false
+
+ # Specify an implementation of OpenTracing to use. Currently only 'jaeger'
+ # is supported.
+ #
+ implementation: jaeger
+
+ # Specify the service name to be reported to the tracer.
+ #
+ service_name: sygnal
+
+ # Specify configuration values to pass to jaeger_client.
+ #
+ jaeger:
+ sampler:
+ type: 'const'
+ param: 1
+# local_agent:
+# reporting_host: '127.0.0.1'
+# reporting_port:
+ logging: true
+
+ ## Sentry #
+ #
+ sentry:
+ # Specify whether or not to enable Sentry.
+ #
+ enabled: false
+
+ # Specify your Sentry DSN if you enable Sentry
+ #
+ #dsn: "https://@sentry.example.org/"
+
+## Pushkins/Apps #
+#
+# Add a section for every push application here.
+# Specify the pushkey for the application and also the type.
+# For the type, you may specify a fully-qualified Python classname if desired.
+#
+apps:
+ # This is an example APNs push configuration using certificate authentication.
+ #
+ #com.example.myapp.ios:
+ # type: apns
+ # certfile: com.example.myApp_prod_APNS.pem
+
+ # This is an example APNs push configuration using key authentication.
+ #
+ #com.example.myapp2.ios:
+ # type: apns
+ # keyfile: my_key.p8
+ # key_id: MY_KEY_ID
+ # team_id: MY_TEAM_ID
+ # topic: MY_TOPIC
+
+ # This is an example GCM/FCM push configuration.
+ #
+ #com.example.myapp.android:
+ # type: gcm
+ # api_key: your_api_key_for_gcm
+
diff --git a/sygnal/__init__.py b/sygnal/__init__.py
index 21fd9c1b..e69de29b 100644
--- a/sygnal/__init__.py
+++ b/sygnal/__init__.py
@@ -1,365 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014 OpenMarket Ltd
-# Copyright 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from six.moves import configparser
-
-import json
-import logging
-import os
-import sys
-import threading
-from logging.handlers import WatchedFileHandler
-
-import flask
-from flask import Flask, request
-
-import prometheus_client
-from prometheus_client import Counter
-
-import sygnal.db
-from sygnal.exceptions import InvalidNotificationException
-
-
-NOTIFS_RECEIVED_COUNTER = Counter(
- "sygnal_notifications_received", "Number of notification pokes received",
-)
-
-NOTIFS_RECEIVED_DEVICE_PUSH_COUNTER = Counter(
- "sygnal_notifications_devices_received",
- "Number of devices been asked to push",
-)
-
-NOTIFS_BY_PUSHKIN = Counter(
- "sygnal_per_pushkin_type",
- "Number of pushes sent via each type of pushkin",
- labelnames=["pushkin"],
-)
-
-logger = logging.getLogger(__name__)
-
-app = Flask('sygnal')
-app.debug = False
-app.config.from_object(__name__)
-
-CONFIG_SECTIONS = ['http', 'log', 'apps', 'db', 'metrics']
-CONFIG_DEFAULTS = {
- 'port': '5000',
- 'loglevel': 'info',
- 'logfile': '',
- 'dbfile': 'sygnal.db'
-}
-
-pushkins = {}
-
-class RequestIdFilter(logging.Filter):
- """A logging filter which adds the current request id to each record"""
- def filter(self, record):
- request_id = ''
- if flask.has_request_context():
- request_id = flask.g.get('request_id', '')
- record.request_id = request_id
- return True
-
-class RequestCounter(object):
- def __init__(self):
- self._count = 0
- self._lock = threading.Lock()
-
- def get(self):
- with self._lock:
- c = self._count
- self._count = c + 1
- return c
-
-
-request_count = RequestCounter()
-
-
-class Tweaks:
- def __init__(self, raw):
- self.sound = None
-
- if 'sound' in raw:
- self.sound = raw['sound']
-
-
-class Device:
- def __init__(self, raw):
- self.app_id = None
- self.pushkey = None
- self.pushkey_ts = 0
- self.data = None
- self.tweaks = None
-
- if 'app_id' not in raw:
- raise InvalidNotificationException("Device with no app_id")
- if 'pushkey' not in raw:
- raise InvalidNotificationException("Device with no pushkey")
- if 'pushkey_ts' in raw:
- self.pushkey_ts = raw['pushkey_ts']
- if 'tweaks' in raw:
- self.tweaks = Tweaks(raw['tweaks'])
- else:
- self.tweaks = Tweaks({})
- self.app_id = raw['app_id']
- self.pushkey = raw['pushkey']
- if 'data' in raw:
- self.data = raw['data']
-
-
-class Counts:
- def __init__(self, raw):
- self.unread = None
- self.missed_calls = None
-
- if 'unread' in raw:
- self.unread = raw['unread']
- if 'missed_calls' in raw:
- self.missed_calls = raw['missed_calls']
-
-
-class Notification:
- def __init__(self, notif):
- optional_attrs = [
- 'room_name',
- 'room_alias',
- 'prio',
- 'membership',
- 'sender_display_name',
- 'content',
- 'event_id',
- 'room_id',
- 'user_is_target',
- 'type',
- 'sender',
- ]
- for a in optional_attrs:
- if a in notif:
- self.__dict__[a] = notif[a]
- else:
- self.__dict__[a] = None
-
- if 'devices' not in notif or not isinstance(notif['devices'], list):
- raise InvalidNotificationException("Expected list in 'devices' key")
-
- if 'counts' in notif:
- self.counts = Counts(notif['counts'])
- else:
- self.counts = Counts({})
-
- self.devices = [Device(d) for d in notif['devices']]
-
-
-class Pushkin(object):
- def __init__(self, name):
- self.name = name
-
- def setup(self):
- pass
-
- def getConfig(self, key):
- if not self.cfg.has_option('apps', '%s.%s' % (self.name, key)):
- return None
- return self.cfg.get('apps', '%s.%s' % (self.name, key))
-
- def dispatchNotification(self, n):
- pass
-
- def shutdown(self):
- pass
-
-
-class SygnalContext:
- pass
-
-
-class ClientError(Exception):
- pass
-
-
-def parse_config():
- cfg = configparser.SafeConfigParser(CONFIG_DEFAULTS)
- # Make keys case-sensitive
- cfg.optionxform = str
- for sect in CONFIG_SECTIONS:
- try:
- cfg.add_section(sect)
- except configparser.DuplicateSectionError:
- pass
- cfg.read(os.getenv("SYGNAL_CONF", "sygnal.conf"))
- return cfg
-
-
-def make_pushkin(kind, name):
- if '.' in kind:
- toimport = kind
- else:
- toimport = "sygnal.%spushkin" % kind
- toplevelmodule = __import__(toimport)
- pushkinmodule = getattr(toplevelmodule, "%spushkin" % kind)
- clarse = getattr(pushkinmodule, "%sPushkin" % kind.capitalize())
- return clarse(name)
-
-
-@app.before_request
-def log_request():
- flask.g.request_id = "%s-%i" % (
- request.method, request_count.get(),
- )
- logger.info("Processing request %s", request.url)
-
-
-@app.after_request
-def log_processed_request(response):
- logger.info(
- "Processed request %s: %i",
- request.url, response.status_code,
- )
- return response
-
-@app.errorhandler(ClientError)
-def handle_client_error(e):
- resp = flask.jsonify({ 'error': { 'msg': str(e) } })
- resp.status_code = 400
- return resp
-
-@app.route('/')
-def root():
- return ""
-
-@app.route('/_matrix/push/v1/notify', methods=['POST'])
-def notify():
- try:
- body = json.loads(request.data)
- except Exception:
- raise ClientError("Expecting json request body")
-
- if 'notification' not in body or not isinstance(body['notification'], dict):
- msg = "Invalid notification: expecting object in 'notification' key"
- logger.warn(msg)
- flask.abort(400, msg)
-
- try:
- notif = Notification(body['notification'])
- except InvalidNotificationException as e:
- logger.exception("Invalid notification")
- flask.abort(400, e.message)
-
- if len(notif.devices) == 0:
- msg = "No devices in notification"
- logger.warn(msg)
- flask.abort(400, msg)
-
- NOTIFS_RECEIVED_COUNTER.inc()
-
- rej = []
-
- for d in notif.devices:
- NOTIFS_RECEIVED_DEVICE_PUSH_COUNTER.inc()
-
- appid = d.app_id
- if appid not in pushkins:
- logger.warn("Got notification for unknown app ID %s", appid)
- rej.append(d.pushkey)
- continue
-
- pushkin = pushkins[appid]
- logger.debug(
- "Sending push to pushkin %s for app ID %s",
- pushkin.name, appid,
- )
-
- NOTIFS_BY_PUSHKIN.labels(pushkin.name).inc()
-
- try:
- rej.extend(pushkin.dispatchNotification(notif))
- except:
- logger.exception("Failed to send push")
- flask.abort(500, "Failed to send push")
- return flask.jsonify({
- "rejected": rej
- })
-
-
-def setup():
- cfg = parse_config()
-
- logging.getLogger().setLevel(getattr(logging, cfg.get('log', 'loglevel').upper()))
- logfile = cfg.get('log', 'logfile')
- if logfile != '':
- handler = WatchedFileHandler(logfile)
- handler.addFilter(RequestIdFilter())
- formatter = logging.Formatter(
- '%(asctime)s [%(process)d] %(levelname)-5s '
- '%(request_id)s %(name)s %(message)s'
- )
- handler.setFormatter(formatter)
- logging.getLogger().addHandler(handler)
- else:
- logging.basicConfig()
-
- if cfg.has_option("metrics", "sentry_dsn"):
- # Only import sentry if enabled
- import sentry_sdk
- from sentry_sdk.integrations.flask import FlaskIntegration
- sentry_sdk.init(
- dsn=cfg.get("metrics", "sentry_dsn"),
- integrations=[FlaskIntegration()],
- )
-
- if cfg.has_option("metrics", "prometheus_port"):
- prometheus_client.start_http_server(
- port=cfg.getint("metrics", "prometheus_port"),
- addr=cfg.get("metrics", "prometheus_addr"),
- )
-
- ctx = SygnalContext()
- ctx.database = sygnal.db.Db(cfg.get('db', 'dbfile'))
-
- for key,val in cfg.items('apps'):
- parts = key.rsplit('.', 1)
- if len(parts) < 2:
- continue
- if parts[1] == 'type':
- try:
- pushkins[parts[0]] = make_pushkin(val, parts[0])
- except:
- logger.exception("Failed to load module for kind %s", val)
- raise
-
- if len(pushkins) == 0:
- logger.error("No app IDs are configured. Edit sygnal.conf to define some.")
- sys.exit(1)
-
- for p in pushkins:
- pushkins[p].cfg = cfg
- pushkins[p].setup(ctx)
- logger.info("Configured with app IDs: %r", pushkins.keys())
-
- logger.error("Setup completed")
-
-def shutdown():
- logger.info("Starting shutdown...")
- i = 0
- for p in pushkins.values():
- logger.info("Shutting down (%d/%d)..." % (i+1, len(pushkins)))
- p.shutdown()
- i += 1
- logger.info("Shutdown complete...")
-
-
-setup()
diff --git a/sygnal/apnspushkin.py b/sygnal/apnspushkin.py
index 0fb5d98b..8b4fcad4 100644
--- a/sygnal/apnspushkin.py
+++ b/sygnal/apnspushkin.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,240 +14,371 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import asyncio
+import base64
+import logging
+import os
+from uuid import uuid4
+
+import aioapns
+from aioapns import APNs, NotificationRequest
+from opentracing import logs, tags
+from prometheus_client import Histogram, Counter
+from twisted.internet.defer import Deferred
+
+from sygnal import apnstruncate
+from sygnal.exceptions import (
+ PushkinSetupException,
+ TemporaryNotificationDispatchException,
+ NotificationDispatchException,
+)
+from sygnal.notifications import Pushkin
+from sygnal.utils import twisted_sleep, NotificationLoggerAdapter
+logger = logging.getLogger(__name__)
-from . import Pushkin
-from .exceptions import PushkinSetupException, NotificationDispatchException
+SEND_TIME_HISTOGRAM = Histogram(
+ "sygnal_apns_request_time", "Time taken to send HTTP request to APNS"
+)
-from pushbaby import PushBaby
-import pushbaby.errors
+RESPONSE_STATUS_CODES_COUNTER = Counter(
+ "sygnal_apns_status_codes",
+ "Number of HTTP response status codes received from APNS",
+ labelnames=["pushkin", "code"],
+)
-import logging
-import base64
-import time
-import gevent
-logger = logging.getLogger(__name__)
+class ApnsPushkin(Pushkin):
+ """
+ Relays notifications to the Apple Push Notification Service.
+ """
-create_failed_table_query = u"""
-CREATE TABLE IF NOT EXISTS apns_failed (id INTEGER PRIMARY KEY, b64token TEXT NOT NULL,
-last_failure_ts INTEGER NOT NULL,
-last_failure_type varchar(10) not null, last_failure_code INTEGER default -1, token_invalidated_ts INTEGER default -1);
-"""
+ # Errors for which the token should be rejected
+ TOKEN_ERROR_REASON = "Unregistered"
+ TOKEN_ERROR_CODE = 410
-create_failed_index_query = u"""
-CREATE UNIQUE INDEX IF NOT EXISTS b64token on apns_failed(b64token);
-"""
+ MAX_TRIES = 3
+ RETRY_DELAY_BASE = 10
-# Max length of individual fields. Pushbaby will truncate appropriate
-# fields of the push to fit the whole body in the max size, but it's
-# not very fast so keep things to a sensible size.
-MAX_FIELD_LENGTH = 1024
+ MAX_FIELD_LENGTH = 1024
+ MAX_JSON_BODY_SIZE = 4096
-class ApnsPushkin(Pushkin):
- MAX_TRIES = 2
- DELETE_FEEDBACK_AFTER_SECS = 28 * 24 * 60 * 60 # a month(ish)
- # These are the only ones of the errors returned in the APNS stream
- # that we want to feed back. Anything else is nothing to do with the
- # token.
- ERRORS_TO_FEED_BACK = (
- pushbaby.errors.INVALID_TOKEN_SIZE,
- pushbaby.errors.INVALID_TOKEN,
- )
-
- def __init__(self, name):
- super(ApnsPushkin, self).__init__(name);
-
- def setup(self, ctx):
- self.db = ctx.database
- self.certfile = self.getConfig('certfile')
- plaf = self.getConfig('platform')
- if not plaf or plaf == 'production' or plaf == 'prod':
- self.plaf = 'prod'
- elif plaf == 'sandbox':
- self.plaf = 'sandbox'
+ UNDERSTOOD_CONFIG_FIELDS = {"type", "platform", "certfile"}
+
+ def __init__(self, name, sygnal, config):
+ super().__init__(name, sygnal, config)
+
+ nonunderstood = set(self.cfg.keys()).difference(self.UNDERSTOOD_CONFIG_FIELDS)
+ if len(nonunderstood) > 0:
+ logger.warning(
+ "The following configuration fields are not understood: %s",
+ nonunderstood,
+ )
+
+ platform = self.get_config("platform")
+ if not platform or platform == "production" or platform == "prod":
+ self.use_sandbox = False
+ elif platform == "sandbox":
+ self.use_sandbox = True
else:
- raise PushkinSetupException("Invalid platform: %s" % plaf)
-
- self.db.query(create_failed_table_query)
- self.db.query(create_failed_index_query)
-
- self.pushbaby = PushBaby(certfile=self.certfile, platform=self.plaf)
- self.pushbaby.on_push_failed = self.on_push_failed
- logger.info("APNS with cert file %s on %s platform", self.certfile, self.plaf)
-
- # poll feedback in a little bit, not while we're busy starting up
- gevent.spawn_later(10, self.do_feedback_poll)
-
- def dispatchNotification(self, n):
- tokens = {}
- for d in n.devices:
- tokens[d.pushkey] = d
-
- # check for tokens that have previously failed
- token_set_str = u"(" + u",".join([u"?" for _ in tokens.keys()]) + u")"
- feed_back_errors_set_str = u"(" + u",".join([u"?" for _ in ApnsPushkin.ERRORS_TO_FEED_BACK]) + u")"
- q = (
- "SELECT b64token,last_failure_type,last_failure_code,token_invalidated_ts "+
- "FROM apns_failed WHERE b64token IN "+token_set_str+
- " and ("+
- "(last_failure_type = 'error' and last_failure_code in "+feed_back_errors_set_str+") "+
- "or (last_failure_type = 'feedback')"+
- ")"
+ raise PushkinSetupException(f"Invalid platform: {platform}")
+
+ certfile = self.get_config("certfile")
+ keyfile = self.get_config("keyfile")
+ if not certfile and not keyfile:
+ raise PushkinSetupException(
+ "You must provide a path to an APNs certificate, or an APNs token."
)
- args = []
- args.extend([unicode(t) for t in tokens.keys()])
- args.extend([(u"%d" % e) for e in ApnsPushkin.ERRORS_TO_FEED_BACK])
- rows = self.db.query(q, args, fetch='all')
-
- rejected = []
- for row in rows:
- token_invalidated_ts = row[3]
- token_pushkey_ts = tokens[row[0]].pushkey_ts
- if token_pushkey_ts < token_invalidated_ts:
- logger.warn(
- "Rejecting token %s with ts %d. Last failure of type '%s' code %d, invalidated at %d",
- row[0], token_pushkey_ts, row[1], row[2], token_invalidated_ts
+
+ if certfile:
+ if not os.path.exists(certfile):
+ raise PushkinSetupException(
+ f"The APNs certificate '{certfile}' does not exist."
)
- rejected.append(row[0])
- del tokens[row[0]]
- else:
- logger.info("Have a failure for token %s of type '%s' at %d code %d but this token postdates it (%d): allowing.", row[0], row[1], token_invalidated_ts, row[2], token_pushkey_ts)
- # This pushkey may be alive again, but we don't delete the
- # failure because HSes should probably have a fresh token
- # if they actually want to use it
-
- payload = None
- if n.event_id and not n.type:
- payload = self.get_payload_event_id_only(n)
else:
- payload = self.get_payload_full(n)
-
- prio = 10
- if n.prio == 'low':
- prio = 5
-
- tries = 0
- for t,d in tokens.items():
- while tries < ApnsPushkin.MAX_TRIES:
- thispayload = payload
- if 'aps' in thispayload:
- thispayload = payload.copy()
- thispayload['aps'] = thispayload['aps'].copy()
- if d.tweaks.sound:
- thispayload['aps']['sound'] = d.tweaks.sound
- logger.info("Sending (attempt %i): '%s' -> %s", tries, thispayload, t)
- poke_start_time = time.time()
+ # keyfile
+ if not os.path.exists(keyfile):
+ raise PushkinSetupException(
+ f"The APNs key file '{keyfile}' does not exist."
+ )
+ if not self.get_config("key_id"):
+ raise PushkinSetupException("You must supply key_id.")
+ if not self.get_config("team_id"):
+ raise PushkinSetupException("You must supply team_id.")
+ if not self.get_config("topic"):
+ raise PushkinSetupException("You must supply topic.")
+
+ if self.get_config("certfile") is not None:
+ self.apns_client = APNs(
+ client_cert=self.get_config("certfile"), use_sandbox=self.use_sandbox
+ )
+ else:
+ self.apns_client = APNs(
+ key=self.get_config("keyfile"),
+ key_id=self.get_config("key_id"),
+ team_id=self.get_config("team_id"),
+ topic=self.get_config("topic"),
+ use_sandbox=self.use_sandbox,
+ )
+
+ # without this, aioapns will retry every second forever.
+ self.apns_client.pool.max_connection_attempts = 3
+
+ async def _dispatch_request(self, log, span, device, shaved_payload, prio):
+ """
+ Actually attempts to dispatch the notification once.
+ """
+
+ # this is no good: APNs expects ID to be in their format
+ # so we can't just derive a
+ # notif_id = context.request_id + f"-{n.devices.index(device)}"
+
+ notif_id = str(uuid4())
+
+ log.info(f"Sending as APNs-ID {notif_id}")
+ span.set_tag("apns_id", notif_id)
+
+ device_token = base64.b64decode(device.pushkey).hex()
+
+ request = NotificationRequest(
+ device_token=device_token,
+ message=shaved_payload,
+ priority=prio,
+ notification_id=notif_id,
+ )
+
+ try:
+ with SEND_TIME_HISTOGRAM.time():
+ response = await self._send_notification(request)
+ except aioapns.ConnectionError:
+ raise TemporaryNotificationDispatchException("aioapns Connection Failure")
+
+ code = int(response.status)
+
+ span.set_tag(tags.HTTP_STATUS_CODE, code)
+
+ RESPONSE_STATUS_CODES_COUNTER.labels(pushkin=self.name, code=code).inc()
+
+ if response.is_successful:
+ return []
+ else:
+ # .description corresponds to the 'reason' response field
+ span.set_tag("apns_reason", response.description)
+ if (
+ code == self.TOKEN_ERROR_CODE
+ or response.description == self.TOKEN_ERROR_REASON
+ ):
+ return [device.pushkey]
+ else:
+ if 500 <= code < 600:
+ raise TemporaryNotificationDispatchException(
+ f"{response.status} {response.description}"
+ )
+ else:
+ raise NotificationDispatchException(
+ f"{response.status} {response.description}"
+ )
+
+ async def dispatch_notification(self, n, device, context):
+ log = NotificationLoggerAdapter(logger, {"request_id": context.request_id})
+
+ # The pushkey is kind of secret because you can use it to send push
+ # to someone.
+ # span_tags = {"pushkey": device.pushkey}
+ span_tags = {}
+
+ with self.sygnal.tracer.start_span(
+ "apns_dispatch", tags=span_tags, child_of=context.opentracing_span
+ ) as span_parent:
+
+ if n.event_id and not n.type:
+ payload = self._get_payload_event_id_only(n)
+ else:
+ payload = self._get_payload_full(n, log)
+
+ if payload is None:
+ # Nothing to do
+ span_parent.log_kv({logs.EVENT: "apns_no_payload"})
+ return
+ prio = 10
+ if n.prio == "low":
+ prio = 5
+
+ shaved_payload = apnstruncate.truncate(
+ payload, max_length=self.MAX_JSON_BODY_SIZE
+ )
+
+ for retry_number in range(self.MAX_TRIES):
try:
- res = self.pushbaby.send(thispayload, base64.b64decode(t), priority=prio)
- logger.info("Sent (%f secs): -> %s", time.time() - poke_start_time, t)
- break
- except:
- logger.exception("Exception sending push -> %s" % (t, ))
+ log.debug("Trying")
+
+ span_tags = {"retry_num": retry_number}
+
+ with self.sygnal.tracer.start_span(
+ "apns_dispatch_try", tags=span_tags, child_of=span_parent
+ ) as span:
+ return await self._dispatch_request(
+ log, span, device, shaved_payload, prio
+ )
+ except TemporaryNotificationDispatchException as exc:
+ retry_delay = self.RETRY_DELAY_BASE * (2 ** retry_number)
+ if exc.custom_retry_delay is not None:
+ retry_delay = exc.custom_retry_delay
+
+ log.warning(
+ "Temporary failure, will retry in %d seconds",
+ retry_delay,
+ exc_info=True,
+ )
+
+ span_parent.log_kv(
+ {"event": "temporary_fail", "retrying_in": retry_delay}
+ )
+
+ if retry_number == self.MAX_TRIES - 1:
+ raise NotificationDispatchException(
+ "Retried too many times."
+ ) from exc
+ else:
+ await twisted_sleep(
+ retry_delay, twisted_reactor=self.sygnal.reactor
+ )
+
+ def _get_payload_event_id_only(self, n):
+ """
+ Constructs a payload for a notification where we know only the event ID.
+ Args:
+ n: The notification to construct a payload for.
+
+ Returns:
+ The APNs payload as a nested dicts.
+ """
+ payload = {}
- tries += 1
+ if n.room_id:
+ payload["room_id"] = n.room_id
+ if n.event_id:
+ payload["event_id"] = n.event_id
+
+ if n.counts.unread is not None:
+ payload["unread_count"] = n.counts.unread
+ if n.counts.missed_calls is not None:
+ payload["missed_calls"] = n.counts.missed_calls
- if tries == ApnsPushkin.MAX_TRIES:
- raise NotificationDispatchException("Max retries exceeded")
+ return payload
- return rejected
+ def _get_payload_full(self, n, log):
+ """
+ Constructs a payload for a notification.
+ Args:
+ n: The notification to construct a payload for.
+ log: A logger.
- def get_payload_full(self, n):
+ Returns:
+ The APNs payload as nested dicts.
+ """
from_display = n.sender
if n.sender_display_name is not None:
from_display = n.sender_display_name
- from_display = from_display[0:MAX_FIELD_LENGTH]
+ from_display = from_display[0 : self.MAX_FIELD_LENGTH]
loc_key = None
loc_args = None
- if n.type == 'm.room.message' or n.type == 'm.room.encrypted':
+ if n.type == "m.room.message" or n.type == "m.room.encrypted":
room_display = None
if n.room_name:
- room_display = n.room_name[0:MAX_FIELD_LENGTH]
+ room_display = n.room_name[0 : self.MAX_FIELD_LENGTH]
elif n.room_alias:
- room_display = n.room_alias[0:MAX_FIELD_LENGTH]
+ room_display = n.room_alias[0 : self.MAX_FIELD_LENGTH]
content_display = None
action_display = None
is_image = False
- if n.content and 'msgtype' in n.content and 'body' in n.content:
- if 'body' in n.content:
- if n.content['msgtype'] == 'm.text':
- content_display = n.content['body']
- elif n.content['msgtype'] == 'm.emote':
- action_display = n.content['body']
+ if n.content and "msgtype" in n.content and "body" in n.content:
+ if "body" in n.content:
+ if n.content["msgtype"] == "m.text":
+ content_display = n.content["body"]
+ elif n.content["msgtype"] == "m.emote":
+ action_display = n.content["body"]
else:
- # fallback: 'body' should always be user-visible text in an m.room.message
- content_display = n.content['body']
- if n.content['msgtype'] == 'm.image':
+ # fallback: 'body' should always be user-visible text
+ # in an m.room.message
+ content_display = n.content["body"]
+ if n.content["msgtype"] == "m.image":
is_image = True
if room_display:
if is_image:
- loc_key = 'IMAGE_FROM_USER_IN_ROOM'
+ loc_key = "IMAGE_FROM_USER_IN_ROOM"
loc_args = [from_display, content_display, room_display]
elif content_display:
- loc_key = 'MSG_FROM_USER_IN_ROOM_WITH_CONTENT'
+ loc_key = "MSG_FROM_USER_IN_ROOM_WITH_CONTENT"
loc_args = [from_display, room_display, content_display]
elif action_display:
- loc_key = 'ACTION_FROM_USER_IN_ROOM'
+ loc_key = "ACTION_FROM_USER_IN_ROOM"
loc_args = [room_display, from_display, action_display]
else:
- loc_key = 'MSG_FROM_USER_IN_ROOM'
+ loc_key = "MSG_FROM_USER_IN_ROOM"
loc_args = [from_display, room_display]
else:
if is_image:
- loc_key = 'IMAGE_FROM_USER'
+ loc_key = "IMAGE_FROM_USER"
loc_args = [from_display, content_display]
elif content_display:
- loc_key = 'MSG_FROM_USER_WITH_CONTENT'
+ loc_key = "MSG_FROM_USER_WITH_CONTENT"
loc_args = [from_display, content_display]
elif action_display:
- loc_key = 'ACTION_FROM_USER'
+ loc_key = "ACTION_FROM_USER"
loc_args = [from_display, action_display]
else:
- loc_key = 'MSG_FROM_USER'
+ loc_key = "MSG_FROM_USER"
loc_args = [from_display]
- elif n.type == 'm.call.invite':
+ elif n.type == "m.call.invite":
is_video_call = False
# This detection works only for hs that uses WebRTC for calls
- if n.content and 'offer' in n.content and 'sdp' in n.content['offer']:
- sdp = n.content['offer']['sdp']
- if 'm=video' in sdp:
+ if n.content and "offer" in n.content and "sdp" in n.content["offer"]:
+ sdp = n.content["offer"]["sdp"]
+ if "m=video" in sdp:
is_video_call = True
if is_video_call:
- loc_key = 'VIDEO_CALL_FROM_USER'
+ loc_key = "VIDEO_CALL_FROM_USER"
else:
- loc_key = 'VOICE_CALL_FROM_USER'
+ loc_key = "VOICE_CALL_FROM_USER"
loc_args = [from_display]
- elif n.type == 'm.room.member':
+ elif n.type == "m.room.member":
if n.user_is_target:
- if n.membership == 'invite':
+ if n.membership == "invite":
if n.room_name:
- loc_key = 'USER_INVITE_TO_NAMED_ROOM'
- loc_args = [from_display, n.room_name[0:MAX_FIELD_LENGTH]]
+ loc_key = "USER_INVITE_TO_NAMED_ROOM"
+ loc_args = [
+ from_display,
+ n.room_name[0 : self.MAX_FIELD_LENGTH],
+ ]
elif n.room_alias:
- loc_key = 'USER_INVITE_TO_NAMED_ROOM'
- loc_args = [from_display, n.room_alias[0:MAX_FIELD_LENGTH]]
+ loc_key = "USER_INVITE_TO_NAMED_ROOM"
+ loc_args = [
+ from_display,
+ n.room_alias[0 : self.MAX_FIELD_LENGTH],
+ ]
else:
- loc_key = 'USER_INVITE_TO_CHAT'
+ loc_key = "USER_INVITE_TO_CHAT"
loc_args = [from_display]
elif n.type:
# A type of message was received that we don't know about
# but it was important enough for a push to have got to us
- loc_key = 'MSG_FROM_USER'
+ loc_key = "MSG_FROM_USER"
loc_args = [from_display]
aps = {}
if loc_key:
- aps['alert'] = {'loc-key': loc_key }
+ aps["alert"] = {"loc-key": loc_key}
if loc_args:
- aps['alert']['loc-args'] = loc_args
+ aps["alert"]["loc-args"] = loc_args
badge = None
if n.counts.unread is not None:
@@ -257,88 +389,25 @@ def get_payload_full(self, n):
badge += n.counts.missed_calls
if badge is not None:
- aps['badge'] = badge
+ aps["badge"] = badge
if loc_key:
- aps['content-available'] = 1
+ aps["content-available"] = 1
if loc_key is None and badge is None:
- logger.info("Nothing to do for alert of type %s", n.type)
- return rejected
+ log.info("Nothing to do for alert of type %s", n.type)
+ return None
payload = {}
if loc_key and n.room_id:
- payload['room_id'] = n.room_id
+ payload["room_id"] = n.room_id
- payload['aps'] = aps
+ payload["aps"] = aps
return payload
- def get_payload_event_id_only(self, n):
- payload = {}
-
- if n.room_id:
- payload['room_id'] = n.room_id
- if n.event_id:
- payload['event_id'] = n.event_id
-
- if n.counts.unread is not None:
- payload['unread_count'] = n.counts.unread
- if n.counts.missed_calls is not None:
- payload['missed_calls'] = n.counts.missed_calls
-
- return payload
-
- def on_push_failed(self, token, identifier, status):
- logger.error("Error sending push to token %s, status %s", token, status)
- # We store all errors (could be useful to get failures instead of digging
- # through logs) but note that not all failures mean we should stop sending
- # to that token.
- self.db.query(
- "INSERT OR REPLACE INTO apns_failed "+
- "(b64token, last_failure_ts, last_failure_type, last_failure_code, token_invalidated_ts) "+
- " VALUES (?, ?, 'error', ?, ?)",
- (base64.b64encode(token), long(time.time()), status, long(time.time()))
+ async def _send_notification(self, request):
+ return await Deferred.fromFuture(
+ asyncio.ensure_future(self.apns_client.send_notification(request))
)
-
- def do_feedback_poll(self):
- logger.info("Polling feedback...")
- try:
- feedback = self.pushbaby.get_all_feedback()
- for fb in feedback:
- b64token = unicode(base64.b64encode(fb.token))
- logger.info("Got feedback for token %s which is invalid as of %d", b64token, long(fb.ts))
- self.db.query(
- "INSERT OR REPLACE INTO apns_failed "+
- "(b64token, last_failure_ts, last_failure_type, token_invalidated_ts) "+
- " VALUES (?, ?, 'feedback', ?)",
- (b64token, long(time.time()), long(fb.ts))
- )
- logger.info("Stored %d feedback items", len(feedback))
-
- # great, we're good until tomorrow
- gevent.spawn_later(24 * 60 * 60, self.do_feedback_poll)
- except:
- logger.exception("Failed to poll for feedback, trying again in 10 minutes")
- gevent.spawn_later(10 * 60, self.do_feedback_poll)
-
- self.prune_failures()
-
- def prune_failures(self):
- """
- Delete any failures older than a set amount of time.
- This is the only way we delete them - we can't delete
- them once we've sent them because a token could be in use by
- more than one Home Server.
- """
- cutoff = long(time.time()) - ApnsPushkin.DELETE_FEEDBACK_AFTER_SECS
- deleted = self.db.query(
- "DELETE FROM apns_failed WHERE last_failure_ts < ?",
- (cutoff,)
- )
- logger.info("deleted %d stale items from failure table", deleted)
-
- def shutdown(self):
- while self.pushbaby.messages_in_flight():
- gevent.wait(timeout=1.0)
diff --git a/sygnal/apnstruncate.py b/sygnal/apnstruncate.py
new file mode 100644
index 00000000..7b52a0e1
--- /dev/null
+++ b/sygnal/apnstruncate.py
@@ -0,0 +1,131 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copied and adapted from
+# https://raw.githubusercontent.com/matrix-org/pushbaby/master/pushbaby/truncate.py
+import json
+
+
+def json_encode(payload):
+ return json.dumps(payload, ensure_ascii=False).encode()
+
+
+class BodyTooLongException(Exception):
+ pass
+
+
+def is_too_long(payload, max_length=2048):
+ """
+ Returns True if the given payload dictionary is too long for a push.
+ Note that the maximum is now 2kB "In iOS 8 and later" although in
+ practice, payloads over 256 bytes (the old limit) are still
+ delivered to iOS 7 or earlier devices.
+
+ Maximum is 4 kiB in the new APNs with the HTTP/2 interface.
+ """
+ return len(json_encode(payload)) > max_length
+
+
+def truncate(payload, max_length=2048):
+ """
+ Truncate APNs fields to make the payload fit within the max length
+ specified.
+ Only truncates fields that are safe to do so.
+
+ Args:
+ payload: nested dict that will be passed to APNs
+ max_length: Maximum length, in bytes, that the payload should occupy
+ when JSON-encoded.
+
+ Returns:
+ Nested dict which should comply with the maximum length restriction.
+
+ """
+ payload = payload.copy()
+ if "aps" not in payload:
+ if is_too_long(payload, max_length):
+ raise BodyTooLongException()
+ else:
+ return payload
+ aps = payload["aps"]
+
+ # first ensure all our choppables are str objects.
+ # We need them to be for truncating to work and this
+ # makes more sense than checking every time.
+ for c in _choppables_for_aps(aps):
+ val = _choppable_get(aps, c)
+ if isinstance(val, bytes):
+ _choppable_put(aps, c, val.decode())
+
+ # chop off whole unicode characters until it fits (or we run out of chars)
+ while is_too_long(payload, max_length):
+ longest = _longest_choppable(aps)
+ if longest is None:
+ raise BodyTooLongException()
+
+ txt = _choppable_get(aps, longest)
+ # Note that python's support for this is actually broken on some OSes
+ # (see test_apnstruncate.py)
+ txt = txt[:-1]
+ _choppable_put(aps, longest, txt)
+ payload["aps"] = aps
+
+ return payload
+
+
+def _choppables_for_aps(aps):
+ ret = []
+ if "alert" not in aps:
+ return ret
+
+ alert = aps["alert"]
+ if isinstance(alert, str):
+ ret.append(("alert",))
+ elif isinstance(alert, dict):
+ if "body" in alert:
+ ret.append(("alert.body",))
+ if "loc-args" in alert:
+ ret.extend([("alert.loc-args", i) for i in range(len(alert["loc-args"]))])
+
+ return ret
+
+
+def _choppable_get(aps, choppable):
+ if choppable[0] == "alert":
+ return aps["alert"]
+ elif choppable[0] == "alert.body":
+ return aps["alert"]["body"]
+ elif choppable[0] == "alert.loc-args":
+ return aps["alert"]["loc-args"][choppable[1]]
+
+
+def _choppable_put(aps, choppable, val):
+ if choppable[0] == "alert":
+ aps["alert"] = val
+ elif choppable[0] == "alert.body":
+ aps["alert"]["body"] = val
+ elif choppable[0] == "alert.loc-args":
+ aps["alert"]["loc-args"][choppable[1]] = val
+
+
+def _longest_choppable(aps):
+ longest = None
+ length_of_longest = 0
+ for c in _choppables_for_aps(aps):
+ val = _choppable_get(aps, c)
+ val_len = len(val.encode())
+ if val_len > length_of_longest:
+ longest = c
+ length_of_longest = val_len
+ return longest
diff --git a/sygnal/db.py b/sygnal/db.py
deleted file mode 100644
index a034e6ec..00000000
--- a/sygnal/db.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# -*- coding: utf-8 -*-
-
-# Copyright 2014 matrix.org
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import sqlite3
-import logging
-import threading
-from six.moves import queue
-import sys
-
-logger = logging.getLogger(__name__)
-
-class Db:
- def __init__(self, dbfile):
- self.dbfile = dbfile
- self.db_queue = queue.Queue()
- # Sqlite is blocking and does so in the c library so we can't
- # use gevent's monkey patching to make it play nice. We just
- # run all sqlite in a separate thread.
- self.dbthread = threading.Thread(target=self.db_loop)
- self.dbthread.setDaemon(True)
- self.dbthread.start()
-
- def db_loop(self):
- self.db = sqlite3.connect(self.dbfile)
- while True:
- job = self.db_queue.get()
- job()
-
- def query(self, query, args=(), fetch=None):
- res = {}
- ev = threading.Event()
- def runquery():
- try:
- c = self.db.cursor()
- c.execute(query, args)
- if fetch == 1 or fetch == 'one':
- res['rows'] = c.fetchone()
- elif fetch == 'all':
- res['rows'] = c.fetchall()
- elif fetch == None:
- self.db.commit()
- res['rowcount'] = c.rowcount
- except:
- logger.exception("Caught exception running db query %s", query)
- res['ex'] = sys.exc_info()[1]
- ev.set()
- self.db_queue.put(runquery)
- ev.wait()
- if 'ex' in res:
- raise res['ex']
- elif 'rows' in res:
- return res['rows']
- elif 'rowcount' in res:
- return res['rowcount']
diff --git a/sygnal/exceptions.py b/sygnal/exceptions.py
index 2d2d4d3c..b72a9260 100644
--- a/sygnal/exceptions.py
+++ b/sygnal/exceptions.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
class InvalidNotificationException(Exception):
pass
@@ -23,3 +24,13 @@ class PushkinSetupException(Exception):
class NotificationDispatchException(Exception):
pass
+
+class TemporaryNotificationDispatchException(Exception):
+ """
+ To be used by pushkins for errors that are not our fault and are
+ hopefully temporary, so the request should possibly be retried soon.
+ """
+
+ def __init__(self, *args: object, custom_retry_delay=None) -> None:
+ super().__init__(*args)
+ self.custom_retry_delay = custom_retry_delay
diff --git a/sygnal/gcmpushkin.py b/sygnal/gcmpushkin.py
index bd068d11..4effa2f6 100644
--- a/sygnal/gcmpushkin.py
+++ b/sygnal/gcmpushkin.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 Leon Handreke
# Copyright 2017 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,30 +14,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from __future__ import absolute_import, division, print_function, unicode_literals
-
+import json
import logging
import time
+from io import BytesIO
+from json import JSONDecodeError
-import grequests
-import gevent
-from requests import adapters, Session
-from prometheus_client import Histogram
+from opentracing import logs, tags
+from prometheus_client import Histogram, Counter
+from twisted.web.client import HTTPConnectionPool, Agent, FileBodyProducer, readBody
+from twisted.web.http_headers import Headers
-from . import Pushkin
+from sygnal.exceptions import (
+ TemporaryNotificationDispatchException,
+ NotificationDispatchException,
+)
+from sygnal.utils import twisted_sleep, NotificationLoggerAdapter
from .exceptions import PushkinSetupException
-
+from .notifications import Pushkin
SEND_TIME_HISTOGRAM = Histogram(
- "sygnal_gcm_request_time",
- "Time taken to send HTTP request",
+ "sygnal_gcm_request_time", "Time taken to send HTTP request to GCM"
)
+RESPONSE_STATUS_CODES_COUNTER = Counter(
+ "sygnal_gcm_status_codes",
+ "Number of HTTP response status codes received from GCM",
+ labelnames=["pushkin", "code"],
+)
logger = logging.getLogger(__name__)
-GCM_URL = "https://fcm.googleapis.com/fcm/send"
+GCM_URL = b"https://fcm.googleapis.com/fcm/send"
MAX_TRIES = 3
RETRY_DELAY_BASE = 10
MAX_BYTES_PER_FIELD = 1024
@@ -47,156 +56,319 @@
# though gcm-client 'helpfully' extracts these into a separate
# list.
BAD_PUSHKEY_FAILURE_CODES = [
- 'MissingRegistration',
- 'InvalidRegistration',
- 'NotRegistered',
- 'InvalidPackageName',
- 'MismatchSenderId',
+ "MissingRegistration",
+ "InvalidRegistration",
+ "NotRegistered",
+ "InvalidPackageName",
+ "MismatchSenderId",
]
# Failure codes that mean the message in question will never
# succeed, so don't retry, but the registration ID is fine
# so we should not reject it upstream.
-BAD_MESSAGE_FAILURE_CODES = [
- 'MessageTooBig',
- 'InvalidDataKey',
- 'InvalidTtl',
-]
+BAD_MESSAGE_FAILURE_CODES = ["MessageTooBig", "InvalidDataKey", "InvalidTtl"]
+
+DEFAULT_MAX_CONNECTIONS = 20
+
class GcmPushkin(Pushkin):
+ """
+ Pushkin that relays notifications to Google/Firebase Cloud Messaging.
+ """
+
+ UNDERSTOOD_CONFIG_FIELDS = {"type", "api_key"}
+
+ def __init__(self, name, sygnal, config, canonical_reg_id_store):
+ super(GcmPushkin, self).__init__(name, sygnal, config)
+
+ nonunderstood = set(self.cfg.keys()).difference(self.UNDERSTOOD_CONFIG_FIELDS)
+ if len(nonunderstood) > 0:
+ logger.warning(
+ "The following configuration fields are not understood: %s",
+ nonunderstood,
+ )
+
+ self.http_pool = HTTPConnectionPool(reactor=sygnal.reactor)
+ self.http_pool.maxPersistentPerHost = self.get_config(
+ "max_connections", DEFAULT_MAX_CONNECTIONS
+ )
- def __init__(self, name):
- super(GcmPushkin, self).__init__(name)
- self.session = Session()
- a = adapters.HTTPAdapter(pool_maxsize=20, pool_connections=20, pool_block=True)
- self.session.mount("https://", a)
+ self.http_agent = Agent(reactor=sygnal.reactor, pool=self.http_pool)
- def setup(self, ctx):
- self.db = ctx.database
+ self.db = sygnal.database
+ self.canonical_reg_id_store = canonical_reg_id_store
- self.api_key = self.getConfig('apiKey')
+ self.api_key = self.get_config("api_key")
if not self.api_key:
raise PushkinSetupException("No API key set in config")
- self.canonical_reg_id_store = CanonicalRegIdStore(self.db)
- def dispatchNotification(self, n):
- pushkeys = [device.pushkey for device in n.devices if device.app_id == self.name]
- # Resolve canonical IDs for all pushkeys
- pushkeys = [canonical_reg_id or reg_id for (reg_id, canonical_reg_id) in
- self.canonical_reg_id_store.get_canonical_ids(pushkeys).items()]
-
- data = GcmPushkin.build_data(n)
- headers = {
- "User-Agent": "sygnal",
- "Content-Type": "application/json",
- "Authorization": "key=%s" % (self.api_key,)
- }
+ @classmethod
+ async def create(cls, name, sygnal, config):
+ """
+ Override this if your pushkin needs to call async code in order to
+ be constructed. Otherwise, it defaults to just invoking the Python-standard
+ __init__ constructor.
+
+ Returns:
+ an instance of this Pushkin
+ """
+ logger.debug("About to set up CanonicalRegId Store")
+ canonical_reg_id_store = CanonicalRegIdStore()
+ await canonical_reg_id_store.setup(sygnal.database)
+ logger.debug("Finished setting up CanonicalRegId Store")
+
+ return cls(name, sygnal, config, canonical_reg_id_store)
+
+ async def _perform_http_request(self, body, headers):
+ """
+ Perform an HTTP request to the FCM server with the body and headers
+ specified.
+ Args:
+ body (nested dict): Body. Will be JSON-encoded.
+ headers (Headers): HTTP Headers.
+
+ Returns:
+
+ """
+ body_producer = FileBodyProducer(BytesIO(json.dumps(body).encode()))
+ try:
+ response = await self.http_agent.request(
+ b"POST", GCM_URL, headers=Headers(headers), bodyProducer=body_producer
+ )
+ except Exception as exception:
+ raise TemporaryNotificationDispatchException(
+ "GCM request failure"
+ ) from exception
+ response_text = (await readBody(response)).decode()
+ return response, response_text
+
+ async def _request_dispatch(self, n, log, body, headers, pushkeys, span):
+ poke_start_time = time.time()
- # TODO: Implement collapse_key to queue only one message per room.
failed = []
- for retry_number in range(0, MAX_TRIES):
- body = {
- "data": data,
- "priority": 'normal' if n.prio == 'low' else 'high',
- }
- if len(pushkeys) == 1:
- body['to'] = pushkeys[0]
- else:
- body['registration_ids'] = pushkeys
-
- logger.info("Sending (attempt %i): %r => %r", retry_number, data, pushkeys);
- poke_start_time = time.time()
-
- with SEND_TIME_HISTOGRAM.time():
- req = grequests.post(
- GCM_URL, json=body, headers=headers, timeout=10,
- session=self.session,
+ with SEND_TIME_HISTOGRAM.time():
+ response, response_text = await self._perform_http_request(body, headers)
+
+ RESPONSE_STATUS_CODES_COUNTER.labels(
+ pushkin=self.name, code=response.code
+ ).inc()
+
+ log.debug("GCM request took %f seconds", time.time() - poke_start_time)
+
+ span.set_tag(tags.HTTP_STATUS_CODE, response.code)
+
+ if 500 <= response.code < 600:
+ log.debug("%d from server, waiting to try again", response.code)
+
+ retry_after = None
+
+ for header_value in response.headers.getRawHeader(
+ b"retry-after", default=[]
+ ):
+ retry_after = int(header_value)
+ span.log_kv({"event": "gcm_retry_after", "retry_after": retry_after})
+
+ raise TemporaryNotificationDispatchException(
+ "GCM server error, hopefully temporary.", custom_retry_delay=retry_after
+ )
+ elif response.code == 400:
+ log.error(
+ "%d from server, we have sent something invalid! Error: %r",
+ response.code,
+ response_text,
+ )
+ # permanent failure: give up
+ raise NotificationDispatchException("Invalid request")
+ elif response.code == 401:
+ log.error(
+ "401 from server! Our API key is invalid? Error: %r", response_text
+ )
+ # permanent failure: give up
+ raise NotificationDispatchException("Not authorised to push")
+ elif 200 <= response.code < 300:
+ try:
+ resp_object = json.loads(response_text)
+ except JSONDecodeError:
+ raise NotificationDispatchException("Invalid JSON response from GCM.")
+ if "results" not in resp_object:
+ log.error(
+ "%d from server but response contained no 'results' key: %r",
+ response.code,
+ response_text,
)
- req.send()
-
- logger.debug("GCM request took %f seconds", time.time() - poke_start_time)
-
- if req.response is None:
- success = False
- logger.debug("Request failed, waiting to try again", req.exception)
- elif req.response.status_code / 100 == 5:
- success = False
- logger.debug("%d from server, waiting to try again", req.response.status_code)
- elif req.response.status_code == 400:
- logger.error(
- "%d from server, we have sent something invalid! Error: %r",
- req.response.status_code,
- req.response.text,
+ if len(resp_object["results"]) < len(pushkeys):
+ log.error(
+ "Sent %d notifications but only got %d responses!",
+ len(n.devices),
+ len(resp_object["results"]),
)
- # permanent failure: give up
- raise Exception("Invalid request")
- elif req.response.status_code == 401:
- logger.error(
- "401 from server! Our API key is invalid? Error: %r",
- req.response.text,
+ span.log_kv(
+ {
+ logs.EVENT: "gcm_response_mismatch",
+ "num_devices": len(n.devices),
+ "num_results": len(resp_object["results"]),
+ }
)
- # permanent failure: give up
- raise Exception("Not authorized to push")
- elif req.response.status_code / 100 == 2:
- resp_object = req.response.json()
- if 'results' not in resp_object:
- logger.error(
- "%d from server but response contained no 'results' key: %r",
- req.response.status_code, req.response.text,
+
+ # determine which pushkeys to retry or forget about
+ new_pushkeys = []
+ for i, result in enumerate(resp_object["results"]):
+ span.set_tag("gcm_regid_updated", "registration_id" in result)
+ if "registration_id" in result:
+ await self.canonical_reg_id_store.set_canonical_id(
+ pushkeys[i], result["registration_id"]
)
- if len(resp_object['results']) < len(pushkeys):
- logger.error(
- "Sent %d notifications but only got %d responses!",
- len(n.devices), len(resp_object['results'])
+ if "error" in result:
+ log.warning(
+ "Error for pushkey %s: %s", pushkeys[i], result["error"]
)
-
- new_pushkeys = []
- for i, result in enumerate(resp_object['results']):
- if 'registration_id' in result:
- self.canonical_reg_id_store.set_canonical_id(
- pushkeys[i], result['registration_id']
+ span.set_tag("gcm_error", result["error"])
+ if result["error"] in BAD_PUSHKEY_FAILURE_CODES:
+ log.info(
+ "Reg ID %r has permanently failed with code %r: "
+ "rejecting upstream",
+ pushkeys[i],
+ result["error"],
+ )
+ failed.append(pushkeys[i])
+ elif result["error"] in BAD_MESSAGE_FAILURE_CODES:
+ log.info(
+ "Message for reg ID %r has permanently failed with code %r",
+ pushkeys[i],
+ result["error"],
)
- if 'error' in result:
- logger.warn("Error for pushkey %s: %s", pushkeys[i], result['error'])
- if result['error'] in BAD_PUSHKEY_FAILURE_CODES:
- logger.info(
- "Reg ID %r has permanently failed with code %r: rejecting upstream",
- pushkeys[i], result['error']
- )
- failed.append(pushkeys[i])
- elif result['error'] in BAD_MESSAGE_FAILURE_CODES:
- logger.info(
- "Message for reg ID %r has permanently failed with code %r",
- pushkeys[i], result['error']
- )
- else:
- logger.info(
- "Reg ID %r has temporarily failed with code %r",
- pushkeys[i], result['error']
- )
- new_pushkeys.append(pushkeys[i])
- if len(new_pushkeys) == 0:
- return failed
- pushkeys = new_pushkeys
-
- retry_delay = RETRY_DELAY_BASE * (2 ** retry_number)
- if req.response and 'retry-after' in req.response.headers:
+ else:
+ log.info(
+ "Reg ID %r has temporarily failed with code %r",
+ pushkeys[i],
+ result["error"],
+ )
+ new_pushkeys.append(pushkeys[i])
+ return failed, new_pushkeys
+
+ async def dispatch_notification(self, n, device, context):
+ log = NotificationLoggerAdapter(logger, {"request_id": context.request_id})
+
+ pushkeys = [
+ device.pushkey for device in n.devices if device.app_id == self.name
+ ]
+ # Resolve canonical IDs for all pushkeys
+
+ if pushkeys[0] != device.pushkey:
+ # Only send notifications once, to all devices at once.
+ return []
+
+ # The pushkey is kind of secret because you can use it to send push
+ # to someone.
+ # span_tags = {"pushkeys": pushkeys}
+ span_tags = {"gcm_num_devices": len(pushkeys)}
+
+ with self.sygnal.tracer.start_span(
+ "gcm_dispatch", tags=span_tags, child_of=context.opentracing_span
+ ) as span_parent:
+ reg_id_mappings = await self.canonical_reg_id_store.get_canonical_ids(
+ pushkeys
+ )
+
+ reg_id_mappings = {
+ reg_id: canonical_reg_id or reg_id
+ for (reg_id, canonical_reg_id) in reg_id_mappings.items()
+ }
+
+ inverse_reg_id_mappings = {v: k for (k, v) in reg_id_mappings.items()}
+
+ data = GcmPushkin._build_data(n)
+ headers = {
+ b"User-Agent": ["sygnal"],
+ b"Content-Type": ["application/json"],
+ b"Authorization": ["key=%s" % (self.api_key,)],
+ }
+
+ # count the number of remapped registration IDs in the request
+ span_parent.set_tag(
+ "gcm_num_remapped_reg_ids_used",
+ [k != v for (k, v) in reg_id_mappings.items()].count(True),
+ )
+
+ # TODO: Implement collapse_key to queue only one message per room.
+ failed = []
+
+ body = {"data": data, "priority": "normal" if n.prio == "low" else "high"}
+
+ for retry_number in range(0, MAX_TRIES):
+ mapped_pushkeys = [reg_id_mappings[pk] for pk in pushkeys]
+
+ if len(pushkeys) == 1:
+ body["to"] = mapped_pushkeys[0]
+ else:
+ body["registration_ids"] = mapped_pushkeys
+
+ log.info("Sending (attempt %i) => %r", retry_number, mapped_pushkeys)
+
try:
- retry_delay = int(req.response.headers['retry-after'])
- except:
- pass
- logger.info("Retrying in %d seconds", retry_delay)
- gevent.sleep(seconds=retry_delay)
+ span_tags = {"retry_num": retry_number}
+
+ with self.sygnal.tracer.start_span(
+ "gcm_dispatch_try", tags=span_tags, child_of=span_parent
+ ) as span:
+ new_failed, new_pushkeys = await self._request_dispatch(
+ n, log, body, headers, mapped_pushkeys, span
+ )
+ pushkeys = new_pushkeys
+ failed += [
+ inverse_reg_id_mappings[canonical_pk]
+ for canonical_pk in new_failed
+ ]
+ if len(pushkeys) == 0:
+ break
+ except TemporaryNotificationDispatchException as exc:
+ retry_delay = RETRY_DELAY_BASE * (2 ** retry_number)
+ if exc.custom_retry_delay is not None:
+ retry_delay = exc.custom_retry_delay
+
+ log.warning(
+ "Temporary failure, will retry in %d seconds",
+ retry_delay,
+ exc_info=True,
+ )
+
+ span_parent.log_kv(
+ {"event": "temporary_fail", "retrying_in": retry_delay}
+ )
- logger.info("Gave up retrying reg IDs: %r", pushkeys)
- return failed
+ await twisted_sleep(
+ retry_delay, twisted_reactor=self.sygnal.reactor
+ )
+
+ if len(pushkeys) > 0:
+ log.info("Gave up retrying reg IDs: %r", pushkeys)
+ # Count the number of failed devices.
+ span_parent.set_tag("gcm_num_failed", len(failed))
+ return failed
@staticmethod
- def build_data(n):
+ def _build_data(n):
+ """
+ Build the payload data to be sent.
+ Args:
+ n: Notification to build the payload for.
+
+ Returns:
+ JSON-compatible dict
+ """
data = {}
- for attr in ['event_id', 'type', 'sender', 'room_name', 'room_alias', 'membership',
- 'sender_display_name', 'content', 'room_id']:
+ for attr in [
+ "event_id",
+ "type",
+ "sender",
+ "room_name",
+ "room_alias",
+ "membership",
+ "sender_display_name",
+ "content",
+ "room_id",
+ ]:
if hasattr(n, attr):
data[attr] = getattr(n, attr)
# Truncate fields to a sensible maximum length. If the whole
@@ -204,40 +376,84 @@ def build_data(n):
if data[attr] is not None and len(data[attr]) > MAX_BYTES_PER_FIELD:
data[attr] = data[attr][0:MAX_BYTES_PER_FIELD]
- data['prio'] = 'high'
- if n.prio == 'low':
- data['prio'] = 'normal';
+ data["prio"] = "high"
+ if n.prio == "low":
+ data["prio"] = "normal"
- if getattr(n, 'counts', None):
- data['unread'] = n.counts.unread
- data['missed_calls'] = n.counts.missed_calls
+ if getattr(n, "counts", None):
+ data["unread"] = n.counts.unread
+ data["missed_calls"] = n.counts.missed_calls
return data
class CanonicalRegIdStore(object):
-
TABLE_CREATE_QUERY = """
CREATE TABLE IF NOT EXISTS gcm_canonical_reg_id (
reg_id TEXT PRIMARY KEY,
- canonical_reg_id TEXT NOT NULL);"""
+ canonical_reg_id TEXT NOT NULL
+ );
+ """
- def __init__(self, db):
+ def __init__(self):
+ self.db = None
+
+ async def setup(self, db):
+ """
+ Prepares, if necessary, the database for storing canonical registration IDs.
+
+ Separate method from the constructor because we wait for an async request
+ to complete, so it must be an `async def` method.
+
+ Args:
+ db (adbapi.ConnectionPool): database to prepare
+
+ """
self.db = db
- self.db.query(self.TABLE_CREATE_QUERY)
- def set_canonical_id(self, reg_id, canonical_reg_id):
- self.db.query(
- "INSERT OR REPLACE INTO gcm_canonical_reg_id VALUES (?, ?);",
- (reg_id, canonical_reg_id))
+ await self.db.runQuery(self.TABLE_CREATE_QUERY)
- def get_canonical_ids(self, reg_ids):
- # TODO: Use one DB query
- return {reg_id: self._get_canonical_id(reg_id) for reg_id in reg_ids}
+ async def set_canonical_id(self, reg_id, canonical_reg_id):
+ """
+ Associates a GCM registration ID with a canonical registration ID.
+ Args:
+ reg_id (str): a registration ID
+ canonical_reg_id (str): the canonical registration ID for `reg_id`
+ """
+ await self.db.runQuery(
+ "INSERT OR REPLACE INTO gcm_canonical_reg_id VALUES (?, ?);",
+ (reg_id, canonical_reg_id),
+ )
+
+ async def get_canonical_ids(self, reg_ids):
+ """
+ Retrieves the canonical registration ID for multiple registration IDs.
+
+ Args:
+ reg_ids (iterable): registration IDs to retrieve canonical registration
+ IDs for.
+
+ Returns (dict):
+ mapping of registration ID to either its canonical registration ID,
+ or `None` if there is no entry.
+ """
+ return {reg_id: await self.get_canonical_id(reg_id) for reg_id in reg_ids}
+
+ async def get_canonical_id(self, reg_id):
+ """
+ Retrieves the canonical registration ID for one registration ID.
+
+ Args:
+ reg_id (str): registration ID to retrieve the canonical registration
+ ID for.
+
+ Returns (dict):
+ its canonical registration ID, or `None` if there is no entry.
+ """
+ rows = await self.db.runQuery(
+ "SELECT canonical_reg_id FROM gcm_canonical_reg_id WHERE reg_id = ?",
+ (reg_id,),
+ )
- def _get_canonical_id(self, reg_id):
- rows = self.db.query(
- "SELECT canonical_reg_id FROM gcm_canonical_reg_id WHERE reg_id = ?;",
- (reg_id, ), fetch='all')
if rows:
return rows[0][0]
diff --git a/sygnal/http.py b/sygnal/http.py
new file mode 100644
index 00000000..78872707
--- /dev/null
+++ b/sygnal/http.py
@@ -0,0 +1,289 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+import sys
+import traceback
+from uuid import uuid4
+
+from opentracing import Format, tags, logs
+from prometheus_client import Counter
+from twisted.internet import defer
+from twisted.internet.defer import gatherResults, ensureDeferred
+from twisted.python.failure import Failure
+from twisted.web import server
+from twisted.web.http import proxiedLogFormatter
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+from sygnal.notifications import NotificationContext
+from sygnal.utils import NotificationLoggerAdapter
+from .exceptions import InvalidNotificationException, NotificationDispatchException
+from .notifications import Notification
+
+logger = logging.getLogger(__name__)
+
+NOTIFS_RECEIVED_COUNTER = Counter(
+ "sygnal_notifications_received", "Number of notification pokes received"
+)
+
+NOTIFS_RECEIVED_DEVICE_PUSH_COUNTER = Counter(
+ "sygnal_notifications_devices_received", "Number of devices been asked to push"
+)
+
+NOTIFS_BY_PUSHKIN = Counter(
+ "sygnal_per_pushkin_type",
+ "Number of pushes sent via each type of pushkin",
+ labelnames=["pushkin"],
+)
+
+PUSHGATEWAY_HTTP_RESPONSES_COUNTER = Counter(
+ "sygnal_pushgateway_status_codes",
+ "HTTP Response Codes given on the Push Gateway API",
+ labelnames=["code"],
+)
+
+
+class V1NotifyHandler(Resource):
+ def __init__(self, sygnal):
+ super().__init__()
+ self.sygnal = sygnal
+
+ isLeaf = True
+
+ def _make_request_id(self):
+ """
+ Generates a request ID, intended to be unique, for a request so it can
+ be followed through logging.
+ Returns: a request ID for the request.
+ """
+ return str(uuid4())
+
+ def render_POST(self, request):
+ response = self._handle_request(request)
+ if response != NOT_DONE_YET:
+ PUSHGATEWAY_HTTP_RESPONSES_COUNTER.labels(code=request.code).inc()
+ return response
+
+ def _handle_request(self, request):
+ """
+ Actually handle the request.
+ Args:
+ request (Request): The request, corresponding to a POST request.
+
+ Returns:
+ Either a str instance or NOT_DONE_YET.
+
+ """
+ request_id = self._make_request_id()
+ header_dict = {
+ k.decode(): v[0].decode()
+ for k, v in request.requestHeaders.getAllRawHeaders()
+ }
+
+ # extract OpenTracing scope from the HTTP headers
+ span_ctx = self.sygnal.tracer.extract(Format.HTTP_HEADERS, header_dict)
+ span_tags = {
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+ "request_id": request_id,
+ }
+
+ root_span = self.sygnal.tracer.start_span(
+ "pushgateway_v1_notify", child_of=span_ctx, tags=span_tags
+ )
+
+ # if this is True, we will not close the root_span at the end of this
+ # function.
+ root_span_accounted_for = False
+
+ try:
+ context = NotificationContext(request_id, root_span)
+
+ log = NotificationLoggerAdapter(logger, {"request_id": request_id})
+
+ try:
+ body = json.loads(request.content.read())
+ except Exception as exc:
+ msg = "Expected JSON request body"
+ log.warning(msg, exc_info=exc)
+ root_span.log_kv({logs.EVENT: "error", "error.object": exc})
+ request.setResponseCode(400)
+ return msg.encode()
+
+ if "notification" not in body or not isinstance(body["notification"], dict):
+ msg = "Invalid notification: expecting object in 'notification' key"
+ log.warning(msg)
+ root_span.log_kv({logs.EVENT: "error", "message": msg})
+ request.setResponseCode(400)
+ return msg.encode()
+
+ try:
+ notif = Notification(body["notification"])
+ except InvalidNotificationException as e:
+ log.exception("Invalid notification")
+ request.setResponseCode(400)
+ root_span.log_kv({logs.EVENT: "error", "error.object": e})
+ return str(e).encode()
+
+ if notif.event_id is not None:
+ root_span.set_tag("event_id", notif.event_id)
+
+ # track whether the notification was passed with content
+ root_span.set_tag("has_content", notif.content is not None)
+
+ NOTIFS_RECEIVED_COUNTER.inc()
+
+ if len(notif.devices) == 0:
+ msg = "No devices in notification"
+ log.warning(msg)
+ request.setResponseCode(400)
+ return msg.encode()
+
+ rej = []
+ deferreds = []
+
+ pushkins = self.sygnal.pushkins
+
+ for d in notif.devices:
+ NOTIFS_RECEIVED_DEVICE_PUSH_COUNTER.inc()
+
+ appid = d.app_id
+ if appid not in pushkins:
+ log.warning("Got notification for unknown app ID %s", appid)
+ rej.append(d.pushkey)
+ continue
+
+ pushkin = pushkins[appid]
+ log.debug(
+ "Sending push to pushkin %s for app ID %s", pushkin.name, appid
+ )
+
+ NOTIFS_BY_PUSHKIN.labels(pushkin.name).inc()
+
+ async def dispatch_checked():
+ """
+ Dispatches a notification and checks the Pushkin
+ returns a list.
+ Returns (list):
+ The result
+ """
+ result = await pushkin.dispatch_notification(notif, d, context)
+ if not isinstance(result, list):
+ raise TypeError("Pushkin should return list.")
+ return result
+
+ deferreds.append(ensureDeferred(dispatch_checked()))
+
+ def callback(rejected_lists):
+ # combine all rejected pushkeys into one list
+
+ rejected = sum(rejected_lists, rej)
+
+ request.write(json.dumps({"rejected": rejected}).encode())
+
+ log.info(
+ "Successfully delivered notifications" " with %d rejected pushkeys",
+ len(rejected),
+ )
+
+ request.finish()
+
+ def errback(failure: Failure):
+ # due to gatherResults, errors will be wrapped in FirstError.
+ if issubclass(failure.type, defer.FirstError):
+ subfailure = failure.value.subFailure
+ if issubclass(subfailure.type, NotificationDispatchException):
+ request.setResponseCode(502)
+ log.warning(
+ "Failed to dispatch notification.", exc_info=subfailure
+ )
+ else:
+ request.setResponseCode(500)
+ log.error(
+ "Exception whilst dispatching notification.",
+ exc_info=subfailure,
+ )
+ else:
+ request.setResponseCode(500)
+ log.error(
+ "Exception whilst dispatching notification.", exc_info=failure
+ )
+
+ request.finish()
+
+ aggregate = gatherResults(deferreds, consumeErrors=True)
+ aggregate.addCallback(callback)
+ aggregate.addErrback(errback)
+
+ def count_deferred_code(_):
+ PUSHGATEWAY_HTTP_RESPONSES_COUNTER.labels(code=request.code).inc()
+ root_span.set_tag(tags.HTTP_STATUS_CODE, request.code)
+ if not 200 <= request.code < 300:
+ root_span.set_tag(tags.ERROR, True)
+ root_span.finish()
+
+ aggregate.addCallback(count_deferred_code)
+ root_span_accounted_for = True
+
+ # we have to try and send the notifications first,
+ # so we can find out which ones to reject
+ return NOT_DONE_YET
+ except Exception as exc_val:
+ root_span.set_tag(tags.ERROR, True)
+
+ # [2] corresponds to the traceback
+ trace = traceback.format_tb(sys.exc_info()[2])
+ root_span.log_kv(
+ {
+ logs.EVENT: tags.ERROR,
+ logs.MESSAGE: str(exc_val),
+ logs.ERROR_OBJECT: exc_val,
+ logs.ERROR_KIND: type(exc_val),
+ logs.STACK: trace,
+ }
+ )
+ raise
+ finally:
+ if not root_span_accounted_for:
+ root_span.finish()
+
+
+class PushGatewayApiServer(object):
+ def __init__(self, sygnal):
+ """
+ Initialises the /_matrix/push/* (Push Gateway API) server.
+ Args:
+ sygnal (Sygnal): the Sygnal object
+ """
+ root = Resource()
+ matrix = Resource()
+ push = Resource()
+ v1 = Resource()
+
+ # Note that using plain strings here will lead to silent failure
+ root.putChild(b"_matrix", matrix)
+ matrix.putChild(b"push", push)
+ push.putChild(b"v1", v1)
+ v1.putChild(b"notify", V1NotifyHandler(sygnal))
+
+ use_x_forwarded_for = sygnal.config["log"]["access"]["x_forwarded_for"]
+
+ log_formatter = proxiedLogFormatter if use_x_forwarded_for else None
+
+ self.site = server.Site(
+ root, reactor=sygnal.reactor, logFormatter=log_formatter
+ )
diff --git a/sygnal/notifications.py b/sygnal/notifications.py
new file mode 100644
index 00000000..59b884c5
--- /dev/null
+++ b/sygnal/notifications.py
@@ -0,0 +1,142 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .exceptions import InvalidNotificationException
+
+
+class Tweaks:
+ def __init__(self, raw):
+ self.sound = None
+
+ if "sound" in raw:
+ self.sound = raw["sound"]
+
+
+class Device:
+ def __init__(self, raw):
+ self.app_id = None
+ self.pushkey = None
+ self.pushkey_ts = 0
+ self.data = None
+ self.tweaks = None
+
+ if "app_id" not in raw:
+ raise InvalidNotificationException("Device with no app_id")
+ if "pushkey" not in raw:
+ raise InvalidNotificationException("Device with no pushkey")
+ if "pushkey_ts" in raw:
+ self.pushkey_ts = raw["pushkey_ts"]
+ if "tweaks" in raw:
+ self.tweaks = Tweaks(raw["tweaks"])
+ else:
+ self.tweaks = Tweaks({})
+ self.app_id = raw["app_id"]
+ self.pushkey = raw["pushkey"]
+ if "data" in raw:
+ self.data = raw["data"]
+
+
+class Counts:
+ def __init__(self, raw):
+ self.unread = None
+ self.missed_calls = None
+
+ if "unread" in raw:
+ self.unread = raw["unread"]
+ if "missed_calls" in raw:
+ self.missed_calls = raw["missed_calls"]
+
+
+class Notification:
+ def __init__(self, notif):
+ optional_attrs = [
+ "room_name",
+ "room_alias",
+ "prio",
+ "membership",
+ "sender_display_name",
+ "content",
+ "event_id",
+ "room_id",
+ "user_is_target",
+ "type",
+ "sender",
+ ]
+ for a in optional_attrs:
+ if a in notif:
+ self.__dict__[a] = notif[a]
+ else:
+ self.__dict__[a] = None
+
+ if "devices" not in notif or not isinstance(notif["devices"], list):
+ raise InvalidNotificationException("Expected list in 'devices' key")
+
+ if "counts" in notif:
+ self.counts = Counts(notif["counts"])
+ else:
+ self.counts = Counts({})
+
+ self.devices = [Device(d) for d in notif["devices"]]
+
+
+class Pushkin(object):
+ def __init__(self, name, sygnal, config):
+ self.name = name
+ self.cfg = config
+ self.sygnal = sygnal
+
+ def get_config(self, key, default=None):
+ if key not in self.cfg:
+ return default
+ return self.cfg[key]
+
+ async def dispatch_notification(self, n, device, context):
+ """
+ Args:
+ n: The notification to dispatch via this pushkin
+ device: The device to dispatch the notification for.
+ context (NotificationContext): the request context
+
+ Returns:
+ A list of rejected pushkeys, to be reported back to the homeserver
+ """
+ pass
+
+ @classmethod
+ async def create(cls, name, sygnal, config):
+ """
+ Override this if your pushkin needs to call async code in order to
+ be constructed. Otherwise, it defaults to just invoking the Python-standard
+ __init__ constructor.
+
+ Returns:
+ an instance of this Pushkin
+ """
+ return cls(name, sygnal, config)
+
+
+class NotificationContext(object):
+ def __init__(self, request_id, opentracing_span):
+ """
+ Args:
+ request_id (str): An ID for the request, or None to have it
+ generated automatically.
+ opentracing_span (Span): The span for the API request triggering
+ the notification.
+ """
+ self.request_id = request_id
+ self.opentracing_span = opentracing_span
diff --git a/sygnal/sygnal.py b/sygnal/sygnal.py
new file mode 100644
index 00000000..44293a4a
--- /dev/null
+++ b/sygnal/sygnal.py
@@ -0,0 +1,301 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+# Copyright 2018, 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import importlib
+import logging
+import logging.config
+import os
+import sys
+
+import opentracing
+import prometheus_client
+
+# import twisted.internet.reactor
+import yaml
+from opentracing.scope_managers.asyncio import AsyncioScopeManager
+from twisted.enterprise.adbapi import ConnectionPool
+from twisted.internet import asyncioreactor
+from twisted.internet.defer import ensureDeferred
+from twisted.python import log as twisted_log
+
+from sygnal.http import PushGatewayApiServer
+
+logger = logging.getLogger(__name__)
+
+CONFIG_DEFAULTS = {
+ "http": {"port": 5000, "bind_addresses": ["127.0.0.1"]},
+ "log": {"setup": {}, "access": {"x_forwarded_for": False}},
+ "db": {"dbfile": "sygnal.db"},
+ "metrics": {
+ "prometheus": {"enabled": False, "address": "127.0.0.1", "port": 8000},
+ "opentracing": {
+ "enabled": False,
+ "implementation": None,
+ "jaeger": {},
+ "service_name": "sygnal",
+ },
+ "sentry": {"enabled": False},
+ },
+ "apps": {},
+}
+
+
+class Sygnal(object):
+ def __init__(self, config, custom_reactor, tracer=opentracing.tracer):
+ """
+ Object that holds state for the entirety of a Sygnal instance.
+ Args:
+ config (dict): Configuration for this Sygnal
+ custom_reactor: a Twisted Reactor to use.
+ tracer (optional): an OpenTracing tracer. The default is the no-op tracer.
+ """
+ self.config = config
+ self.reactor = custom_reactor
+ self.pushkins = {}
+ self.tracer = tracer
+
+ logging_dict_config = config["log"]["setup"]
+ logging.config.dictConfig(logging_dict_config)
+
+ logger.debug("Started logging")
+
+ observer = twisted_log.PythonLoggingObserver(loggerName="sygnal.access")
+ observer.start()
+
+ sentrycfg = config["metrics"]["sentry"]
+ if sentrycfg["enabled"] is True:
+ import sentry_sdk
+
+ logger.info("Initialising Sentry")
+ sentry_sdk.init(sentrycfg["dsn"])
+
+ promcfg = config["metrics"]["prometheus"]
+ if promcfg["enabled"] is True:
+ prom_addr = promcfg["address"]
+ prom_port = int(promcfg["port"])
+ logger.info(
+ "Starting Prometheus Server on %s port %d", prom_addr, prom_port
+ )
+
+ prometheus_client.start_http_server(port=prom_port, addr=prom_addr or "")
+
+ tracecfg = config["metrics"]["opentracing"]
+ if tracecfg["enabled"] is True:
+ if tracecfg["implementation"] == "jaeger":
+ try:
+ import jaeger_client
+
+ jaeger_cfg = jaeger_client.Config(
+ config=tracecfg["jaeger"],
+ service_name=tracecfg["service_name"],
+ scope_manager=AsyncioScopeManager(),
+ )
+
+ self.tracer = jaeger_cfg.initialize_tracer()
+
+ logger.info("Enabled OpenTracing support with Jaeger")
+ except ModuleNotFoundError:
+ logger.critical(
+ "You have asked for OpenTracing with Jaeger but do not have"
+ " the Python package 'jaeger_client' installed."
+ )
+ raise
+ else:
+ logger.error(
+ "Unknown OpenTracing implementation: %s.", tracecfg["impl"]
+ )
+ sys.exit(1)
+
+ self.database = ConnectionPool(
+ "sqlite3",
+ config["db"]["dbfile"],
+ cp_reactor=self.reactor,
+ cp_min=1,
+ cp_max=1,
+ check_same_thread=False,
+ )
+
+ async def _make_pushkin(self, app_name, app_config):
+ """
+ Load and instantiate a pushkin.
+ Args:
+ app_name (str): The pushkin's app_id
+ app_config (dict): The pushkin's configuration
+
+ Returns (Pushkin):
+ A pushkin of the desired type.
+ """
+ app_type = app_config["type"]
+ if "." in app_type:
+ kind_split = app_type.rsplit(".", 1)
+ to_import = kind_split[0]
+ to_construct = kind_split[1]
+ else:
+ to_import = f"sygnal.{app_type}pushkin"
+ to_construct = f"{app_type.capitalize()}Pushkin"
+
+ logger.info("Importing pushkin module: %s", to_import)
+ pushkin_module = importlib.import_module(to_import)
+ logger.info("Creating pushkin: %s", to_construct)
+ clarse = getattr(pushkin_module, to_construct)
+ return await clarse.create(app_name, self, app_config)
+
+ async def _make_pushkins_then_start(self, port, bind_addresses, pushgateway_api):
+ for app_id, app_cfg in self.config["apps"].items():
+ try:
+ self.pushkins[app_id] = await self._make_pushkin(app_id, app_cfg)
+ except Exception:
+ logger.exception(
+ "Failed to load and create pushkin for kind %s", app_cfg["type"]
+ )
+ sys.exit(1)
+
+ if len(self.pushkins) == 0:
+ logger.error("No app IDs are configured. Edit sygnal.yaml to define some.")
+ sys.exit(1)
+
+ logger.info("Configured with app IDs: %r", self.pushkins.keys())
+
+ for interface in bind_addresses:
+ logger.info("Starting listening on %s port %d", interface, port)
+ self.reactor.listenTCP(port, pushgateway_api.site, interface=interface)
+
+ def run(self):
+ """
+ Attempt to run Sygnal and then exit the application.
+ """
+ port = int(self.config["http"]["port"])
+ bind_addresses = self.config["http"]["bind_addresses"]
+ pushgateway_api = PushGatewayApiServer(self)
+
+ ensureDeferred(
+ self._make_pushkins_then_start(port, bind_addresses, pushgateway_api)
+ )
+ self.reactor.run()
+
+
+def parse_config():
+ """
+ Find and load Sygnal's configuration file.
+ Returns (dict):
+ A loaded configuration.
+ """
+ config_path = os.getenv("SYGNAL_CONF", "sygnal.yaml")
+ try:
+ with open(config_path) as file_handle:
+ return yaml.safe_load(file_handle)
+ except FileNotFoundError:
+ logger.critical(
+ "Could not find configuration file!\n" "Path: %s\n" "Absolute Path: %s",
+ config_path,
+ os.path.realpath(config_path),
+ )
+ raise
+
+
+def check_config(config):
+ """
+ Lightly check the configuration and issue warnings as appropriate.
+ Args:
+ config: The loaded configuration.
+ """
+ UNDERSTOOD_CONFIG_FIELDS = CONFIG_DEFAULTS.keys()
+
+ def check_section(section_name, known_keys, cfgpart=config):
+ nonunderstood = set(cfgpart[section_name].keys()).difference(known_keys)
+ if len(nonunderstood) > 0:
+ logger.warning(
+ f"The following configuration fields in '{section_name}' "
+ f"are not understood: %s",
+ nonunderstood,
+ )
+
+ nonunderstood = set(config.keys()).difference(UNDERSTOOD_CONFIG_FIELDS)
+ if len(nonunderstood) > 0:
+ logger.warning(
+ "The following configuration fields are not understood: %s", nonunderstood
+ )
+
+ check_section("http", {"port", "bind_addresses"})
+ check_section("log", {"setup", "access"})
+ check_section(
+ "access", {"file", "enabled", "x_forwarded_for"}, cfgpart=config["log"]
+ )
+ check_section("db", {"dbfile"})
+ check_section("metrics", {"opentracing", "sentry", "prometheus"})
+ check_section(
+ "opentracing",
+ {"enabled", "implementation", "jaeger", "service_name"},
+ cfgpart=config["metrics"],
+ )
+ check_section(
+ "prometheus", {"enabled", "address", "port"}, cfgpart=config["metrics"]
+ )
+ check_section("sentry", {"enabled", "dsn"}, cfgpart=config["metrics"])
+
+
+def merge_left_with_defaults(defaults, loaded_config):
+ """
+ Merge two configurations, with one of them overriding the other.
+ Args:
+ defaults (dict): A configuration of defaults
+ loaded_config (dict): A configuration, as loaded from disk.
+
+ Returns (dict):
+ A merged configuration, with loaded_config preferred over defaults.
+ """
+ result = defaults.copy()
+
+ if loaded_config is None:
+ return result
+
+ # copy defaults or override them
+ for k, v in result.items():
+ if isinstance(v, dict):
+ if k in loaded_config:
+ result[k] = merge_left_with_defaults(v, loaded_config[k])
+ else:
+ result[k] = copy.deepcopy(v)
+ elif k in loaded_config:
+ result[k] = loaded_config[k]
+
+ # copy things with no defaults
+ for k, v in loaded_config.items():
+ if k not in result:
+ result[k] = v
+
+ return result
+
+
+if __name__ == "__main__":
+ # TODO we don't want to have to install the reactor, when we can get away with
+ # it
+ asyncioreactor.install()
+
+ # we remove the global reactor to make it evident when it has accidentally
+ # been used:
+ # ! twisted.internet.reactor = None
+ # TODO can't do this ^ yet, since twisted.internet.task.{coiterate,cooperate}
+ # (indirectly) depend on the globally-installed reactor and there's no way
+ # to pass in a custom one.
+ # and twisted.web.client uses twisted.internet.task.cooperate
+
+ config = parse_config()
+ config = merge_left_with_defaults(CONFIG_DEFAULTS, config)
+ check_config(config)
+ sygnal = Sygnal(config, custom_reactor=asyncioreactor.AsyncioSelectorReactor())
+ sygnal.run()
diff --git a/sygnal/utils.py b/sygnal/utils.py
new file mode 100644
index 00000000..0da68868
--- /dev/null
+++ b/sygnal/utils.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from logging import LoggerAdapter
+
+from twisted.internet.defer import Deferred
+
+
+async def twisted_sleep(delay, twisted_reactor):
+ """
+ Creates a Deferred which will fire in a set time.
+ This allows you to `await` on it and have an async analogue to
+ L{time.sleep}.
+ Args:
+ delay: Delay in seconds
+ twisted_reactor: Reactor to use for sleeping.
+
+ Returns:
+ a Deferred which fires in `delay` seconds.
+ """
+ deferred = Deferred()
+ twisted_reactor.callLater(delay, deferred.callback, None)
+ await deferred
+
+
+class NotificationLoggerAdapter(LoggerAdapter):
+ def process(self, msg, kwargs):
+ return f"[{self.extra['request_id']}] {msg}", kwargs
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_apns.py b/tests/test_apns.py
new file mode 100644
index 00000000..e29277e2
--- /dev/null
+++ b/tests/test_apns.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import patch, MagicMock
+
+from aioapns.common import NotificationResult
+
+from sygnal import apnstruncate
+from tests import testutils
+
+PUSHKIN_ID = "com.example.apns"
+
+TEST_CERTFILE_PATH = "/path/to/my/certfile.pem"
+
+DEVICE_EXAMPLE = {"app_id": "com.example.apns", "pushkey": "spqr", "pushkey_ts": 42}
+
+
+class ApnsTestCase(testutils.TestCase):
+ def setUp(self):
+ self.apns_mock_class = patch("sygnal.apnspushkin.APNs").start()
+ self.apns_mock = MagicMock()
+ self.apns_mock_class.return_value = self.apns_mock
+
+ # pretend our certificate exists
+ patch("os.path.exists", lambda x: x == TEST_CERTFILE_PATH).start()
+ self.addCleanup(patch.stopall)
+
+ super(ApnsTestCase, self).setUp()
+
+ self.apns_pushkin_snotif = MagicMock()
+ self.sygnal.pushkins[PUSHKIN_ID]._send_notification = self.apns_pushkin_snotif
+
+ def config_setup(self, config):
+ super(ApnsTestCase, self).config_setup(config)
+ config["apps"][PUSHKIN_ID] = {"type": "apns", "certfile": TEST_CERTFILE_PATH}
+
+ def test_payload_truncation(self):
+ """
+ Tests that APNS message bodies will be truncated to fit the limits of
+ APNS.
+ """
+ # Arrange
+ method = self.apns_pushkin_snotif
+ method.return_value = testutils.make_async_magic_mock(
+ NotificationResult("notID", "200")
+ )
+ self.sygnal.pushkins[PUSHKIN_ID].MAX_JSON_BODY_SIZE = 200
+
+ # Act
+ self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ # Assert
+ self.assertEquals(1, method.call_count)
+ ((notification_req,), _kwargs) = method.call_args
+ payload = notification_req.message
+
+ self.assertLessEqual(len(apnstruncate.json_encode(payload)), 200)
+
+ def test_payload_truncation_test_validity(self):
+ """
+ This tests that L{test_payload_truncation_success} is a valid test
+ by showing that not limiting the truncation size would result in a
+ longer message.
+ """
+ # Arrange
+ method = self.apns_pushkin_snotif
+ method.return_value = testutils.make_async_magic_mock(
+ NotificationResult("notID", "200")
+ )
+ self.sygnal.pushkins[PUSHKIN_ID].MAX_JSON_BODY_SIZE = 4096
+
+ # Act
+ self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ # Assert
+ self.assertEquals(1, method.call_count)
+ ((notification_req,), _kwargs) = method.call_args
+ payload = notification_req.message
+
+ self.assertGreater(len(apnstruncate.json_encode(payload)), 200)
+
+ def test_expected(self):
+ """
+ Tests the expected case: a good response from APNS means we pass on
+ a good response to the homeserver.
+ """
+ # Arrange
+ method = self.apns_pushkin_snotif
+ method.side_effect = testutils.make_async_magic_mock(
+ NotificationResult("notID", "200")
+ )
+
+ # Act
+ resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ # Assert
+ self.assertEquals(1, method.call_count)
+ ((notification_req,), _kwargs) = method.call_args
+
+ self.assertEquals(
+ {
+ "room_id": "!slw48wfj34rtnrf:example.com",
+ "aps": {
+ "alert": {
+ "loc-key": "MSG_FROM_USER_IN_ROOM_WITH_CONTENT",
+ "loc-args": [
+ "Major Tom",
+ "Mission Control",
+ "I'm floating in a most peculiar way.",
+ ],
+ },
+ "badge": 3,
+ "content-available": 1,
+ },
+ },
+ notification_req.message,
+ )
+
+ self.assertEquals({"rejected": []}, resp)
+
+ def test_rejection(self):
+ """
+ Tests the rejection case: a rejection response from APNS leads to us
+ passing on a rejection to the homeserver.
+ """
+ # Arrange
+ method = self.apns_pushkin_snotif
+ method.side_effect = testutils.make_async_magic_mock(
+ NotificationResult("notID", "410", description="Unregistered")
+ )
+
+ # Act
+ resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ # Assert
+ self.assertEquals(1, method.call_count)
+ self.assertEquals({"rejected": ["spqr"]}, resp)
+
+ def test_no_retry_on_4xx(self):
+ """
+ Test that we don't retry when we get a 4xx error but do not mark as
+ rejected.
+ """
+ # Arrange
+ method = self.apns_pushkin_snotif
+ method.side_effect = testutils.make_async_magic_mock(
+ NotificationResult("notID", "429", description="TooManyRequests")
+ )
+
+ # Act
+ resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ # Assert
+ self.assertEquals(1, method.call_count)
+ self.assertEquals(502, resp)
+
+ def test_retry_on_5xx(self):
+ """
+ Test that we DO retry when we get a 5xx error and do not mark as
+ rejected.
+ """
+ # Arrange
+ method = self.apns_pushkin_snotif
+ method.side_effect = testutils.make_async_magic_mock(
+ NotificationResult("notID", "503", description="ServiceUnavailable")
+ )
+
+ # Act
+ resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ # Assert
+ self.assertGreater(method.call_count, 1)
+ self.assertEquals(502, resp)
diff --git a/tests/test_apnstruncate.py b/tests/test_apnstruncate.py
new file mode 100644
index 00000000..7eb07ffd
--- /dev/null
+++ b/tests/test_apnstruncate.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copied and adapted from
+# https://raw.githubusercontent.com/matrix-org/pushbaby/master/tests/test_truncate.py
+
+
+import string
+import unittest
+
+from sygnal.apnstruncate import truncate, json_encode
+
+
+def simplestring(length, offset=0):
+ """
+ Deterministically generates a string.
+ Args:
+ length: Length of the string
+ offset: Offset of the string
+
+ Returns:
+ A string formed of lowercase ASCII characters.
+ """
+ return "".join(
+ [
+ string.ascii_lowercase[(i + offset) % len(string.ascii_lowercase)]
+ for i in range(length)
+ ]
+ )
+
+
+def sillystring(length, offset=0):
+ """
+ Deterministically generates a string
+ Args:
+ length: Length of the string
+ offset: Offset of the string
+
+ Returns:
+ A string formed of weird and wonderful UTF-8 emoji characters.
+ """
+ chars = ["\U0001F430", "\U0001F431", "\U0001F432", "\U0001F433"]
+ return "".join([chars[(i + offset) % len(chars)] for i in range(length)])
+
+
+def payload_for_aps(aps):
+ """
+ Returns the APNS payload for an 'aps' dictionary.
+ """
+ return {"aps": aps}
+
+
+class TruncateTestCase(unittest.TestCase):
+ def test_dont_truncate(self):
+ """
+ Tests that truncation is not performed if unnecessary.
+ """
+ # This shouldn't need to be truncated
+ txt = simplestring(20)
+ aps = {"alert": txt}
+ self.assertEquals(txt, truncate(payload_for_aps(aps), 256)["aps"]["alert"])
+
+ def test_truncate_alert(self):
+ """
+ Tests that the 'alert' string field will be truncated when needed.
+ """
+ overhead = len(json_encode(payload_for_aps({"alert": ""})))
+ txt = simplestring(10)
+ aps = {"alert": txt}
+ self.assertEquals(
+ txt[:5], truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]
+ )
+
+ def test_truncate_alert_body(self):
+ """
+ Tests that the 'alert' 'body' field will be truncated when needed.
+ """
+ overhead = len(json_encode(payload_for_aps({"alert": {"body": ""}})))
+ txt = simplestring(10)
+ aps = {"alert": {"body": txt}}
+ self.assertEquals(
+ txt[:5],
+ truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["body"],
+ )
+
+ def test_truncate_loc_arg(self):
+ """
+ Tests that the 'alert' 'loc-args' field will be truncated when needed.
+ (Tests with one loc arg)
+ """
+ overhead = len(json_encode(payload_for_aps({"alert": {"loc-args": [""]}})))
+ txt = simplestring(10)
+ aps = {"alert": {"loc-args": [txt]}}
+ self.assertEquals(
+ txt[:5],
+ truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["loc-args"][0],
+ )
+
+ def test_truncate_loc_args(self):
+ """
+ Tests that the 'alert' 'loc-args' field will be truncated when needed.
+ (Tests with two loc args)
+ """
+ overhead = len(json_encode(payload_for_aps({"alert": {"loc-args": ["", ""]}})))
+ txt = simplestring(10)
+ txt2 = simplestring(10, 3)
+ aps = {"alert": {"loc-args": [txt, txt2]}}
+ self.assertEquals(
+ txt[:5],
+ truncate(payload_for_aps(aps), overhead + 10)["aps"]["alert"]["loc-args"][
+ 0
+ ],
+ )
+ self.assertEquals(
+ txt2[:5],
+ truncate(payload_for_aps(aps), overhead + 10)["aps"]["alert"]["loc-args"][
+ 1
+ ],
+ )
+
+ def test_python_unicode_support(self):
+ """
+ Tests Python's unicode support :-
+ a one character unicode string should have a length of one, even if it's one
+ multibyte character.
+ OS X, for example, is broken, and counts the number of surrogate pairs.
+ I have no great desire to manually parse UTF-8 to work around this since
+ it works fine on Linux.
+ """
+ if len(u"\U0001F430") != 1:
+ msg = (
+ "Unicode support is broken in your Python binary. "
+ + "Truncating messages with multibyte unicode characters will fail."
+ )
+ self.fail(msg)
+
+ def test_truncate_string_with_multibyte(self):
+ """
+ Tests that truncation works as expected on strings containing one
+ multibyte character.
+ """
+ overhead = len(json_encode(payload_for_aps({"alert": ""})))
+ txt = u"\U0001F430" + simplestring(30)
+ aps = {"alert": txt}
+ # NB. The number of characters of the string we get is dependent
+ # on the json encoding used.
+ self.assertEquals(
+ txt[:17], truncate(payload_for_aps(aps), overhead + 20)["aps"]["alert"]
+ )
+
+ def test_truncate_multibyte(self):
+ """
+ Tests that truncation works as expected on strings containing only
+ multibyte characters.
+ """
+ overhead = len(json_encode(payload_for_aps({"alert": ""})))
+ txt = sillystring(30)
+ aps = {"alert": txt}
+ trunc = truncate(payload_for_aps(aps), overhead + 30)
+ # The string is all 4 byte characters so the trunctaed UTF-8 string
+ # should be a multiple of 4 bytes long
+ self.assertEquals(len(trunc["aps"]["alert"].encode()) % 4, 0)
+ # NB. The number of characters of the string we get is dependent
+ # on the json encoding used.
+ self.assertEquals(txt[:7], trunc["aps"]["alert"])
diff --git a/tests/test_gcm.py b/tests/test_gcm.py
new file mode 100644
index 00000000..87496ae9
--- /dev/null
+++ b/tests/test_gcm.py
@@ -0,0 +1,214 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+
+from sygnal.gcmpushkin import GcmPushkin
+from tests import testutils
+from tests.testutils import DummyResponse
+
+DEVICE_EXAMPLE = {"app_id": "com.example.gcm", "pushkey": "spqr", "pushkey_ts": 42}
+DEVICE_EXAMPLE2 = {"app_id": "com.example.gcm", "pushkey": "spqr2", "pushkey_ts": 42}
+
+
+class TestGcmPushkin(GcmPushkin):
+ """
+ A GCM pushkin with the ability to make HTTP requests removed and instead
+ can be preloaded with virtual requests.
+ """
+
+ def __init__(self, name, sygnal, config, canonical_reg_id_store):
+ super().__init__(name, sygnal, config, canonical_reg_id_store)
+ self.preloaded_response = None
+ self.preloaded_response_payload = None
+ self.last_request_body = None
+ self.last_request_headers = None
+ self.num_requests = 0
+
+ def preload_with_response(self, code, response_payload):
+ """
+ Preloads a fake GCM response.
+ """
+ self.preloaded_response = DummyResponse(code)
+ self.preloaded_response_payload = response_payload
+
+ async def _perform_http_request(self, body, headers):
+ self.last_request_body = body
+ self.last_request_headers = headers
+ self.num_requests += 1
+ return self.preloaded_response, json.dumps(self.preloaded_response_payload)
+
+
+class GcmTestCase(testutils.TestCase):
+ def config_setup(self, config):
+ super(GcmTestCase, self).config_setup(config)
+ config["apps"]["com.example.gcm"] = {
+ "type": "tests.test_gcm.TestGcmPushkin",
+ "api_key": "kii",
+ }
+
+ def test_expected(self):
+ """
+ Tests the expected case: a good response from GCM leads to a good
+ response from Sygnal.
+ """
+ gcm = self.sygnal.pushkins["com.example.gcm"]
+ gcm.preload_with_response(
+ 200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]}
+ )
+
+ req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ resp = self._collect_request(req)
+
+ self.assertEquals(resp, {"rejected": []})
+ self.assertEquals(gcm.num_requests, 1)
+
+ def test_rejected(self):
+ """
+ Tests the rejected case: a pushkey rejected to GCM leads to Sygnal
+ informing the homeserver of the rejection.
+ """
+ gcm = self.sygnal.pushkins["com.example.gcm"]
+ gcm.preload_with_response(
+ 200, {"results": [{"registration_id": "spqr", "error": "NotRegistered"}]}
+ )
+
+ req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ resp = self._collect_request(req)
+
+ self.assertEquals(resp, {"rejected": ["spqr"]})
+ self.assertEquals(gcm.num_requests, 1)
+
+ def test_regenerated_id(self):
+ """
+ Tests that pushkeys regenerated by GCM are kept track of that is,
+ the new device ID is used in lieu of the old one if we are aware that
+ it has changed.
+ """
+ gcm = self.sygnal.pushkins["com.example.gcm"]
+ gcm.preload_with_response(
+ 200, {"results": [{"registration_id": "spqr_new", "message_id": "msg42"}]}
+ )
+
+ req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ resp = self._collect_request(req)
+
+ self.assertEquals(resp, {"rejected": []})
+
+ gcm.preload_with_response(
+ 200, {"results": [{"registration_id": "spqr_new", "message_id": "msg43"}]}
+ )
+
+ req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ resp = self._collect_request(req)
+
+ self.assertEquals(gcm.last_request_body["to"], "spqr_new")
+
+ self.assertEquals(resp, {"rejected": []})
+ self.assertEquals(gcm.num_requests, 2)
+
+ def test_batching(self):
+ """
+ Tests that multiple GCM devices have their notification delivered to GCM
+ together, instead of being delivered separately.
+ """
+ gcm = self.sygnal.pushkins["com.example.gcm"]
+ gcm.preload_with_response(
+ 200,
+ {
+ "results": [
+ {"registration_id": "spqr", "message_id": "msg42"},
+ {"registration_id": "spqr2", "message_id": "msg42"},
+ ]
+ },
+ )
+
+ req = self._make_request(
+ self._make_dummy_notification([DEVICE_EXAMPLE, DEVICE_EXAMPLE2])
+ )
+
+ resp = self._collect_request(req)
+
+ self.assertEquals(resp, {"rejected": []})
+ self.assertEquals(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"])
+ self.assertEquals(gcm.num_requests, 1)
+
+ def test_batching_individual_failure(self):
+ """
+ Tests that multiple GCM devices have their notification delivered to GCM
+ together, instead of being delivered separately,
+ and that if only one device ID is rejected, then only that device is
+ reported to the homeserver as rejected.
+ """
+ gcm = self.sygnal.pushkins["com.example.gcm"]
+ gcm.preload_with_response(
+ 200,
+ {
+ "results": [
+ {"registration_id": "spqr", "message_id": "msg42"},
+ {"registration_id": "spqr2", "error": "NotRegistered"},
+ ]
+ },
+ )
+
+ req = self._make_request(
+ self._make_dummy_notification([DEVICE_EXAMPLE, DEVICE_EXAMPLE2])
+ )
+
+ resp = self._collect_request(req)
+
+ self.assertEquals(resp, {"rejected": ["spqr2"]})
+ self.assertEquals(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"])
+ self.assertEquals(gcm.num_requests, 1)
+
+ def test_regenerated_failure(self):
+ """
+ Tests that use of a regenerated device ID does not cause confusion
+ when reporting a rejection to the homeserver.
+ The homeserver doesn't know about the regenerated ID so the rejection
+ must be translated back into the one provided by the homeserver.
+ """
+ gcm = self.sygnal.pushkins["com.example.gcm"]
+ gcm.preload_with_response(
+ 200, {"results": [{"registration_id": "spqr_new", "message_id": "msg42"}]}
+ )
+
+ req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ resp = self._collect_request(req)
+
+ self.assertEquals(resp, {"rejected": []})
+
+ # imagine there is some non-negligible time between these two,
+ # and the device in question is unregistered
+
+ gcm.preload_with_response(
+ 200,
+ {"results": [{"registration_id": "spqr_new", "error": "NotRegistered"}]},
+ )
+
+ req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE]))
+
+ resp = self._collect_request(req)
+
+ self.assertEquals(gcm.last_request_body["to"], "spqr_new")
+
+ # the ID translation needs to be transparent as the homeserver will not
+ # make sense of it otherwise.
+ self.assertEquals(resp, {"rejected": ["spqr"]})
+ self.assertEquals(gcm.num_requests, 2)
diff --git a/tests/test_pushgateway_api_v1.py b/tests/test_pushgateway_api_v1.py
new file mode 100644
index 00000000..5b940ec6
--- /dev/null
+++ b/tests/test_pushgateway_api_v1.py
@@ -0,0 +1,185 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from sygnal.exceptions import (
+ NotificationDispatchException,
+ TemporaryNotificationDispatchException,
+)
+from sygnal.notifications import Pushkin
+
+from tests import testutils
+
+DEVICE_RAISE_EXCEPTION = {
+ "app_id": "com.example.spqr",
+ "pushkey": "raise_exception",
+ "pushkey_ts": 1234,
+}
+
+DEVICE_REMOTE_ERROR = {
+ "app_id": "com.example.spqr",
+ "pushkey": "remote_error",
+ "pushkey_ts": 1234,
+}
+
+DEVICE_TEMPORARY_REMOTE_ERROR = {
+ "app_id": "com.example.spqr",
+ "pushkey": "temporary_remote_error",
+ "pushkey_ts": 1234,
+}
+
+DEVICE_REJECTED = {
+ "app_id": "com.example.spqr",
+ "pushkey": "reject",
+ "pushkey_ts": 1234,
+}
+
+DEVICE_ACCEPTED = {
+ "app_id": "com.example.spqr",
+ "pushkey": "accept",
+ "pushkey_ts": 1234,
+}
+
+
+class TestPushkin(Pushkin):
+ """
+ A synthetic Pushkin with simple rules.
+ """
+
+ async def dispatch_notification(self, n, device, context):
+ if device.pushkey == "raise_exception":
+ raise Exception("Bad things have occurred!")
+ elif device.pushkey == "remote_error":
+ raise NotificationDispatchException("Synthetic failure")
+ elif device.pushkey == "temporary_remote_error":
+ raise TemporaryNotificationDispatchException("Synthetic failure")
+ elif device.pushkey == "reject":
+ return [device.pushkey]
+ elif device.pushkey == "accept":
+ return []
+ raise Exception(f"Unexpected fall-through. {device.pushkey}")
+
+
+class PushGatewayApiV1TestCase(testutils.TestCase):
+ def config_setup(self, config):
+ """
+ Set up a TestPushkin for the test.
+ """
+ super(PushGatewayApiV1TestCase, self).config_setup(config)
+ config["apps"]["com.example.spqr"] = {
+ "type": "tests.test_pushgateway_api_v1.TestPushkin"
+ }
+
+ def test_good_requests_give_200(self):
+ """
+ Test that good requests give a 200 response code.
+ """
+ # 200 codes cause the result to be parsed instead of returning the code
+ self.assertNot(
+ isinstance(
+ self._request(
+ self._make_dummy_notification([DEVICE_ACCEPTED, DEVICE_REJECTED])
+ ),
+ int,
+ )
+ )
+
+ def test_accepted_devices_are_not_rejected(self):
+ """
+ Test that devices which are accepted by the Pushkin
+ do not lead to a rejection being returned to the homeserver.
+ """
+ self.assertEquals(
+ self._request(self._make_dummy_notification([DEVICE_ACCEPTED])),
+ {"rejected": []},
+ )
+
+ def test_rejected_devices_are_rejected(self):
+ """
+ Test that devices which are rejected by the Pushkin
+ DO lead to a rejection being returned to the homeserver.
+ """
+ self.assertEquals(
+ self._request(self._make_dummy_notification([DEVICE_REJECTED])),
+ {"rejected": [DEVICE_REJECTED["pushkey"]]},
+ )
+
+ def test_only_rejected_devices_are_rejected(self):
+ """
+ Test that devices which are rejected by the Pushkin
+ are the only ones to have a rejection returned to the homeserver,
+ even if other devices feature in the request.
+ """
+ self.assertEquals(
+ self._request(
+ self._make_dummy_notification([DEVICE_REJECTED, DEVICE_ACCEPTED])
+ ),
+ {"rejected": [DEVICE_REJECTED["pushkey"]]},
+ )
+
+ def test_bad_requests_give_400(self):
+ """
+ Test that bad requests lead to a 400 Bad Request response.
+ """
+ self.assertEquals(self._request({}), 400)
+
+ def test_exceptions_give_500(self):
+ """
+ Test that internal exceptions/errors lead to a 500 Internal Server Error
+ response.
+ """
+
+ self.assertEquals(
+ self._request(self._make_dummy_notification([DEVICE_RAISE_EXCEPTION])), 500
+ )
+
+ # we also check that a successful device doesn't hide the exception
+ self.assertEquals(
+ self._request(
+ self._make_dummy_notification([DEVICE_ACCEPTED, DEVICE_RAISE_EXCEPTION])
+ ),
+ 500,
+ )
+
+ self.assertEquals(
+ self._request(
+ self._make_dummy_notification([DEVICE_RAISE_EXCEPTION, DEVICE_ACCEPTED])
+ ),
+ 500,
+ )
+
+ def test_remote_errors_give_502(self):
+ """
+ Test that errors caused by remote services such as GCM or APNS
+ lead to a 502 Bad Gateway response.
+ """
+
+ self.assertEquals(
+ self._request(self._make_dummy_notification([DEVICE_REMOTE_ERROR])), 502
+ )
+
+ # we also check that a successful device doesn't hide the exception
+ self.assertEquals(
+ self._request(
+ self._make_dummy_notification([DEVICE_ACCEPTED, DEVICE_REMOTE_ERROR])
+ ),
+ 502,
+ )
+
+ self.assertEquals(
+ self._request(
+ self._make_dummy_notification([DEVICE_REMOTE_ERROR, DEVICE_ACCEPTED])
+ ),
+ 502,
+ )
diff --git a/tests/testutils.py b/tests/testutils.py
new file mode 100644
index 00000000..f1155b67
--- /dev/null
+++ b/tests/testutils.py
@@ -0,0 +1,218 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from io import BytesIO
+from threading import Condition
+
+from twisted.internet.defer import ensureDeferred
+from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.trial import unittest
+from twisted.web.http_headers import Headers
+from twisted.web.server import NOT_DONE_YET
+from twisted.web.test.requesthelper import DummyRequest as UnaugmentedDummyRequest
+
+from sygnal.http import PushGatewayApiServer
+from sygnal.sygnal import Sygnal, merge_left_with_defaults, CONFIG_DEFAULTS
+
+REQ_PATH = b"/_matrix/push/v1/notify"
+
+
+class TestCase(unittest.TestCase):
+ def config_setup(self, config):
+ config["db"]["dbfile"] = ":memory:"
+
+ def setUp(self):
+ reactor = ExtendedMemoryReactorClock()
+
+ config = {"apps": {}, "db": {}, "log": {"setup": {"version": 1}}}
+ config = merge_left_with_defaults(CONFIG_DEFAULTS, config)
+
+ self.config_setup(config)
+
+ self.sygnal = Sygnal(config, reactor)
+ self.sygnal.database.start()
+ self.v1api = PushGatewayApiServer(self.sygnal)
+
+ start_deferred = ensureDeferred(
+ self.sygnal._make_pushkins_then_start(0, [], None)
+ )
+
+ while not start_deferred.called:
+ # we need to advance until the pushkins have started up
+ self.sygnal.reactor.advance(1)
+ self.sygnal.reactor.wait_for_work(lambda: start_deferred.called)
+
+ def tearDown(self):
+ super().tearDown()
+ self.sygnal.database.close()
+
+ def _make_dummy_notification(self, devices):
+ return {
+ "notification": {
+ "id": "$3957tyerfgewrf384",
+ "room_id": "!slw48wfj34rtnrf:example.com",
+ "type": "m.room.message",
+ "sender": "@exampleuser:matrix.org",
+ "sender_display_name": "Major Tom",
+ "room_name": "Mission Control",
+ "room_alias": "#exampleroom:matrix.org",
+ "prio": "high",
+ "content": {
+ "msgtype": "m.text",
+ "body": "I'm floating in a most peculiar way.",
+ },
+ "counts": {"unread": 2, "missed_calls": 1},
+ "devices": devices,
+ }
+ }
+
+ def _make_request(self, payload, headers=None):
+ """
+ Make a dummy request to the notify endpoint with the specified
+ Args:
+ payload: payload to be JSON encoded
+ headers (dict, optional): A L{dict} mapping header names as L{bytes}
+ to L{list}s of header values as L{bytes}
+
+ Returns (DummyRequest):
+ A dummy request corresponding to the request arguments supplied.
+
+ """
+ pathparts = REQ_PATH.split(b"/")
+ if pathparts[0] == b"":
+ pathparts = pathparts[1:]
+ dreq = DummyRequest(pathparts)
+ dreq.requestHeaders = Headers(headers or {})
+ dreq.responseCode = 200 # default to 200
+
+ if isinstance(payload, dict):
+ payload = json.dumps(payload)
+
+ dreq.content = BytesIO(payload.encode())
+ dreq.method = "POST"
+
+ return dreq
+
+ def _collect_request(self, request):
+ """
+ Collects (waits until done and then returns the result of) the request.
+ Args:
+ request (Request): a request to collect
+
+ Returns (dict or int):
+ If successful (200 response received), the response is JSON decoded
+ and the resultant dict is returned.
+ If the response code is not 200, returns the response code.
+ """
+ resource = self.v1api.site.getResourceFor(request)
+ rendered = resource.render(request)
+
+ if request.responseCode != 200:
+ return request.responseCode
+
+ if isinstance(rendered, str):
+ return json.loads(rendered)
+ elif rendered == NOT_DONE_YET:
+
+ while not request.finished:
+ # we need to advance until the request has been finished
+ self.sygnal.reactor.advance(1)
+ self.sygnal.reactor.wait_for_work(lambda: request.finished)
+
+ assert request.finished > 0
+
+ if request.responseCode != 200:
+ return request.responseCode
+
+ written_bytes = b"".join(request.written)
+ return json.loads(written_bytes)
+ else:
+ raise RuntimeError(f"Can't collect: {rendered}")
+
+ def _request(self, *args, **kwargs):
+ """
+ Makes and collects a request.
+ See L{_make_request} and L{_collect_request}.
+ """
+ request = self._make_request(*args, **kwargs)
+
+ return self._collect_request(request)
+
+
+class ExtendedMemoryReactorClock(MemoryReactorClock):
+ def __init__(self):
+ super().__init__()
+ self.work_notifier = Condition()
+
+ def callFromThread(self, function, *args):
+ self.callLater(0, function, *args)
+
+ def callLater(self, when, what, *a, **kw):
+ self.work_notifier.acquire()
+ try:
+ return_value = super().callLater(when, what, *a, **kw)
+ self.work_notifier.notify_all()
+ finally:
+ self.work_notifier.release()
+
+ return return_value
+
+ def wait_for_work(self, early_stop=lambda: False):
+ """
+ Blocks until there is work as long as the early stop condition
+ is not satisfied.
+
+ Args:
+ early_stop: Extra function called that determines whether to stop
+ blocking.
+ Should returns true iff the early stop condition is satisfied,
+ in which case no blocking will be done.
+ It is intended to be used to detect when the task you are
+ waiting for is complete, e.g. a Deferred has fired or a
+ Request has been finished.
+ """
+ self.work_notifier.acquire()
+
+ try:
+ while len(self.getDelayedCalls()) == 0 and not early_stop():
+ self.work_notifier.wait()
+ finally:
+ self.work_notifier.release()
+
+
+class DummyRequest(UnaugmentedDummyRequest):
+ """
+ Tracks the response code in the 'code' field, like a normal Request.
+ """
+
+ def __init__(self, postpath, session=None, client=None):
+ super().__init__(postpath, session, client)
+ self.code = 200
+
+ def setResponseCode(self, code, message=None):
+ super().setResponseCode(code, message)
+ self.code = code
+
+
+class DummyResponse(object):
+ def __init__(self, code):
+ self.code = code
+
+
+def make_async_magic_mock(ret_val):
+ async def dummy(*_args, **_kwargs):
+ return ret_val
+
+ return dummy