diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml new file mode 100644 index 00000000..09778b48 --- /dev/null +++ b/.buildkite/pipeline.yml @@ -0,0 +1,27 @@ + +steps: + - command: + - "python -m pip install black" + - "black --check ." + label: "Check Code Formatting" + plugins: + - docker#v3.0.1: + image: "python:3.7" + + - command: + - "python -m pip install flake8" + - "flake8 ." + label: "Check Code Style" + plugins: + - docker#v3.0.1: + image: "python:3.7" + + - wait + + - command: + - "python -m pip install -e ." + - "trial tests" + label: "Run unit tests" + plugins: + - docker#v3.0.1: + image: "python:3.7" diff --git a/.gitignore b/.gitignore index b6e20577..d4c05d27 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ *.pyc sygnal.conf +sygnal.yaml gunicorn_config.py var/ sygnal.pid sygnal.db +/_trial_temp* +/.idea diff --git a/README.rst b/README.rst index 5d3d7c97..1dc04fd9 100644 --- a/README.rst +++ b/README.rst @@ -1,47 +1,28 @@ Introduction ============ -sygnal is a reference Push Gateway for Matrix (http://matrix.org/). +Sygnal is a reference Push Gateway for `Matrix `_. -See -http://matrix.org/docs/spec/client_server/r0.2.0.html#id51 for a high level overview of how notifications work in Matrix. +See https://matrix.org/docs/spec/client_server/r0.5.0#id134 +for a high level overview of how notifications work in Matrix. -http://matrix.org/docs/spec/push_gateway/unstable.html#post-matrix-push-r0-notify -describes the protocol that Matrix Home Servers use to send notifications to -Push Gateways such as sygnal. +https://matrix.org/docs/spec/push_gateway/r0.1.0 +describes the protocol that Matrix Home Servers use to send notifications to Push Gateways such as Sygnal. Setup ===== -sygnal is a plain WSGI app, although these instructions use gunicorn which -will create a complete, standalone webserver. When used with gunicorn, -sygnal can use gunicorn's extra hook to perform a clean shutdown which tries as -hard as possible to ensure no messages are lost. +Sygnal is configured through a YAML configuration file. +By default, this configuration file is assumed to be named ``sygnal.yaml`` and to be in the working directory. +To change this, set the ``SYGNAL_CONF`` environment variable to the path to your configuration file. +A sample configuration file is provided in this repository; +see ``sygnal.yaml.sample``. -There are two config files: - * sygnal.conf (The app-specific config file) - * gunicorn_config.py (gunicorn's config file) +The `apps:` section is where you set up different apps that are to be handled. +Each app should be given its own subsection, with the key of that subsection being the app's ``app_id``. +Keys in this section take the form of the ``app_id``, as specified when setting up a Matrix pusher +(see https://matrix.org/docs/spec/client_server/r0.5.0#post-matrix-client-r0-pushers-set). -sygnal.conf contains configuration for sygnal itself. This includes the location -and level of sygnal's log file. The [apps] section is where you set up different -apps that are to be handled. Keys in this section take the form of the app_id -and the name of the configuration key, joined by a single dot ('.'). The app_id -is as specified when setting up a Matrix pusher (see -http://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-pushers-set). So for example, the `type` for -the App ID of `com.example.myapp.ios.prod` would be specified as follows:: - - com.example.myapp.ios.prod.type = foobar - -By default sygnal.conf is assumed to be in the working directory, but the path -can be overriden by setting the `sygnal.conf` environment variable. - -The gunicorn sample config contains everything necessary to run sygnal from -gunicorn. The shutdown hook handles clean shutdown. You can customise other -aspects of this file as you wish to change, for example, the log location or the -bind port. - -Note that sygnal uses gevent. You should therefore not change the worker class -or the number of workers (which should be 1: in gevent, a single worker uses -multiple greenlets to handle all the requests). +See the sample configuration for examples. App Types --------- @@ -49,42 +30,44 @@ There are two supported App Types: apns This sends push notifications to iOS apps via the Apple Push Notification - Service (APNS). It expects the 'certfile' parameter to be a path relative to - sygnal's working directory of a PEM file containing the APNS certificate and - unencrypted private key. + Service (APNS). -gcm - This sends messages via Google Cloud Messaging (GCM) and hence can be used - to deliver notifications to Android apps. It expects the 'apiKey' parameter - to contain the secret GCM key. + Expected configuration depends on which kind of authentication you wish to use: -Running -======= -To run with gunicorn: + | + + For certificate-based authentication: + It expects: -gunicorn -c gunicorn_config.py sygnal:app + * the ``certfile`` parameter to be a path relative to + sygnal's working directory of a PEM file containing the APNS certificate and + unencrypted private key. -You can customise the gunicorn_config.py to determine whether this daemonizes or runs in the foreground. + For token-based authentication: + It expects: -Gunicorn maintains its own logging in addition to the app's, so the access_log -and error_log contain gunicorn's accesses and gunicorn specific errors. The log -file in sygnal.conf contains app level logging. + * the 'keyfile' parameter to be a path relative to Sygnal's working directory of a p8 file + * the 'key_id' parameter + * the 'team_id' parameter + * the 'topic' parameter -Clean shutdown -============== -The code for APNS uses a grace period where it waits for errors to come down the -socket before declaring it safe for the app to shut down (due to the design of -APNS). Terminating using SIGTERM performs a clean shutdown:: +gcm + This sends messages via Google/Firebase Cloud Messaging (GCM/FCM) and hence can be used + to deliver notifications to Android apps. It expects the 'api_key' parameter + to contain the 'Server key', which can be acquired from Firebase Console at: + ``https://console.firebase.google.com/project//settings/cloudmessaging/`` - kill -TERM `cat sygnal.pid` +Running +======= -Restarting sygnal using SIGHUP will handle this gracefully:: +``python -m sygnal.sygnal`` - kill -HUP `cat sygnal.pid` +Python 3.7 or higher is required. Log Rotation ============ -Gunicorn appends to files but does not use a rotating logger. -Sygnal's app logging does the same. Gunicorn will re-open all log files -(including the app's) when sent SIGUSR1. The recommended configuration is -therefore to use logrotate. +Sygnal's logging appends to files but does not use a rotating logger. +The recommended configuration is therefore to use ``logrotate``. +The log file will be automatically reopened if the log file changes, for example +due to ``logrotate``. + diff --git a/gunicorn_config.py.sample b/gunicorn_config.py.sample deleted file mode 100644 index aa70f82c..00000000 --- a/gunicorn_config.py.sample +++ /dev/null @@ -1,43 +0,0 @@ -# Customise these settings to your needs -bind = '0.0.0.0:5000' -daemon = False -accesslog = 'var/access_log' -errorlog = 'var/error_log' -#accesslog = '-' -#errorlog = '-' -loglevel = 'debug' -pidfile = 'sygnal.pid' -worker_connections = 1000 -keepalive = 2 -proc_name = 'sygnal' - -# It is inadvisable to change anything below here, -# since these settings make gunicorn work appropriately -# with sygnal. - -preload_app = False -workers = 1 -worker_class = 'gevent' - - -def worker_exit(server, worker): - # Used in the hooks - try: - # This must be imported inside the hook or it won't find - # the import - import sygnal - # NB. We obviously need to clean up in the worker, not - # the arbiter process. worker_exit runs in the worker - # (despite the docs claiming it runs after the worker - # has exited) - # We use a flask hook to handle the worker setup. - # Unfortunately flask doesn't have a shutdown hook - # (it's not a standard thing in WSGI). - sygnal.shutdown() - - except: - # We swallow this exception because it's generally a completely - # useless, "No module named sygnal" due to it failing to load - # the sygnal module because an exception was thrown. - print("Failed to load sygnal - check your log file") - diff --git a/setup.cfg b/setup.cfg index d08a9014..b9ef0b54 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,15 @@ +[flake8] +# line length defaulted to by black +max-line-length = 88 + +# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes +# for error codes. The ones we ignore are: +# W503: line break before binary operator +# W504: line break after binary operator +# E203: whitespace before ':' (which is contrary to pep8?) +# (this is a subset of those ignored in Synapse) +ignore=W503,W504,E203 + [isort] line_length = 80 not_skip = __init__.py diff --git a/setup.py b/setup.py index 15d55a11..1a351a7b 100755 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ # Copyright 2014 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +17,7 @@ # limitations under the License. import os + from setuptools import setup, find_packages @@ -26,18 +28,20 @@ def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() + setup( name="matrix-sygnal", version=read("VERSION").strip(), packages=find_packages(exclude=["tests", "tests.*"]), description="Reference Push Gateway for Matrix Notifications", install_requires=[ - "flask>=1.0.2", - "gevent>=1.0.1", - "pushbaby>=0.0.9", - "grequests", - "six", - "prometheus_client>=0.7.0,<0.8" + "Twisted>=19.2.1", + "prometheus_client>=0.7.0,<0.8", + "aioapns>=1.7", + "pyyaml>=5.1.1", + "service_identity>=18.1.0", + "jaeger-client>=4.0.0", + "opentracing>=2.2.0", ], long_description=read("README.rst"), ) diff --git a/sygnal.conf.sample b/sygnal.conf.sample deleted file mode 100644 index 0da51b56..00000000 --- a/sygnal.conf.sample +++ /dev/null @@ -1,14 +0,0 @@ -[log] -loglevel = debug -logfile = var/sygnal.log - -# [metrics] -# prometheus_port = 8800 # Enable prometheus metrics and bind to this port -# prometheus_addr = 127.0.0.1 # Optional bind address -# sentry_dsn = https://... # Enable sentry integration - -[apps] -com.example.myapp.ios.type = apns -com.example.myapp.ios.certfile = com.example.myApp_prod_APNS.pem -com.example.myapp.android.type = gcm -com.example.myapp.android.apiKey = your_api_key_for_gcm diff --git a/sygnal.yaml.sample b/sygnal.yaml.sample new file mode 100644 index 00000000..95a6f41c --- /dev/null +++ b/sygnal.yaml.sample @@ -0,0 +1,168 @@ +## +# This is a configuration for Sygnal, the reference Push Gateway for Matrix +# See: matrix.org +## + +## Logging # +# +log: + # Specify a Python logging 'dictConfig', as described at: + # https://docs.python.org/3.7/library/logging.config.html#logging.config.dictConfig + # + setup: + version: 1 + formatters: + normal: + format: "%(asctime)s [%(process)d] %(levelname)-5s %(name)s %(message)s" + handlers: + # This handler prints to Standard Error + # + stderr: + class: "logging.StreamHandler" + formatter: "normal" + stream: "ext://sys.stderr" + + # This handler prints to Standard Output. + # + stdout: + class: "logging.StreamHandler" + formatter: "normal" + stream: "ext://sys.stdout" + + # This handler demonstrates logging to a text file on the filesystem. + # You can use logrotate(8) to perform log rotation. + # + file: + class: "logging.handlers.WatchedFileHandler" + formatter: "normal" + filename: "./sygnal.log" + loggers: + # sygnal.access contains the access logging lines. + # Comment out this section if you don't want to give access logging + # any special treatment. + # + sygnal.access: + propagate: false + handlers: ["stdout"] + level: "INFO" + + # sygnal contains log lines from Sygnal itself. + # You can comment out this section to fall back to the root logger. + # + sygnal: + propagate: false + handlers: ["stderr", "file"] + + root: + # Specify the handler(s) to send log messages to. + handlers: ["stderr"] + level: "INFO" + + + access: + # Specify whether or not to trust the IP address in the `X-Forwarded-For` + # header. In general, you want to enable this if and only if you are using a + # reverse proxy which is configured to emit it. + # + x_forwarded_for: false + +## HTTP Server (Matrix Push Gateway API) # +# +http: + # Specify a list of interface addresses to bind to. + # + # This example listens on the IPv4 loopback device: + bind_addresses: ['127.0.0.1'] + # This example listens on all IPv4 interfaces: + #bind_addresses: ['0.0.0.0'] + # This example listens on all IPv4 and IPv6 interfaces: + #bind_addresses: ['0.0.0.0', '::'] + + # Specify the port number to listen on. + # + port: 5000 + +## Metrics # +# +metrics: + ## Prometheus # + # + prometheus: + # Specify whether or not to enable Prometheus. + # + enabled: false + + # Specify an address for the Prometheus HTTP Server to listen on. + # + address: '127.0.0.1' + + # Specify a port for the Prometheus HTTP Server to listen on. + # + port: 8000 + + ## OpenTracing # + # + opentracing: + # Specify whether or not to enable OpenTracing. + # + enabled: false + + # Specify an implementation of OpenTracing to use. Currently only 'jaeger' + # is supported. + # + implementation: jaeger + + # Specify the service name to be reported to the tracer. + # + service_name: sygnal + + # Specify configuration values to pass to jaeger_client. + # + jaeger: + sampler: + type: 'const' + param: 1 +# local_agent: +# reporting_host: '127.0.0.1' +# reporting_port: + logging: true + + ## Sentry # + # + sentry: + # Specify whether or not to enable Sentry. + # + enabled: false + + # Specify your Sentry DSN if you enable Sentry + # + #dsn: "https://@sentry.example.org/" + +## Pushkins/Apps # +# +# Add a section for every push application here. +# Specify the pushkey for the application and also the type. +# For the type, you may specify a fully-qualified Python classname if desired. +# +apps: + # This is an example APNs push configuration using certificate authentication. + # + #com.example.myapp.ios: + # type: apns + # certfile: com.example.myApp_prod_APNS.pem + + # This is an example APNs push configuration using key authentication. + # + #com.example.myapp2.ios: + # type: apns + # keyfile: my_key.p8 + # key_id: MY_KEY_ID + # team_id: MY_TEAM_ID + # topic: MY_TOPIC + + # This is an example GCM/FCM push configuration. + # + #com.example.myapp.android: + # type: gcm + # api_key: your_api_key_for_gcm + diff --git a/sygnal/__init__.py b/sygnal/__init__.py index 21fd9c1b..e69de29b 100644 --- a/sygnal/__init__.py +++ b/sygnal/__init__.py @@ -1,365 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 OpenMarket Ltd -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from six.moves import configparser - -import json -import logging -import os -import sys -import threading -from logging.handlers import WatchedFileHandler - -import flask -from flask import Flask, request - -import prometheus_client -from prometheus_client import Counter - -import sygnal.db -from sygnal.exceptions import InvalidNotificationException - - -NOTIFS_RECEIVED_COUNTER = Counter( - "sygnal_notifications_received", "Number of notification pokes received", -) - -NOTIFS_RECEIVED_DEVICE_PUSH_COUNTER = Counter( - "sygnal_notifications_devices_received", - "Number of devices been asked to push", -) - -NOTIFS_BY_PUSHKIN = Counter( - "sygnal_per_pushkin_type", - "Number of pushes sent via each type of pushkin", - labelnames=["pushkin"], -) - -logger = logging.getLogger(__name__) - -app = Flask('sygnal') -app.debug = False -app.config.from_object(__name__) - -CONFIG_SECTIONS = ['http', 'log', 'apps', 'db', 'metrics'] -CONFIG_DEFAULTS = { - 'port': '5000', - 'loglevel': 'info', - 'logfile': '', - 'dbfile': 'sygnal.db' -} - -pushkins = {} - -class RequestIdFilter(logging.Filter): - """A logging filter which adds the current request id to each record""" - def filter(self, record): - request_id = '' - if flask.has_request_context(): - request_id = flask.g.get('request_id', '') - record.request_id = request_id - return True - -class RequestCounter(object): - def __init__(self): - self._count = 0 - self._lock = threading.Lock() - - def get(self): - with self._lock: - c = self._count - self._count = c + 1 - return c - - -request_count = RequestCounter() - - -class Tweaks: - def __init__(self, raw): - self.sound = None - - if 'sound' in raw: - self.sound = raw['sound'] - - -class Device: - def __init__(self, raw): - self.app_id = None - self.pushkey = None - self.pushkey_ts = 0 - self.data = None - self.tweaks = None - - if 'app_id' not in raw: - raise InvalidNotificationException("Device with no app_id") - if 'pushkey' not in raw: - raise InvalidNotificationException("Device with no pushkey") - if 'pushkey_ts' in raw: - self.pushkey_ts = raw['pushkey_ts'] - if 'tweaks' in raw: - self.tweaks = Tweaks(raw['tweaks']) - else: - self.tweaks = Tweaks({}) - self.app_id = raw['app_id'] - self.pushkey = raw['pushkey'] - if 'data' in raw: - self.data = raw['data'] - - -class Counts: - def __init__(self, raw): - self.unread = None - self.missed_calls = None - - if 'unread' in raw: - self.unread = raw['unread'] - if 'missed_calls' in raw: - self.missed_calls = raw['missed_calls'] - - -class Notification: - def __init__(self, notif): - optional_attrs = [ - 'room_name', - 'room_alias', - 'prio', - 'membership', - 'sender_display_name', - 'content', - 'event_id', - 'room_id', - 'user_is_target', - 'type', - 'sender', - ] - for a in optional_attrs: - if a in notif: - self.__dict__[a] = notif[a] - else: - self.__dict__[a] = None - - if 'devices' not in notif or not isinstance(notif['devices'], list): - raise InvalidNotificationException("Expected list in 'devices' key") - - if 'counts' in notif: - self.counts = Counts(notif['counts']) - else: - self.counts = Counts({}) - - self.devices = [Device(d) for d in notif['devices']] - - -class Pushkin(object): - def __init__(self, name): - self.name = name - - def setup(self): - pass - - def getConfig(self, key): - if not self.cfg.has_option('apps', '%s.%s' % (self.name, key)): - return None - return self.cfg.get('apps', '%s.%s' % (self.name, key)) - - def dispatchNotification(self, n): - pass - - def shutdown(self): - pass - - -class SygnalContext: - pass - - -class ClientError(Exception): - pass - - -def parse_config(): - cfg = configparser.SafeConfigParser(CONFIG_DEFAULTS) - # Make keys case-sensitive - cfg.optionxform = str - for sect in CONFIG_SECTIONS: - try: - cfg.add_section(sect) - except configparser.DuplicateSectionError: - pass - cfg.read(os.getenv("SYGNAL_CONF", "sygnal.conf")) - return cfg - - -def make_pushkin(kind, name): - if '.' in kind: - toimport = kind - else: - toimport = "sygnal.%spushkin" % kind - toplevelmodule = __import__(toimport) - pushkinmodule = getattr(toplevelmodule, "%spushkin" % kind) - clarse = getattr(pushkinmodule, "%sPushkin" % kind.capitalize()) - return clarse(name) - - -@app.before_request -def log_request(): - flask.g.request_id = "%s-%i" % ( - request.method, request_count.get(), - ) - logger.info("Processing request %s", request.url) - - -@app.after_request -def log_processed_request(response): - logger.info( - "Processed request %s: %i", - request.url, response.status_code, - ) - return response - -@app.errorhandler(ClientError) -def handle_client_error(e): - resp = flask.jsonify({ 'error': { 'msg': str(e) } }) - resp.status_code = 400 - return resp - -@app.route('/') -def root(): - return "" - -@app.route('/_matrix/push/v1/notify', methods=['POST']) -def notify(): - try: - body = json.loads(request.data) - except Exception: - raise ClientError("Expecting json request body") - - if 'notification' not in body or not isinstance(body['notification'], dict): - msg = "Invalid notification: expecting object in 'notification' key" - logger.warn(msg) - flask.abort(400, msg) - - try: - notif = Notification(body['notification']) - except InvalidNotificationException as e: - logger.exception("Invalid notification") - flask.abort(400, e.message) - - if len(notif.devices) == 0: - msg = "No devices in notification" - logger.warn(msg) - flask.abort(400, msg) - - NOTIFS_RECEIVED_COUNTER.inc() - - rej = [] - - for d in notif.devices: - NOTIFS_RECEIVED_DEVICE_PUSH_COUNTER.inc() - - appid = d.app_id - if appid not in pushkins: - logger.warn("Got notification for unknown app ID %s", appid) - rej.append(d.pushkey) - continue - - pushkin = pushkins[appid] - logger.debug( - "Sending push to pushkin %s for app ID %s", - pushkin.name, appid, - ) - - NOTIFS_BY_PUSHKIN.labels(pushkin.name).inc() - - try: - rej.extend(pushkin.dispatchNotification(notif)) - except: - logger.exception("Failed to send push") - flask.abort(500, "Failed to send push") - return flask.jsonify({ - "rejected": rej - }) - - -def setup(): - cfg = parse_config() - - logging.getLogger().setLevel(getattr(logging, cfg.get('log', 'loglevel').upper())) - logfile = cfg.get('log', 'logfile') - if logfile != '': - handler = WatchedFileHandler(logfile) - handler.addFilter(RequestIdFilter()) - formatter = logging.Formatter( - '%(asctime)s [%(process)d] %(levelname)-5s ' - '%(request_id)s %(name)s %(message)s' - ) - handler.setFormatter(formatter) - logging.getLogger().addHandler(handler) - else: - logging.basicConfig() - - if cfg.has_option("metrics", "sentry_dsn"): - # Only import sentry if enabled - import sentry_sdk - from sentry_sdk.integrations.flask import FlaskIntegration - sentry_sdk.init( - dsn=cfg.get("metrics", "sentry_dsn"), - integrations=[FlaskIntegration()], - ) - - if cfg.has_option("metrics", "prometheus_port"): - prometheus_client.start_http_server( - port=cfg.getint("metrics", "prometheus_port"), - addr=cfg.get("metrics", "prometheus_addr"), - ) - - ctx = SygnalContext() - ctx.database = sygnal.db.Db(cfg.get('db', 'dbfile')) - - for key,val in cfg.items('apps'): - parts = key.rsplit('.', 1) - if len(parts) < 2: - continue - if parts[1] == 'type': - try: - pushkins[parts[0]] = make_pushkin(val, parts[0]) - except: - logger.exception("Failed to load module for kind %s", val) - raise - - if len(pushkins) == 0: - logger.error("No app IDs are configured. Edit sygnal.conf to define some.") - sys.exit(1) - - for p in pushkins: - pushkins[p].cfg = cfg - pushkins[p].setup(ctx) - logger.info("Configured with app IDs: %r", pushkins.keys()) - - logger.error("Setup completed") - -def shutdown(): - logger.info("Starting shutdown...") - i = 0 - for p in pushkins.values(): - logger.info("Shutting down (%d/%d)..." % (i+1, len(pushkins))) - p.shutdown() - i += 1 - logger.info("Shutdown complete...") - - -setup() diff --git a/sygnal/apnspushkin.py b/sygnal/apnspushkin.py index 0fb5d98b..8b4fcad4 100644 --- a/sygnal/apnspushkin.py +++ b/sygnal/apnspushkin.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,240 +14,371 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import base64 +import logging +import os +from uuid import uuid4 + +import aioapns +from aioapns import APNs, NotificationRequest +from opentracing import logs, tags +from prometheus_client import Histogram, Counter +from twisted.internet.defer import Deferred + +from sygnal import apnstruncate +from sygnal.exceptions import ( + PushkinSetupException, + TemporaryNotificationDispatchException, + NotificationDispatchException, +) +from sygnal.notifications import Pushkin +from sygnal.utils import twisted_sleep, NotificationLoggerAdapter +logger = logging.getLogger(__name__) -from . import Pushkin -from .exceptions import PushkinSetupException, NotificationDispatchException +SEND_TIME_HISTOGRAM = Histogram( + "sygnal_apns_request_time", "Time taken to send HTTP request to APNS" +) -from pushbaby import PushBaby -import pushbaby.errors +RESPONSE_STATUS_CODES_COUNTER = Counter( + "sygnal_apns_status_codes", + "Number of HTTP response status codes received from APNS", + labelnames=["pushkin", "code"], +) -import logging -import base64 -import time -import gevent -logger = logging.getLogger(__name__) +class ApnsPushkin(Pushkin): + """ + Relays notifications to the Apple Push Notification Service. + """ -create_failed_table_query = u""" -CREATE TABLE IF NOT EXISTS apns_failed (id INTEGER PRIMARY KEY, b64token TEXT NOT NULL, -last_failure_ts INTEGER NOT NULL, -last_failure_type varchar(10) not null, last_failure_code INTEGER default -1, token_invalidated_ts INTEGER default -1); -""" + # Errors for which the token should be rejected + TOKEN_ERROR_REASON = "Unregistered" + TOKEN_ERROR_CODE = 410 -create_failed_index_query = u""" -CREATE UNIQUE INDEX IF NOT EXISTS b64token on apns_failed(b64token); -""" + MAX_TRIES = 3 + RETRY_DELAY_BASE = 10 -# Max length of individual fields. Pushbaby will truncate appropriate -# fields of the push to fit the whole body in the max size, but it's -# not very fast so keep things to a sensible size. -MAX_FIELD_LENGTH = 1024 + MAX_FIELD_LENGTH = 1024 + MAX_JSON_BODY_SIZE = 4096 -class ApnsPushkin(Pushkin): - MAX_TRIES = 2 - DELETE_FEEDBACK_AFTER_SECS = 28 * 24 * 60 * 60 # a month(ish) - # These are the only ones of the errors returned in the APNS stream - # that we want to feed back. Anything else is nothing to do with the - # token. - ERRORS_TO_FEED_BACK = ( - pushbaby.errors.INVALID_TOKEN_SIZE, - pushbaby.errors.INVALID_TOKEN, - ) - - def __init__(self, name): - super(ApnsPushkin, self).__init__(name); - - def setup(self, ctx): - self.db = ctx.database - self.certfile = self.getConfig('certfile') - plaf = self.getConfig('platform') - if not plaf or plaf == 'production' or plaf == 'prod': - self.plaf = 'prod' - elif plaf == 'sandbox': - self.plaf = 'sandbox' + UNDERSTOOD_CONFIG_FIELDS = {"type", "platform", "certfile"} + + def __init__(self, name, sygnal, config): + super().__init__(name, sygnal, config) + + nonunderstood = set(self.cfg.keys()).difference(self.UNDERSTOOD_CONFIG_FIELDS) + if len(nonunderstood) > 0: + logger.warning( + "The following configuration fields are not understood: %s", + nonunderstood, + ) + + platform = self.get_config("platform") + if not platform or platform == "production" or platform == "prod": + self.use_sandbox = False + elif platform == "sandbox": + self.use_sandbox = True else: - raise PushkinSetupException("Invalid platform: %s" % plaf) - - self.db.query(create_failed_table_query) - self.db.query(create_failed_index_query) - - self.pushbaby = PushBaby(certfile=self.certfile, platform=self.plaf) - self.pushbaby.on_push_failed = self.on_push_failed - logger.info("APNS with cert file %s on %s platform", self.certfile, self.plaf) - - # poll feedback in a little bit, not while we're busy starting up - gevent.spawn_later(10, self.do_feedback_poll) - - def dispatchNotification(self, n): - tokens = {} - for d in n.devices: - tokens[d.pushkey] = d - - # check for tokens that have previously failed - token_set_str = u"(" + u",".join([u"?" for _ in tokens.keys()]) + u")" - feed_back_errors_set_str = u"(" + u",".join([u"?" for _ in ApnsPushkin.ERRORS_TO_FEED_BACK]) + u")" - q = ( - "SELECT b64token,last_failure_type,last_failure_code,token_invalidated_ts "+ - "FROM apns_failed WHERE b64token IN "+token_set_str+ - " and ("+ - "(last_failure_type = 'error' and last_failure_code in "+feed_back_errors_set_str+") "+ - "or (last_failure_type = 'feedback')"+ - ")" + raise PushkinSetupException(f"Invalid platform: {platform}") + + certfile = self.get_config("certfile") + keyfile = self.get_config("keyfile") + if not certfile and not keyfile: + raise PushkinSetupException( + "You must provide a path to an APNs certificate, or an APNs token." ) - args = [] - args.extend([unicode(t) for t in tokens.keys()]) - args.extend([(u"%d" % e) for e in ApnsPushkin.ERRORS_TO_FEED_BACK]) - rows = self.db.query(q, args, fetch='all') - - rejected = [] - for row in rows: - token_invalidated_ts = row[3] - token_pushkey_ts = tokens[row[0]].pushkey_ts - if token_pushkey_ts < token_invalidated_ts: - logger.warn( - "Rejecting token %s with ts %d. Last failure of type '%s' code %d, invalidated at %d", - row[0], token_pushkey_ts, row[1], row[2], token_invalidated_ts + + if certfile: + if not os.path.exists(certfile): + raise PushkinSetupException( + f"The APNs certificate '{certfile}' does not exist." ) - rejected.append(row[0]) - del tokens[row[0]] - else: - logger.info("Have a failure for token %s of type '%s' at %d code %d but this token postdates it (%d): allowing.", row[0], row[1], token_invalidated_ts, row[2], token_pushkey_ts) - # This pushkey may be alive again, but we don't delete the - # failure because HSes should probably have a fresh token - # if they actually want to use it - - payload = None - if n.event_id and not n.type: - payload = self.get_payload_event_id_only(n) else: - payload = self.get_payload_full(n) - - prio = 10 - if n.prio == 'low': - prio = 5 - - tries = 0 - for t,d in tokens.items(): - while tries < ApnsPushkin.MAX_TRIES: - thispayload = payload - if 'aps' in thispayload: - thispayload = payload.copy() - thispayload['aps'] = thispayload['aps'].copy() - if d.tweaks.sound: - thispayload['aps']['sound'] = d.tweaks.sound - logger.info("Sending (attempt %i): '%s' -> %s", tries, thispayload, t) - poke_start_time = time.time() + # keyfile + if not os.path.exists(keyfile): + raise PushkinSetupException( + f"The APNs key file '{keyfile}' does not exist." + ) + if not self.get_config("key_id"): + raise PushkinSetupException("You must supply key_id.") + if not self.get_config("team_id"): + raise PushkinSetupException("You must supply team_id.") + if not self.get_config("topic"): + raise PushkinSetupException("You must supply topic.") + + if self.get_config("certfile") is not None: + self.apns_client = APNs( + client_cert=self.get_config("certfile"), use_sandbox=self.use_sandbox + ) + else: + self.apns_client = APNs( + key=self.get_config("keyfile"), + key_id=self.get_config("key_id"), + team_id=self.get_config("team_id"), + topic=self.get_config("topic"), + use_sandbox=self.use_sandbox, + ) + + # without this, aioapns will retry every second forever. + self.apns_client.pool.max_connection_attempts = 3 + + async def _dispatch_request(self, log, span, device, shaved_payload, prio): + """ + Actually attempts to dispatch the notification once. + """ + + # this is no good: APNs expects ID to be in their format + # so we can't just derive a + # notif_id = context.request_id + f"-{n.devices.index(device)}" + + notif_id = str(uuid4()) + + log.info(f"Sending as APNs-ID {notif_id}") + span.set_tag("apns_id", notif_id) + + device_token = base64.b64decode(device.pushkey).hex() + + request = NotificationRequest( + device_token=device_token, + message=shaved_payload, + priority=prio, + notification_id=notif_id, + ) + + try: + with SEND_TIME_HISTOGRAM.time(): + response = await self._send_notification(request) + except aioapns.ConnectionError: + raise TemporaryNotificationDispatchException("aioapns Connection Failure") + + code = int(response.status) + + span.set_tag(tags.HTTP_STATUS_CODE, code) + + RESPONSE_STATUS_CODES_COUNTER.labels(pushkin=self.name, code=code).inc() + + if response.is_successful: + return [] + else: + # .description corresponds to the 'reason' response field + span.set_tag("apns_reason", response.description) + if ( + code == self.TOKEN_ERROR_CODE + or response.description == self.TOKEN_ERROR_REASON + ): + return [device.pushkey] + else: + if 500 <= code < 600: + raise TemporaryNotificationDispatchException( + f"{response.status} {response.description}" + ) + else: + raise NotificationDispatchException( + f"{response.status} {response.description}" + ) + + async def dispatch_notification(self, n, device, context): + log = NotificationLoggerAdapter(logger, {"request_id": context.request_id}) + + # The pushkey is kind of secret because you can use it to send push + # to someone. + # span_tags = {"pushkey": device.pushkey} + span_tags = {} + + with self.sygnal.tracer.start_span( + "apns_dispatch", tags=span_tags, child_of=context.opentracing_span + ) as span_parent: + + if n.event_id and not n.type: + payload = self._get_payload_event_id_only(n) + else: + payload = self._get_payload_full(n, log) + + if payload is None: + # Nothing to do + span_parent.log_kv({logs.EVENT: "apns_no_payload"}) + return + prio = 10 + if n.prio == "low": + prio = 5 + + shaved_payload = apnstruncate.truncate( + payload, max_length=self.MAX_JSON_BODY_SIZE + ) + + for retry_number in range(self.MAX_TRIES): try: - res = self.pushbaby.send(thispayload, base64.b64decode(t), priority=prio) - logger.info("Sent (%f secs): -> %s", time.time() - poke_start_time, t) - break - except: - logger.exception("Exception sending push -> %s" % (t, )) + log.debug("Trying") + + span_tags = {"retry_num": retry_number} + + with self.sygnal.tracer.start_span( + "apns_dispatch_try", tags=span_tags, child_of=span_parent + ) as span: + return await self._dispatch_request( + log, span, device, shaved_payload, prio + ) + except TemporaryNotificationDispatchException as exc: + retry_delay = self.RETRY_DELAY_BASE * (2 ** retry_number) + if exc.custom_retry_delay is not None: + retry_delay = exc.custom_retry_delay + + log.warning( + "Temporary failure, will retry in %d seconds", + retry_delay, + exc_info=True, + ) + + span_parent.log_kv( + {"event": "temporary_fail", "retrying_in": retry_delay} + ) + + if retry_number == self.MAX_TRIES - 1: + raise NotificationDispatchException( + "Retried too many times." + ) from exc + else: + await twisted_sleep( + retry_delay, twisted_reactor=self.sygnal.reactor + ) + + def _get_payload_event_id_only(self, n): + """ + Constructs a payload for a notification where we know only the event ID. + Args: + n: The notification to construct a payload for. + + Returns: + The APNs payload as a nested dicts. + """ + payload = {} - tries += 1 + if n.room_id: + payload["room_id"] = n.room_id + if n.event_id: + payload["event_id"] = n.event_id + + if n.counts.unread is not None: + payload["unread_count"] = n.counts.unread + if n.counts.missed_calls is not None: + payload["missed_calls"] = n.counts.missed_calls - if tries == ApnsPushkin.MAX_TRIES: - raise NotificationDispatchException("Max retries exceeded") + return payload - return rejected + def _get_payload_full(self, n, log): + """ + Constructs a payload for a notification. + Args: + n: The notification to construct a payload for. + log: A logger. - def get_payload_full(self, n): + Returns: + The APNs payload as nested dicts. + """ from_display = n.sender if n.sender_display_name is not None: from_display = n.sender_display_name - from_display = from_display[0:MAX_FIELD_LENGTH] + from_display = from_display[0 : self.MAX_FIELD_LENGTH] loc_key = None loc_args = None - if n.type == 'm.room.message' or n.type == 'm.room.encrypted': + if n.type == "m.room.message" or n.type == "m.room.encrypted": room_display = None if n.room_name: - room_display = n.room_name[0:MAX_FIELD_LENGTH] + room_display = n.room_name[0 : self.MAX_FIELD_LENGTH] elif n.room_alias: - room_display = n.room_alias[0:MAX_FIELD_LENGTH] + room_display = n.room_alias[0 : self.MAX_FIELD_LENGTH] content_display = None action_display = None is_image = False - if n.content and 'msgtype' in n.content and 'body' in n.content: - if 'body' in n.content: - if n.content['msgtype'] == 'm.text': - content_display = n.content['body'] - elif n.content['msgtype'] == 'm.emote': - action_display = n.content['body'] + if n.content and "msgtype" in n.content and "body" in n.content: + if "body" in n.content: + if n.content["msgtype"] == "m.text": + content_display = n.content["body"] + elif n.content["msgtype"] == "m.emote": + action_display = n.content["body"] else: - # fallback: 'body' should always be user-visible text in an m.room.message - content_display = n.content['body'] - if n.content['msgtype'] == 'm.image': + # fallback: 'body' should always be user-visible text + # in an m.room.message + content_display = n.content["body"] + if n.content["msgtype"] == "m.image": is_image = True if room_display: if is_image: - loc_key = 'IMAGE_FROM_USER_IN_ROOM' + loc_key = "IMAGE_FROM_USER_IN_ROOM" loc_args = [from_display, content_display, room_display] elif content_display: - loc_key = 'MSG_FROM_USER_IN_ROOM_WITH_CONTENT' + loc_key = "MSG_FROM_USER_IN_ROOM_WITH_CONTENT" loc_args = [from_display, room_display, content_display] elif action_display: - loc_key = 'ACTION_FROM_USER_IN_ROOM' + loc_key = "ACTION_FROM_USER_IN_ROOM" loc_args = [room_display, from_display, action_display] else: - loc_key = 'MSG_FROM_USER_IN_ROOM' + loc_key = "MSG_FROM_USER_IN_ROOM" loc_args = [from_display, room_display] else: if is_image: - loc_key = 'IMAGE_FROM_USER' + loc_key = "IMAGE_FROM_USER" loc_args = [from_display, content_display] elif content_display: - loc_key = 'MSG_FROM_USER_WITH_CONTENT' + loc_key = "MSG_FROM_USER_WITH_CONTENT" loc_args = [from_display, content_display] elif action_display: - loc_key = 'ACTION_FROM_USER' + loc_key = "ACTION_FROM_USER" loc_args = [from_display, action_display] else: - loc_key = 'MSG_FROM_USER' + loc_key = "MSG_FROM_USER" loc_args = [from_display] - elif n.type == 'm.call.invite': + elif n.type == "m.call.invite": is_video_call = False # This detection works only for hs that uses WebRTC for calls - if n.content and 'offer' in n.content and 'sdp' in n.content['offer']: - sdp = n.content['offer']['sdp'] - if 'm=video' in sdp: + if n.content and "offer" in n.content and "sdp" in n.content["offer"]: + sdp = n.content["offer"]["sdp"] + if "m=video" in sdp: is_video_call = True if is_video_call: - loc_key = 'VIDEO_CALL_FROM_USER' + loc_key = "VIDEO_CALL_FROM_USER" else: - loc_key = 'VOICE_CALL_FROM_USER' + loc_key = "VOICE_CALL_FROM_USER" loc_args = [from_display] - elif n.type == 'm.room.member': + elif n.type == "m.room.member": if n.user_is_target: - if n.membership == 'invite': + if n.membership == "invite": if n.room_name: - loc_key = 'USER_INVITE_TO_NAMED_ROOM' - loc_args = [from_display, n.room_name[0:MAX_FIELD_LENGTH]] + loc_key = "USER_INVITE_TO_NAMED_ROOM" + loc_args = [ + from_display, + n.room_name[0 : self.MAX_FIELD_LENGTH], + ] elif n.room_alias: - loc_key = 'USER_INVITE_TO_NAMED_ROOM' - loc_args = [from_display, n.room_alias[0:MAX_FIELD_LENGTH]] + loc_key = "USER_INVITE_TO_NAMED_ROOM" + loc_args = [ + from_display, + n.room_alias[0 : self.MAX_FIELD_LENGTH], + ] else: - loc_key = 'USER_INVITE_TO_CHAT' + loc_key = "USER_INVITE_TO_CHAT" loc_args = [from_display] elif n.type: # A type of message was received that we don't know about # but it was important enough for a push to have got to us - loc_key = 'MSG_FROM_USER' + loc_key = "MSG_FROM_USER" loc_args = [from_display] aps = {} if loc_key: - aps['alert'] = {'loc-key': loc_key } + aps["alert"] = {"loc-key": loc_key} if loc_args: - aps['alert']['loc-args'] = loc_args + aps["alert"]["loc-args"] = loc_args badge = None if n.counts.unread is not None: @@ -257,88 +389,25 @@ def get_payload_full(self, n): badge += n.counts.missed_calls if badge is not None: - aps['badge'] = badge + aps["badge"] = badge if loc_key: - aps['content-available'] = 1 + aps["content-available"] = 1 if loc_key is None and badge is None: - logger.info("Nothing to do for alert of type %s", n.type) - return rejected + log.info("Nothing to do for alert of type %s", n.type) + return None payload = {} if loc_key and n.room_id: - payload['room_id'] = n.room_id + payload["room_id"] = n.room_id - payload['aps'] = aps + payload["aps"] = aps return payload - def get_payload_event_id_only(self, n): - payload = {} - - if n.room_id: - payload['room_id'] = n.room_id - if n.event_id: - payload['event_id'] = n.event_id - - if n.counts.unread is not None: - payload['unread_count'] = n.counts.unread - if n.counts.missed_calls is not None: - payload['missed_calls'] = n.counts.missed_calls - - return payload - - def on_push_failed(self, token, identifier, status): - logger.error("Error sending push to token %s, status %s", token, status) - # We store all errors (could be useful to get failures instead of digging - # through logs) but note that not all failures mean we should stop sending - # to that token. - self.db.query( - "INSERT OR REPLACE INTO apns_failed "+ - "(b64token, last_failure_ts, last_failure_type, last_failure_code, token_invalidated_ts) "+ - " VALUES (?, ?, 'error', ?, ?)", - (base64.b64encode(token), long(time.time()), status, long(time.time())) + async def _send_notification(self, request): + return await Deferred.fromFuture( + asyncio.ensure_future(self.apns_client.send_notification(request)) ) - - def do_feedback_poll(self): - logger.info("Polling feedback...") - try: - feedback = self.pushbaby.get_all_feedback() - for fb in feedback: - b64token = unicode(base64.b64encode(fb.token)) - logger.info("Got feedback for token %s which is invalid as of %d", b64token, long(fb.ts)) - self.db.query( - "INSERT OR REPLACE INTO apns_failed "+ - "(b64token, last_failure_ts, last_failure_type, token_invalidated_ts) "+ - " VALUES (?, ?, 'feedback', ?)", - (b64token, long(time.time()), long(fb.ts)) - ) - logger.info("Stored %d feedback items", len(feedback)) - - # great, we're good until tomorrow - gevent.spawn_later(24 * 60 * 60, self.do_feedback_poll) - except: - logger.exception("Failed to poll for feedback, trying again in 10 minutes") - gevent.spawn_later(10 * 60, self.do_feedback_poll) - - self.prune_failures() - - def prune_failures(self): - """ - Delete any failures older than a set amount of time. - This is the only way we delete them - we can't delete - them once we've sent them because a token could be in use by - more than one Home Server. - """ - cutoff = long(time.time()) - ApnsPushkin.DELETE_FEEDBACK_AFTER_SECS - deleted = self.db.query( - "DELETE FROM apns_failed WHERE last_failure_ts < ?", - (cutoff,) - ) - logger.info("deleted %d stale items from failure table", deleted) - - def shutdown(self): - while self.pushbaby.messages_in_flight(): - gevent.wait(timeout=1.0) diff --git a/sygnal/apnstruncate.py b/sygnal/apnstruncate.py new file mode 100644 index 00000000..7b52a0e1 --- /dev/null +++ b/sygnal/apnstruncate.py @@ -0,0 +1,131 @@ +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied and adapted from +# https://raw.githubusercontent.com/matrix-org/pushbaby/master/pushbaby/truncate.py +import json + + +def json_encode(payload): + return json.dumps(payload, ensure_ascii=False).encode() + + +class BodyTooLongException(Exception): + pass + + +def is_too_long(payload, max_length=2048): + """ + Returns True if the given payload dictionary is too long for a push. + Note that the maximum is now 2kB "In iOS 8 and later" although in + practice, payloads over 256 bytes (the old limit) are still + delivered to iOS 7 or earlier devices. + + Maximum is 4 kiB in the new APNs with the HTTP/2 interface. + """ + return len(json_encode(payload)) > max_length + + +def truncate(payload, max_length=2048): + """ + Truncate APNs fields to make the payload fit within the max length + specified. + Only truncates fields that are safe to do so. + + Args: + payload: nested dict that will be passed to APNs + max_length: Maximum length, in bytes, that the payload should occupy + when JSON-encoded. + + Returns: + Nested dict which should comply with the maximum length restriction. + + """ + payload = payload.copy() + if "aps" not in payload: + if is_too_long(payload, max_length): + raise BodyTooLongException() + else: + return payload + aps = payload["aps"] + + # first ensure all our choppables are str objects. + # We need them to be for truncating to work and this + # makes more sense than checking every time. + for c in _choppables_for_aps(aps): + val = _choppable_get(aps, c) + if isinstance(val, bytes): + _choppable_put(aps, c, val.decode()) + + # chop off whole unicode characters until it fits (or we run out of chars) + while is_too_long(payload, max_length): + longest = _longest_choppable(aps) + if longest is None: + raise BodyTooLongException() + + txt = _choppable_get(aps, longest) + # Note that python's support for this is actually broken on some OSes + # (see test_apnstruncate.py) + txt = txt[:-1] + _choppable_put(aps, longest, txt) + payload["aps"] = aps + + return payload + + +def _choppables_for_aps(aps): + ret = [] + if "alert" not in aps: + return ret + + alert = aps["alert"] + if isinstance(alert, str): + ret.append(("alert",)) + elif isinstance(alert, dict): + if "body" in alert: + ret.append(("alert.body",)) + if "loc-args" in alert: + ret.extend([("alert.loc-args", i) for i in range(len(alert["loc-args"]))]) + + return ret + + +def _choppable_get(aps, choppable): + if choppable[0] == "alert": + return aps["alert"] + elif choppable[0] == "alert.body": + return aps["alert"]["body"] + elif choppable[0] == "alert.loc-args": + return aps["alert"]["loc-args"][choppable[1]] + + +def _choppable_put(aps, choppable, val): + if choppable[0] == "alert": + aps["alert"] = val + elif choppable[0] == "alert.body": + aps["alert"]["body"] = val + elif choppable[0] == "alert.loc-args": + aps["alert"]["loc-args"][choppable[1]] = val + + +def _longest_choppable(aps): + longest = None + length_of_longest = 0 + for c in _choppables_for_aps(aps): + val = _choppable_get(aps, c) + val_len = len(val.encode()) + if val_len > length_of_longest: + longest = c + length_of_longest = val_len + return longest diff --git a/sygnal/db.py b/sygnal/db.py deleted file mode 100644 index a034e6ec..00000000 --- a/sygnal/db.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2014 matrix.org -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sqlite3 -import logging -import threading -from six.moves import queue -import sys - -logger = logging.getLogger(__name__) - -class Db: - def __init__(self, dbfile): - self.dbfile = dbfile - self.db_queue = queue.Queue() - # Sqlite is blocking and does so in the c library so we can't - # use gevent's monkey patching to make it play nice. We just - # run all sqlite in a separate thread. - self.dbthread = threading.Thread(target=self.db_loop) - self.dbthread.setDaemon(True) - self.dbthread.start() - - def db_loop(self): - self.db = sqlite3.connect(self.dbfile) - while True: - job = self.db_queue.get() - job() - - def query(self, query, args=(), fetch=None): - res = {} - ev = threading.Event() - def runquery(): - try: - c = self.db.cursor() - c.execute(query, args) - if fetch == 1 or fetch == 'one': - res['rows'] = c.fetchone() - elif fetch == 'all': - res['rows'] = c.fetchall() - elif fetch == None: - self.db.commit() - res['rowcount'] = c.rowcount - except: - logger.exception("Caught exception running db query %s", query) - res['ex'] = sys.exc_info()[1] - ev.set() - self.db_queue.put(runquery) - ev.wait() - if 'ex' in res: - raise res['ex'] - elif 'rows' in res: - return res['rows'] - elif 'rowcount' in res: - return res['rowcount'] diff --git a/sygnal/exceptions.py b/sygnal/exceptions.py index 2d2d4d3c..b72a9260 100644 --- a/sygnal/exceptions.py +++ b/sygnal/exceptions.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + class InvalidNotificationException(Exception): pass @@ -23,3 +24,13 @@ class PushkinSetupException(Exception): class NotificationDispatchException(Exception): pass + +class TemporaryNotificationDispatchException(Exception): + """ + To be used by pushkins for errors that are not our fault and are + hopefully temporary, so the request should possibly be retried soon. + """ + + def __init__(self, *args: object, custom_retry_delay=None) -> None: + super().__init__(*args) + self.custom_retry_delay = custom_retry_delay diff --git a/sygnal/gcmpushkin.py b/sygnal/gcmpushkin.py index bd068d11..4effa2f6 100644 --- a/sygnal/gcmpushkin.py +++ b/sygnal/gcmpushkin.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014 Leon Handreke # Copyright 2017 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,30 +14,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import absolute_import, division, print_function, unicode_literals - +import json import logging import time +from io import BytesIO +from json import JSONDecodeError -import grequests -import gevent -from requests import adapters, Session -from prometheus_client import Histogram +from opentracing import logs, tags +from prometheus_client import Histogram, Counter +from twisted.web.client import HTTPConnectionPool, Agent, FileBodyProducer, readBody +from twisted.web.http_headers import Headers -from . import Pushkin +from sygnal.exceptions import ( + TemporaryNotificationDispatchException, + NotificationDispatchException, +) +from sygnal.utils import twisted_sleep, NotificationLoggerAdapter from .exceptions import PushkinSetupException - +from .notifications import Pushkin SEND_TIME_HISTOGRAM = Histogram( - "sygnal_gcm_request_time", - "Time taken to send HTTP request", + "sygnal_gcm_request_time", "Time taken to send HTTP request to GCM" ) +RESPONSE_STATUS_CODES_COUNTER = Counter( + "sygnal_gcm_status_codes", + "Number of HTTP response status codes received from GCM", + labelnames=["pushkin", "code"], +) logger = logging.getLogger(__name__) -GCM_URL = "https://fcm.googleapis.com/fcm/send" +GCM_URL = b"https://fcm.googleapis.com/fcm/send" MAX_TRIES = 3 RETRY_DELAY_BASE = 10 MAX_BYTES_PER_FIELD = 1024 @@ -47,156 +56,319 @@ # though gcm-client 'helpfully' extracts these into a separate # list. BAD_PUSHKEY_FAILURE_CODES = [ - 'MissingRegistration', - 'InvalidRegistration', - 'NotRegistered', - 'InvalidPackageName', - 'MismatchSenderId', + "MissingRegistration", + "InvalidRegistration", + "NotRegistered", + "InvalidPackageName", + "MismatchSenderId", ] # Failure codes that mean the message in question will never # succeed, so don't retry, but the registration ID is fine # so we should not reject it upstream. -BAD_MESSAGE_FAILURE_CODES = [ - 'MessageTooBig', - 'InvalidDataKey', - 'InvalidTtl', -] +BAD_MESSAGE_FAILURE_CODES = ["MessageTooBig", "InvalidDataKey", "InvalidTtl"] + +DEFAULT_MAX_CONNECTIONS = 20 + class GcmPushkin(Pushkin): + """ + Pushkin that relays notifications to Google/Firebase Cloud Messaging. + """ + + UNDERSTOOD_CONFIG_FIELDS = {"type", "api_key"} + + def __init__(self, name, sygnal, config, canonical_reg_id_store): + super(GcmPushkin, self).__init__(name, sygnal, config) + + nonunderstood = set(self.cfg.keys()).difference(self.UNDERSTOOD_CONFIG_FIELDS) + if len(nonunderstood) > 0: + logger.warning( + "The following configuration fields are not understood: %s", + nonunderstood, + ) + + self.http_pool = HTTPConnectionPool(reactor=sygnal.reactor) + self.http_pool.maxPersistentPerHost = self.get_config( + "max_connections", DEFAULT_MAX_CONNECTIONS + ) - def __init__(self, name): - super(GcmPushkin, self).__init__(name) - self.session = Session() - a = adapters.HTTPAdapter(pool_maxsize=20, pool_connections=20, pool_block=True) - self.session.mount("https://", a) + self.http_agent = Agent(reactor=sygnal.reactor, pool=self.http_pool) - def setup(self, ctx): - self.db = ctx.database + self.db = sygnal.database + self.canonical_reg_id_store = canonical_reg_id_store - self.api_key = self.getConfig('apiKey') + self.api_key = self.get_config("api_key") if not self.api_key: raise PushkinSetupException("No API key set in config") - self.canonical_reg_id_store = CanonicalRegIdStore(self.db) - def dispatchNotification(self, n): - pushkeys = [device.pushkey for device in n.devices if device.app_id == self.name] - # Resolve canonical IDs for all pushkeys - pushkeys = [canonical_reg_id or reg_id for (reg_id, canonical_reg_id) in - self.canonical_reg_id_store.get_canonical_ids(pushkeys).items()] - - data = GcmPushkin.build_data(n) - headers = { - "User-Agent": "sygnal", - "Content-Type": "application/json", - "Authorization": "key=%s" % (self.api_key,) - } + @classmethod + async def create(cls, name, sygnal, config): + """ + Override this if your pushkin needs to call async code in order to + be constructed. Otherwise, it defaults to just invoking the Python-standard + __init__ constructor. + + Returns: + an instance of this Pushkin + """ + logger.debug("About to set up CanonicalRegId Store") + canonical_reg_id_store = CanonicalRegIdStore() + await canonical_reg_id_store.setup(sygnal.database) + logger.debug("Finished setting up CanonicalRegId Store") + + return cls(name, sygnal, config, canonical_reg_id_store) + + async def _perform_http_request(self, body, headers): + """ + Perform an HTTP request to the FCM server with the body and headers + specified. + Args: + body (nested dict): Body. Will be JSON-encoded. + headers (Headers): HTTP Headers. + + Returns: + + """ + body_producer = FileBodyProducer(BytesIO(json.dumps(body).encode())) + try: + response = await self.http_agent.request( + b"POST", GCM_URL, headers=Headers(headers), bodyProducer=body_producer + ) + except Exception as exception: + raise TemporaryNotificationDispatchException( + "GCM request failure" + ) from exception + response_text = (await readBody(response)).decode() + return response, response_text + + async def _request_dispatch(self, n, log, body, headers, pushkeys, span): + poke_start_time = time.time() - # TODO: Implement collapse_key to queue only one message per room. failed = [] - for retry_number in range(0, MAX_TRIES): - body = { - "data": data, - "priority": 'normal' if n.prio == 'low' else 'high', - } - if len(pushkeys) == 1: - body['to'] = pushkeys[0] - else: - body['registration_ids'] = pushkeys - - logger.info("Sending (attempt %i): %r => %r", retry_number, data, pushkeys); - poke_start_time = time.time() - - with SEND_TIME_HISTOGRAM.time(): - req = grequests.post( - GCM_URL, json=body, headers=headers, timeout=10, - session=self.session, + with SEND_TIME_HISTOGRAM.time(): + response, response_text = await self._perform_http_request(body, headers) + + RESPONSE_STATUS_CODES_COUNTER.labels( + pushkin=self.name, code=response.code + ).inc() + + log.debug("GCM request took %f seconds", time.time() - poke_start_time) + + span.set_tag(tags.HTTP_STATUS_CODE, response.code) + + if 500 <= response.code < 600: + log.debug("%d from server, waiting to try again", response.code) + + retry_after = None + + for header_value in response.headers.getRawHeader( + b"retry-after", default=[] + ): + retry_after = int(header_value) + span.log_kv({"event": "gcm_retry_after", "retry_after": retry_after}) + + raise TemporaryNotificationDispatchException( + "GCM server error, hopefully temporary.", custom_retry_delay=retry_after + ) + elif response.code == 400: + log.error( + "%d from server, we have sent something invalid! Error: %r", + response.code, + response_text, + ) + # permanent failure: give up + raise NotificationDispatchException("Invalid request") + elif response.code == 401: + log.error( + "401 from server! Our API key is invalid? Error: %r", response_text + ) + # permanent failure: give up + raise NotificationDispatchException("Not authorised to push") + elif 200 <= response.code < 300: + try: + resp_object = json.loads(response_text) + except JSONDecodeError: + raise NotificationDispatchException("Invalid JSON response from GCM.") + if "results" not in resp_object: + log.error( + "%d from server but response contained no 'results' key: %r", + response.code, + response_text, ) - req.send() - - logger.debug("GCM request took %f seconds", time.time() - poke_start_time) - - if req.response is None: - success = False - logger.debug("Request failed, waiting to try again", req.exception) - elif req.response.status_code / 100 == 5: - success = False - logger.debug("%d from server, waiting to try again", req.response.status_code) - elif req.response.status_code == 400: - logger.error( - "%d from server, we have sent something invalid! Error: %r", - req.response.status_code, - req.response.text, + if len(resp_object["results"]) < len(pushkeys): + log.error( + "Sent %d notifications but only got %d responses!", + len(n.devices), + len(resp_object["results"]), ) - # permanent failure: give up - raise Exception("Invalid request") - elif req.response.status_code == 401: - logger.error( - "401 from server! Our API key is invalid? Error: %r", - req.response.text, + span.log_kv( + { + logs.EVENT: "gcm_response_mismatch", + "num_devices": len(n.devices), + "num_results": len(resp_object["results"]), + } ) - # permanent failure: give up - raise Exception("Not authorized to push") - elif req.response.status_code / 100 == 2: - resp_object = req.response.json() - if 'results' not in resp_object: - logger.error( - "%d from server but response contained no 'results' key: %r", - req.response.status_code, req.response.text, + + # determine which pushkeys to retry or forget about + new_pushkeys = [] + for i, result in enumerate(resp_object["results"]): + span.set_tag("gcm_regid_updated", "registration_id" in result) + if "registration_id" in result: + await self.canonical_reg_id_store.set_canonical_id( + pushkeys[i], result["registration_id"] ) - if len(resp_object['results']) < len(pushkeys): - logger.error( - "Sent %d notifications but only got %d responses!", - len(n.devices), len(resp_object['results']) + if "error" in result: + log.warning( + "Error for pushkey %s: %s", pushkeys[i], result["error"] ) - - new_pushkeys = [] - for i, result in enumerate(resp_object['results']): - if 'registration_id' in result: - self.canonical_reg_id_store.set_canonical_id( - pushkeys[i], result['registration_id'] + span.set_tag("gcm_error", result["error"]) + if result["error"] in BAD_PUSHKEY_FAILURE_CODES: + log.info( + "Reg ID %r has permanently failed with code %r: " + "rejecting upstream", + pushkeys[i], + result["error"], + ) + failed.append(pushkeys[i]) + elif result["error"] in BAD_MESSAGE_FAILURE_CODES: + log.info( + "Message for reg ID %r has permanently failed with code %r", + pushkeys[i], + result["error"], ) - if 'error' in result: - logger.warn("Error for pushkey %s: %s", pushkeys[i], result['error']) - if result['error'] in BAD_PUSHKEY_FAILURE_CODES: - logger.info( - "Reg ID %r has permanently failed with code %r: rejecting upstream", - pushkeys[i], result['error'] - ) - failed.append(pushkeys[i]) - elif result['error'] in BAD_MESSAGE_FAILURE_CODES: - logger.info( - "Message for reg ID %r has permanently failed with code %r", - pushkeys[i], result['error'] - ) - else: - logger.info( - "Reg ID %r has temporarily failed with code %r", - pushkeys[i], result['error'] - ) - new_pushkeys.append(pushkeys[i]) - if len(new_pushkeys) == 0: - return failed - pushkeys = new_pushkeys - - retry_delay = RETRY_DELAY_BASE * (2 ** retry_number) - if req.response and 'retry-after' in req.response.headers: + else: + log.info( + "Reg ID %r has temporarily failed with code %r", + pushkeys[i], + result["error"], + ) + new_pushkeys.append(pushkeys[i]) + return failed, new_pushkeys + + async def dispatch_notification(self, n, device, context): + log = NotificationLoggerAdapter(logger, {"request_id": context.request_id}) + + pushkeys = [ + device.pushkey for device in n.devices if device.app_id == self.name + ] + # Resolve canonical IDs for all pushkeys + + if pushkeys[0] != device.pushkey: + # Only send notifications once, to all devices at once. + return [] + + # The pushkey is kind of secret because you can use it to send push + # to someone. + # span_tags = {"pushkeys": pushkeys} + span_tags = {"gcm_num_devices": len(pushkeys)} + + with self.sygnal.tracer.start_span( + "gcm_dispatch", tags=span_tags, child_of=context.opentracing_span + ) as span_parent: + reg_id_mappings = await self.canonical_reg_id_store.get_canonical_ids( + pushkeys + ) + + reg_id_mappings = { + reg_id: canonical_reg_id or reg_id + for (reg_id, canonical_reg_id) in reg_id_mappings.items() + } + + inverse_reg_id_mappings = {v: k for (k, v) in reg_id_mappings.items()} + + data = GcmPushkin._build_data(n) + headers = { + b"User-Agent": ["sygnal"], + b"Content-Type": ["application/json"], + b"Authorization": ["key=%s" % (self.api_key,)], + } + + # count the number of remapped registration IDs in the request + span_parent.set_tag( + "gcm_num_remapped_reg_ids_used", + [k != v for (k, v) in reg_id_mappings.items()].count(True), + ) + + # TODO: Implement collapse_key to queue only one message per room. + failed = [] + + body = {"data": data, "priority": "normal" if n.prio == "low" else "high"} + + for retry_number in range(0, MAX_TRIES): + mapped_pushkeys = [reg_id_mappings[pk] for pk in pushkeys] + + if len(pushkeys) == 1: + body["to"] = mapped_pushkeys[0] + else: + body["registration_ids"] = mapped_pushkeys + + log.info("Sending (attempt %i) => %r", retry_number, mapped_pushkeys) + try: - retry_delay = int(req.response.headers['retry-after']) - except: - pass - logger.info("Retrying in %d seconds", retry_delay) - gevent.sleep(seconds=retry_delay) + span_tags = {"retry_num": retry_number} + + with self.sygnal.tracer.start_span( + "gcm_dispatch_try", tags=span_tags, child_of=span_parent + ) as span: + new_failed, new_pushkeys = await self._request_dispatch( + n, log, body, headers, mapped_pushkeys, span + ) + pushkeys = new_pushkeys + failed += [ + inverse_reg_id_mappings[canonical_pk] + for canonical_pk in new_failed + ] + if len(pushkeys) == 0: + break + except TemporaryNotificationDispatchException as exc: + retry_delay = RETRY_DELAY_BASE * (2 ** retry_number) + if exc.custom_retry_delay is not None: + retry_delay = exc.custom_retry_delay + + log.warning( + "Temporary failure, will retry in %d seconds", + retry_delay, + exc_info=True, + ) + + span_parent.log_kv( + {"event": "temporary_fail", "retrying_in": retry_delay} + ) - logger.info("Gave up retrying reg IDs: %r", pushkeys) - return failed + await twisted_sleep( + retry_delay, twisted_reactor=self.sygnal.reactor + ) + + if len(pushkeys) > 0: + log.info("Gave up retrying reg IDs: %r", pushkeys) + # Count the number of failed devices. + span_parent.set_tag("gcm_num_failed", len(failed)) + return failed @staticmethod - def build_data(n): + def _build_data(n): + """ + Build the payload data to be sent. + Args: + n: Notification to build the payload for. + + Returns: + JSON-compatible dict + """ data = {} - for attr in ['event_id', 'type', 'sender', 'room_name', 'room_alias', 'membership', - 'sender_display_name', 'content', 'room_id']: + for attr in [ + "event_id", + "type", + "sender", + "room_name", + "room_alias", + "membership", + "sender_display_name", + "content", + "room_id", + ]: if hasattr(n, attr): data[attr] = getattr(n, attr) # Truncate fields to a sensible maximum length. If the whole @@ -204,40 +376,84 @@ def build_data(n): if data[attr] is not None and len(data[attr]) > MAX_BYTES_PER_FIELD: data[attr] = data[attr][0:MAX_BYTES_PER_FIELD] - data['prio'] = 'high' - if n.prio == 'low': - data['prio'] = 'normal'; + data["prio"] = "high" + if n.prio == "low": + data["prio"] = "normal" - if getattr(n, 'counts', None): - data['unread'] = n.counts.unread - data['missed_calls'] = n.counts.missed_calls + if getattr(n, "counts", None): + data["unread"] = n.counts.unread + data["missed_calls"] = n.counts.missed_calls return data class CanonicalRegIdStore(object): - TABLE_CREATE_QUERY = """ CREATE TABLE IF NOT EXISTS gcm_canonical_reg_id ( reg_id TEXT PRIMARY KEY, - canonical_reg_id TEXT NOT NULL);""" + canonical_reg_id TEXT NOT NULL + ); + """ - def __init__(self, db): + def __init__(self): + self.db = None + + async def setup(self, db): + """ + Prepares, if necessary, the database for storing canonical registration IDs. + + Separate method from the constructor because we wait for an async request + to complete, so it must be an `async def` method. + + Args: + db (adbapi.ConnectionPool): database to prepare + + """ self.db = db - self.db.query(self.TABLE_CREATE_QUERY) - def set_canonical_id(self, reg_id, canonical_reg_id): - self.db.query( - "INSERT OR REPLACE INTO gcm_canonical_reg_id VALUES (?, ?);", - (reg_id, canonical_reg_id)) + await self.db.runQuery(self.TABLE_CREATE_QUERY) - def get_canonical_ids(self, reg_ids): - # TODO: Use one DB query - return {reg_id: self._get_canonical_id(reg_id) for reg_id in reg_ids} + async def set_canonical_id(self, reg_id, canonical_reg_id): + """ + Associates a GCM registration ID with a canonical registration ID. + Args: + reg_id (str): a registration ID + canonical_reg_id (str): the canonical registration ID for `reg_id` + """ + await self.db.runQuery( + "INSERT OR REPLACE INTO gcm_canonical_reg_id VALUES (?, ?);", + (reg_id, canonical_reg_id), + ) + + async def get_canonical_ids(self, reg_ids): + """ + Retrieves the canonical registration ID for multiple registration IDs. + + Args: + reg_ids (iterable): registration IDs to retrieve canonical registration + IDs for. + + Returns (dict): + mapping of registration ID to either its canonical registration ID, + or `None` if there is no entry. + """ + return {reg_id: await self.get_canonical_id(reg_id) for reg_id in reg_ids} + + async def get_canonical_id(self, reg_id): + """ + Retrieves the canonical registration ID for one registration ID. + + Args: + reg_id (str): registration ID to retrieve the canonical registration + ID for. + + Returns (dict): + its canonical registration ID, or `None` if there is no entry. + """ + rows = await self.db.runQuery( + "SELECT canonical_reg_id FROM gcm_canonical_reg_id WHERE reg_id = ?", + (reg_id,), + ) - def _get_canonical_id(self, reg_id): - rows = self.db.query( - "SELECT canonical_reg_id FROM gcm_canonical_reg_id WHERE reg_id = ?;", - (reg_id, ), fetch='all') if rows: return rows[0][0] diff --git a/sygnal/http.py b/sygnal/http.py new file mode 100644 index 00000000..78872707 --- /dev/null +++ b/sygnal/http.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import sys +import traceback +from uuid import uuid4 + +from opentracing import Format, tags, logs +from prometheus_client import Counter +from twisted.internet import defer +from twisted.internet.defer import gatherResults, ensureDeferred +from twisted.python.failure import Failure +from twisted.web import server +from twisted.web.http import proxiedLogFormatter +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET + +from sygnal.notifications import NotificationContext +from sygnal.utils import NotificationLoggerAdapter +from .exceptions import InvalidNotificationException, NotificationDispatchException +from .notifications import Notification + +logger = logging.getLogger(__name__) + +NOTIFS_RECEIVED_COUNTER = Counter( + "sygnal_notifications_received", "Number of notification pokes received" +) + +NOTIFS_RECEIVED_DEVICE_PUSH_COUNTER = Counter( + "sygnal_notifications_devices_received", "Number of devices been asked to push" +) + +NOTIFS_BY_PUSHKIN = Counter( + "sygnal_per_pushkin_type", + "Number of pushes sent via each type of pushkin", + labelnames=["pushkin"], +) + +PUSHGATEWAY_HTTP_RESPONSES_COUNTER = Counter( + "sygnal_pushgateway_status_codes", + "HTTP Response Codes given on the Push Gateway API", + labelnames=["code"], +) + + +class V1NotifyHandler(Resource): + def __init__(self, sygnal): + super().__init__() + self.sygnal = sygnal + + isLeaf = True + + def _make_request_id(self): + """ + Generates a request ID, intended to be unique, for a request so it can + be followed through logging. + Returns: a request ID for the request. + """ + return str(uuid4()) + + def render_POST(self, request): + response = self._handle_request(request) + if response != NOT_DONE_YET: + PUSHGATEWAY_HTTP_RESPONSES_COUNTER.labels(code=request.code).inc() + return response + + def _handle_request(self, request): + """ + Actually handle the request. + Args: + request (Request): The request, corresponding to a POST request. + + Returns: + Either a str instance or NOT_DONE_YET. + + """ + request_id = self._make_request_id() + header_dict = { + k.decode(): v[0].decode() + for k, v in request.requestHeaders.getAllRawHeaders() + } + + # extract OpenTracing scope from the HTTP headers + span_ctx = self.sygnal.tracer.extract(Format.HTTP_HEADERS, header_dict) + span_tags = { + tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, + "request_id": request_id, + } + + root_span = self.sygnal.tracer.start_span( + "pushgateway_v1_notify", child_of=span_ctx, tags=span_tags + ) + + # if this is True, we will not close the root_span at the end of this + # function. + root_span_accounted_for = False + + try: + context = NotificationContext(request_id, root_span) + + log = NotificationLoggerAdapter(logger, {"request_id": request_id}) + + try: + body = json.loads(request.content.read()) + except Exception as exc: + msg = "Expected JSON request body" + log.warning(msg, exc_info=exc) + root_span.log_kv({logs.EVENT: "error", "error.object": exc}) + request.setResponseCode(400) + return msg.encode() + + if "notification" not in body or not isinstance(body["notification"], dict): + msg = "Invalid notification: expecting object in 'notification' key" + log.warning(msg) + root_span.log_kv({logs.EVENT: "error", "message": msg}) + request.setResponseCode(400) + return msg.encode() + + try: + notif = Notification(body["notification"]) + except InvalidNotificationException as e: + log.exception("Invalid notification") + request.setResponseCode(400) + root_span.log_kv({logs.EVENT: "error", "error.object": e}) + return str(e).encode() + + if notif.event_id is not None: + root_span.set_tag("event_id", notif.event_id) + + # track whether the notification was passed with content + root_span.set_tag("has_content", notif.content is not None) + + NOTIFS_RECEIVED_COUNTER.inc() + + if len(notif.devices) == 0: + msg = "No devices in notification" + log.warning(msg) + request.setResponseCode(400) + return msg.encode() + + rej = [] + deferreds = [] + + pushkins = self.sygnal.pushkins + + for d in notif.devices: + NOTIFS_RECEIVED_DEVICE_PUSH_COUNTER.inc() + + appid = d.app_id + if appid not in pushkins: + log.warning("Got notification for unknown app ID %s", appid) + rej.append(d.pushkey) + continue + + pushkin = pushkins[appid] + log.debug( + "Sending push to pushkin %s for app ID %s", pushkin.name, appid + ) + + NOTIFS_BY_PUSHKIN.labels(pushkin.name).inc() + + async def dispatch_checked(): + """ + Dispatches a notification and checks the Pushkin + returns a list. + Returns (list): + The result + """ + result = await pushkin.dispatch_notification(notif, d, context) + if not isinstance(result, list): + raise TypeError("Pushkin should return list.") + return result + + deferreds.append(ensureDeferred(dispatch_checked())) + + def callback(rejected_lists): + # combine all rejected pushkeys into one list + + rejected = sum(rejected_lists, rej) + + request.write(json.dumps({"rejected": rejected}).encode()) + + log.info( + "Successfully delivered notifications" " with %d rejected pushkeys", + len(rejected), + ) + + request.finish() + + def errback(failure: Failure): + # due to gatherResults, errors will be wrapped in FirstError. + if issubclass(failure.type, defer.FirstError): + subfailure = failure.value.subFailure + if issubclass(subfailure.type, NotificationDispatchException): + request.setResponseCode(502) + log.warning( + "Failed to dispatch notification.", exc_info=subfailure + ) + else: + request.setResponseCode(500) + log.error( + "Exception whilst dispatching notification.", + exc_info=subfailure, + ) + else: + request.setResponseCode(500) + log.error( + "Exception whilst dispatching notification.", exc_info=failure + ) + + request.finish() + + aggregate = gatherResults(deferreds, consumeErrors=True) + aggregate.addCallback(callback) + aggregate.addErrback(errback) + + def count_deferred_code(_): + PUSHGATEWAY_HTTP_RESPONSES_COUNTER.labels(code=request.code).inc() + root_span.set_tag(tags.HTTP_STATUS_CODE, request.code) + if not 200 <= request.code < 300: + root_span.set_tag(tags.ERROR, True) + root_span.finish() + + aggregate.addCallback(count_deferred_code) + root_span_accounted_for = True + + # we have to try and send the notifications first, + # so we can find out which ones to reject + return NOT_DONE_YET + except Exception as exc_val: + root_span.set_tag(tags.ERROR, True) + + # [2] corresponds to the traceback + trace = traceback.format_tb(sys.exc_info()[2]) + root_span.log_kv( + { + logs.EVENT: tags.ERROR, + logs.MESSAGE: str(exc_val), + logs.ERROR_OBJECT: exc_val, + logs.ERROR_KIND: type(exc_val), + logs.STACK: trace, + } + ) + raise + finally: + if not root_span_accounted_for: + root_span.finish() + + +class PushGatewayApiServer(object): + def __init__(self, sygnal): + """ + Initialises the /_matrix/push/* (Push Gateway API) server. + Args: + sygnal (Sygnal): the Sygnal object + """ + root = Resource() + matrix = Resource() + push = Resource() + v1 = Resource() + + # Note that using plain strings here will lead to silent failure + root.putChild(b"_matrix", matrix) + matrix.putChild(b"push", push) + push.putChild(b"v1", v1) + v1.putChild(b"notify", V1NotifyHandler(sygnal)) + + use_x_forwarded_for = sygnal.config["log"]["access"]["x_forwarded_for"] + + log_formatter = proxiedLogFormatter if use_x_forwarded_for else None + + self.site = server.Site( + root, reactor=sygnal.reactor, logFormatter=log_formatter + ) diff --git a/sygnal/notifications.py b/sygnal/notifications.py new file mode 100644 index 00000000..59b884c5 --- /dev/null +++ b/sygnal/notifications.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .exceptions import InvalidNotificationException + + +class Tweaks: + def __init__(self, raw): + self.sound = None + + if "sound" in raw: + self.sound = raw["sound"] + + +class Device: + def __init__(self, raw): + self.app_id = None + self.pushkey = None + self.pushkey_ts = 0 + self.data = None + self.tweaks = None + + if "app_id" not in raw: + raise InvalidNotificationException("Device with no app_id") + if "pushkey" not in raw: + raise InvalidNotificationException("Device with no pushkey") + if "pushkey_ts" in raw: + self.pushkey_ts = raw["pushkey_ts"] + if "tweaks" in raw: + self.tweaks = Tweaks(raw["tweaks"]) + else: + self.tweaks = Tweaks({}) + self.app_id = raw["app_id"] + self.pushkey = raw["pushkey"] + if "data" in raw: + self.data = raw["data"] + + +class Counts: + def __init__(self, raw): + self.unread = None + self.missed_calls = None + + if "unread" in raw: + self.unread = raw["unread"] + if "missed_calls" in raw: + self.missed_calls = raw["missed_calls"] + + +class Notification: + def __init__(self, notif): + optional_attrs = [ + "room_name", + "room_alias", + "prio", + "membership", + "sender_display_name", + "content", + "event_id", + "room_id", + "user_is_target", + "type", + "sender", + ] + for a in optional_attrs: + if a in notif: + self.__dict__[a] = notif[a] + else: + self.__dict__[a] = None + + if "devices" not in notif or not isinstance(notif["devices"], list): + raise InvalidNotificationException("Expected list in 'devices' key") + + if "counts" in notif: + self.counts = Counts(notif["counts"]) + else: + self.counts = Counts({}) + + self.devices = [Device(d) for d in notif["devices"]] + + +class Pushkin(object): + def __init__(self, name, sygnal, config): + self.name = name + self.cfg = config + self.sygnal = sygnal + + def get_config(self, key, default=None): + if key not in self.cfg: + return default + return self.cfg[key] + + async def dispatch_notification(self, n, device, context): + """ + Args: + n: The notification to dispatch via this pushkin + device: The device to dispatch the notification for. + context (NotificationContext): the request context + + Returns: + A list of rejected pushkeys, to be reported back to the homeserver + """ + pass + + @classmethod + async def create(cls, name, sygnal, config): + """ + Override this if your pushkin needs to call async code in order to + be constructed. Otherwise, it defaults to just invoking the Python-standard + __init__ constructor. + + Returns: + an instance of this Pushkin + """ + return cls(name, sygnal, config) + + +class NotificationContext(object): + def __init__(self, request_id, opentracing_span): + """ + Args: + request_id (str): An ID for the request, or None to have it + generated automatically. + opentracing_span (Span): The span for the API request triggering + the notification. + """ + self.request_id = request_id + self.opentracing_span = opentracing_span diff --git a/sygnal/sygnal.py b/sygnal/sygnal.py new file mode 100644 index 00000000..44293a4a --- /dev/null +++ b/sygnal/sygnal.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# Copyright 2018, 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import importlib +import logging +import logging.config +import os +import sys + +import opentracing +import prometheus_client + +# import twisted.internet.reactor +import yaml +from opentracing.scope_managers.asyncio import AsyncioScopeManager +from twisted.enterprise.adbapi import ConnectionPool +from twisted.internet import asyncioreactor +from twisted.internet.defer import ensureDeferred +from twisted.python import log as twisted_log + +from sygnal.http import PushGatewayApiServer + +logger = logging.getLogger(__name__) + +CONFIG_DEFAULTS = { + "http": {"port": 5000, "bind_addresses": ["127.0.0.1"]}, + "log": {"setup": {}, "access": {"x_forwarded_for": False}}, + "db": {"dbfile": "sygnal.db"}, + "metrics": { + "prometheus": {"enabled": False, "address": "127.0.0.1", "port": 8000}, + "opentracing": { + "enabled": False, + "implementation": None, + "jaeger": {}, + "service_name": "sygnal", + }, + "sentry": {"enabled": False}, + }, + "apps": {}, +} + + +class Sygnal(object): + def __init__(self, config, custom_reactor, tracer=opentracing.tracer): + """ + Object that holds state for the entirety of a Sygnal instance. + Args: + config (dict): Configuration for this Sygnal + custom_reactor: a Twisted Reactor to use. + tracer (optional): an OpenTracing tracer. The default is the no-op tracer. + """ + self.config = config + self.reactor = custom_reactor + self.pushkins = {} + self.tracer = tracer + + logging_dict_config = config["log"]["setup"] + logging.config.dictConfig(logging_dict_config) + + logger.debug("Started logging") + + observer = twisted_log.PythonLoggingObserver(loggerName="sygnal.access") + observer.start() + + sentrycfg = config["metrics"]["sentry"] + if sentrycfg["enabled"] is True: + import sentry_sdk + + logger.info("Initialising Sentry") + sentry_sdk.init(sentrycfg["dsn"]) + + promcfg = config["metrics"]["prometheus"] + if promcfg["enabled"] is True: + prom_addr = promcfg["address"] + prom_port = int(promcfg["port"]) + logger.info( + "Starting Prometheus Server on %s port %d", prom_addr, prom_port + ) + + prometheus_client.start_http_server(port=prom_port, addr=prom_addr or "") + + tracecfg = config["metrics"]["opentracing"] + if tracecfg["enabled"] is True: + if tracecfg["implementation"] == "jaeger": + try: + import jaeger_client + + jaeger_cfg = jaeger_client.Config( + config=tracecfg["jaeger"], + service_name=tracecfg["service_name"], + scope_manager=AsyncioScopeManager(), + ) + + self.tracer = jaeger_cfg.initialize_tracer() + + logger.info("Enabled OpenTracing support with Jaeger") + except ModuleNotFoundError: + logger.critical( + "You have asked for OpenTracing with Jaeger but do not have" + " the Python package 'jaeger_client' installed." + ) + raise + else: + logger.error( + "Unknown OpenTracing implementation: %s.", tracecfg["impl"] + ) + sys.exit(1) + + self.database = ConnectionPool( + "sqlite3", + config["db"]["dbfile"], + cp_reactor=self.reactor, + cp_min=1, + cp_max=1, + check_same_thread=False, + ) + + async def _make_pushkin(self, app_name, app_config): + """ + Load and instantiate a pushkin. + Args: + app_name (str): The pushkin's app_id + app_config (dict): The pushkin's configuration + + Returns (Pushkin): + A pushkin of the desired type. + """ + app_type = app_config["type"] + if "." in app_type: + kind_split = app_type.rsplit(".", 1) + to_import = kind_split[0] + to_construct = kind_split[1] + else: + to_import = f"sygnal.{app_type}pushkin" + to_construct = f"{app_type.capitalize()}Pushkin" + + logger.info("Importing pushkin module: %s", to_import) + pushkin_module = importlib.import_module(to_import) + logger.info("Creating pushkin: %s", to_construct) + clarse = getattr(pushkin_module, to_construct) + return await clarse.create(app_name, self, app_config) + + async def _make_pushkins_then_start(self, port, bind_addresses, pushgateway_api): + for app_id, app_cfg in self.config["apps"].items(): + try: + self.pushkins[app_id] = await self._make_pushkin(app_id, app_cfg) + except Exception: + logger.exception( + "Failed to load and create pushkin for kind %s", app_cfg["type"] + ) + sys.exit(1) + + if len(self.pushkins) == 0: + logger.error("No app IDs are configured. Edit sygnal.yaml to define some.") + sys.exit(1) + + logger.info("Configured with app IDs: %r", self.pushkins.keys()) + + for interface in bind_addresses: + logger.info("Starting listening on %s port %d", interface, port) + self.reactor.listenTCP(port, pushgateway_api.site, interface=interface) + + def run(self): + """ + Attempt to run Sygnal and then exit the application. + """ + port = int(self.config["http"]["port"]) + bind_addresses = self.config["http"]["bind_addresses"] + pushgateway_api = PushGatewayApiServer(self) + + ensureDeferred( + self._make_pushkins_then_start(port, bind_addresses, pushgateway_api) + ) + self.reactor.run() + + +def parse_config(): + """ + Find and load Sygnal's configuration file. + Returns (dict): + A loaded configuration. + """ + config_path = os.getenv("SYGNAL_CONF", "sygnal.yaml") + try: + with open(config_path) as file_handle: + return yaml.safe_load(file_handle) + except FileNotFoundError: + logger.critical( + "Could not find configuration file!\n" "Path: %s\n" "Absolute Path: %s", + config_path, + os.path.realpath(config_path), + ) + raise + + +def check_config(config): + """ + Lightly check the configuration and issue warnings as appropriate. + Args: + config: The loaded configuration. + """ + UNDERSTOOD_CONFIG_FIELDS = CONFIG_DEFAULTS.keys() + + def check_section(section_name, known_keys, cfgpart=config): + nonunderstood = set(cfgpart[section_name].keys()).difference(known_keys) + if len(nonunderstood) > 0: + logger.warning( + f"The following configuration fields in '{section_name}' " + f"are not understood: %s", + nonunderstood, + ) + + nonunderstood = set(config.keys()).difference(UNDERSTOOD_CONFIG_FIELDS) + if len(nonunderstood) > 0: + logger.warning( + "The following configuration fields are not understood: %s", nonunderstood + ) + + check_section("http", {"port", "bind_addresses"}) + check_section("log", {"setup", "access"}) + check_section( + "access", {"file", "enabled", "x_forwarded_for"}, cfgpart=config["log"] + ) + check_section("db", {"dbfile"}) + check_section("metrics", {"opentracing", "sentry", "prometheus"}) + check_section( + "opentracing", + {"enabled", "implementation", "jaeger", "service_name"}, + cfgpart=config["metrics"], + ) + check_section( + "prometheus", {"enabled", "address", "port"}, cfgpart=config["metrics"] + ) + check_section("sentry", {"enabled", "dsn"}, cfgpart=config["metrics"]) + + +def merge_left_with_defaults(defaults, loaded_config): + """ + Merge two configurations, with one of them overriding the other. + Args: + defaults (dict): A configuration of defaults + loaded_config (dict): A configuration, as loaded from disk. + + Returns (dict): + A merged configuration, with loaded_config preferred over defaults. + """ + result = defaults.copy() + + if loaded_config is None: + return result + + # copy defaults or override them + for k, v in result.items(): + if isinstance(v, dict): + if k in loaded_config: + result[k] = merge_left_with_defaults(v, loaded_config[k]) + else: + result[k] = copy.deepcopy(v) + elif k in loaded_config: + result[k] = loaded_config[k] + + # copy things with no defaults + for k, v in loaded_config.items(): + if k not in result: + result[k] = v + + return result + + +if __name__ == "__main__": + # TODO we don't want to have to install the reactor, when we can get away with + # it + asyncioreactor.install() + + # we remove the global reactor to make it evident when it has accidentally + # been used: + # ! twisted.internet.reactor = None + # TODO can't do this ^ yet, since twisted.internet.task.{coiterate,cooperate} + # (indirectly) depend on the globally-installed reactor and there's no way + # to pass in a custom one. + # and twisted.web.client uses twisted.internet.task.cooperate + + config = parse_config() + config = merge_left_with_defaults(CONFIG_DEFAULTS, config) + check_config(config) + sygnal = Sygnal(config, custom_reactor=asyncioreactor.AsyncioSelectorReactor()) + sygnal.run() diff --git a/sygnal/utils.py b/sygnal/utils.py new file mode 100644 index 00000000..0da68868 --- /dev/null +++ b/sygnal/utils.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from logging import LoggerAdapter + +from twisted.internet.defer import Deferred + + +async def twisted_sleep(delay, twisted_reactor): + """ + Creates a Deferred which will fire in a set time. + This allows you to `await` on it and have an async analogue to + L{time.sleep}. + Args: + delay: Delay in seconds + twisted_reactor: Reactor to use for sleeping. + + Returns: + a Deferred which fires in `delay` seconds. + """ + deferred = Deferred() + twisted_reactor.callLater(delay, deferred.callback, None) + await deferred + + +class NotificationLoggerAdapter(LoggerAdapter): + def process(self, msg, kwargs): + return f"[{self.extra['request_id']}] {msg}", kwargs diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_apns.py b/tests/test_apns.py new file mode 100644 index 00000000..e29277e2 --- /dev/null +++ b/tests/test_apns.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch, MagicMock + +from aioapns.common import NotificationResult + +from sygnal import apnstruncate +from tests import testutils + +PUSHKIN_ID = "com.example.apns" + +TEST_CERTFILE_PATH = "/path/to/my/certfile.pem" + +DEVICE_EXAMPLE = {"app_id": "com.example.apns", "pushkey": "spqr", "pushkey_ts": 42} + + +class ApnsTestCase(testutils.TestCase): + def setUp(self): + self.apns_mock_class = patch("sygnal.apnspushkin.APNs").start() + self.apns_mock = MagicMock() + self.apns_mock_class.return_value = self.apns_mock + + # pretend our certificate exists + patch("os.path.exists", lambda x: x == TEST_CERTFILE_PATH).start() + self.addCleanup(patch.stopall) + + super(ApnsTestCase, self).setUp() + + self.apns_pushkin_snotif = MagicMock() + self.sygnal.pushkins[PUSHKIN_ID]._send_notification = self.apns_pushkin_snotif + + def config_setup(self, config): + super(ApnsTestCase, self).config_setup(config) + config["apps"][PUSHKIN_ID] = {"type": "apns", "certfile": TEST_CERTFILE_PATH} + + def test_payload_truncation(self): + """ + Tests that APNS message bodies will be truncated to fit the limits of + APNS. + """ + # Arrange + method = self.apns_pushkin_snotif + method.return_value = testutils.make_async_magic_mock( + NotificationResult("notID", "200") + ) + self.sygnal.pushkins[PUSHKIN_ID].MAX_JSON_BODY_SIZE = 200 + + # Act + self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + # Assert + self.assertEquals(1, method.call_count) + ((notification_req,), _kwargs) = method.call_args + payload = notification_req.message + + self.assertLessEqual(len(apnstruncate.json_encode(payload)), 200) + + def test_payload_truncation_test_validity(self): + """ + This tests that L{test_payload_truncation_success} is a valid test + by showing that not limiting the truncation size would result in a + longer message. + """ + # Arrange + method = self.apns_pushkin_snotif + method.return_value = testutils.make_async_magic_mock( + NotificationResult("notID", "200") + ) + self.sygnal.pushkins[PUSHKIN_ID].MAX_JSON_BODY_SIZE = 4096 + + # Act + self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + # Assert + self.assertEquals(1, method.call_count) + ((notification_req,), _kwargs) = method.call_args + payload = notification_req.message + + self.assertGreater(len(apnstruncate.json_encode(payload)), 200) + + def test_expected(self): + """ + Tests the expected case: a good response from APNS means we pass on + a good response to the homeserver. + """ + # Arrange + method = self.apns_pushkin_snotif + method.side_effect = testutils.make_async_magic_mock( + NotificationResult("notID", "200") + ) + + # Act + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + # Assert + self.assertEquals(1, method.call_count) + ((notification_req,), _kwargs) = method.call_args + + self.assertEquals( + { + "room_id": "!slw48wfj34rtnrf:example.com", + "aps": { + "alert": { + "loc-key": "MSG_FROM_USER_IN_ROOM_WITH_CONTENT", + "loc-args": [ + "Major Tom", + "Mission Control", + "I'm floating in a most peculiar way.", + ], + }, + "badge": 3, + "content-available": 1, + }, + }, + notification_req.message, + ) + + self.assertEquals({"rejected": []}, resp) + + def test_rejection(self): + """ + Tests the rejection case: a rejection response from APNS leads to us + passing on a rejection to the homeserver. + """ + # Arrange + method = self.apns_pushkin_snotif + method.side_effect = testutils.make_async_magic_mock( + NotificationResult("notID", "410", description="Unregistered") + ) + + # Act + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + # Assert + self.assertEquals(1, method.call_count) + self.assertEquals({"rejected": ["spqr"]}, resp) + + def test_no_retry_on_4xx(self): + """ + Test that we don't retry when we get a 4xx error but do not mark as + rejected. + """ + # Arrange + method = self.apns_pushkin_snotif + method.side_effect = testutils.make_async_magic_mock( + NotificationResult("notID", "429", description="TooManyRequests") + ) + + # Act + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + # Assert + self.assertEquals(1, method.call_count) + self.assertEquals(502, resp) + + def test_retry_on_5xx(self): + """ + Test that we DO retry when we get a 5xx error and do not mark as + rejected. + """ + # Arrange + method = self.apns_pushkin_snotif + method.side_effect = testutils.make_async_magic_mock( + NotificationResult("notID", "503", description="ServiceUnavailable") + ) + + # Act + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + # Assert + self.assertGreater(method.call_count, 1) + self.assertEquals(502, resp) diff --git a/tests/test_apnstruncate.py b/tests/test_apnstruncate.py new file mode 100644 index 00000000..7eb07ffd --- /dev/null +++ b/tests/test_apnstruncate.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied and adapted from +# https://raw.githubusercontent.com/matrix-org/pushbaby/master/tests/test_truncate.py + + +import string +import unittest + +from sygnal.apnstruncate import truncate, json_encode + + +def simplestring(length, offset=0): + """ + Deterministically generates a string. + Args: + length: Length of the string + offset: Offset of the string + + Returns: + A string formed of lowercase ASCII characters. + """ + return "".join( + [ + string.ascii_lowercase[(i + offset) % len(string.ascii_lowercase)] + for i in range(length) + ] + ) + + +def sillystring(length, offset=0): + """ + Deterministically generates a string + Args: + length: Length of the string + offset: Offset of the string + + Returns: + A string formed of weird and wonderful UTF-8 emoji characters. + """ + chars = ["\U0001F430", "\U0001F431", "\U0001F432", "\U0001F433"] + return "".join([chars[(i + offset) % len(chars)] for i in range(length)]) + + +def payload_for_aps(aps): + """ + Returns the APNS payload for an 'aps' dictionary. + """ + return {"aps": aps} + + +class TruncateTestCase(unittest.TestCase): + def test_dont_truncate(self): + """ + Tests that truncation is not performed if unnecessary. + """ + # This shouldn't need to be truncated + txt = simplestring(20) + aps = {"alert": txt} + self.assertEquals(txt, truncate(payload_for_aps(aps), 256)["aps"]["alert"]) + + def test_truncate_alert(self): + """ + Tests that the 'alert' string field will be truncated when needed. + """ + overhead = len(json_encode(payload_for_aps({"alert": ""}))) + txt = simplestring(10) + aps = {"alert": txt} + self.assertEquals( + txt[:5], truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"] + ) + + def test_truncate_alert_body(self): + """ + Tests that the 'alert' 'body' field will be truncated when needed. + """ + overhead = len(json_encode(payload_for_aps({"alert": {"body": ""}}))) + txt = simplestring(10) + aps = {"alert": {"body": txt}} + self.assertEquals( + txt[:5], + truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["body"], + ) + + def test_truncate_loc_arg(self): + """ + Tests that the 'alert' 'loc-args' field will be truncated when needed. + (Tests with one loc arg) + """ + overhead = len(json_encode(payload_for_aps({"alert": {"loc-args": [""]}}))) + txt = simplestring(10) + aps = {"alert": {"loc-args": [txt]}} + self.assertEquals( + txt[:5], + truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["loc-args"][0], + ) + + def test_truncate_loc_args(self): + """ + Tests that the 'alert' 'loc-args' field will be truncated when needed. + (Tests with two loc args) + """ + overhead = len(json_encode(payload_for_aps({"alert": {"loc-args": ["", ""]}}))) + txt = simplestring(10) + txt2 = simplestring(10, 3) + aps = {"alert": {"loc-args": [txt, txt2]}} + self.assertEquals( + txt[:5], + truncate(payload_for_aps(aps), overhead + 10)["aps"]["alert"]["loc-args"][ + 0 + ], + ) + self.assertEquals( + txt2[:5], + truncate(payload_for_aps(aps), overhead + 10)["aps"]["alert"]["loc-args"][ + 1 + ], + ) + + def test_python_unicode_support(self): + """ + Tests Python's unicode support :- + a one character unicode string should have a length of one, even if it's one + multibyte character. + OS X, for example, is broken, and counts the number of surrogate pairs. + I have no great desire to manually parse UTF-8 to work around this since + it works fine on Linux. + """ + if len(u"\U0001F430") != 1: + msg = ( + "Unicode support is broken in your Python binary. " + + "Truncating messages with multibyte unicode characters will fail." + ) + self.fail(msg) + + def test_truncate_string_with_multibyte(self): + """ + Tests that truncation works as expected on strings containing one + multibyte character. + """ + overhead = len(json_encode(payload_for_aps({"alert": ""}))) + txt = u"\U0001F430" + simplestring(30) + aps = {"alert": txt} + # NB. The number of characters of the string we get is dependent + # on the json encoding used. + self.assertEquals( + txt[:17], truncate(payload_for_aps(aps), overhead + 20)["aps"]["alert"] + ) + + def test_truncate_multibyte(self): + """ + Tests that truncation works as expected on strings containing only + multibyte characters. + """ + overhead = len(json_encode(payload_for_aps({"alert": ""}))) + txt = sillystring(30) + aps = {"alert": txt} + trunc = truncate(payload_for_aps(aps), overhead + 30) + # The string is all 4 byte characters so the trunctaed UTF-8 string + # should be a multiple of 4 bytes long + self.assertEquals(len(trunc["aps"]["alert"].encode()) % 4, 0) + # NB. The number of characters of the string we get is dependent + # on the json encoding used. + self.assertEquals(txt[:7], trunc["aps"]["alert"]) diff --git a/tests/test_gcm.py b/tests/test_gcm.py new file mode 100644 index 00000000..87496ae9 --- /dev/null +++ b/tests/test_gcm.py @@ -0,0 +1,214 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from sygnal.gcmpushkin import GcmPushkin +from tests import testutils +from tests.testutils import DummyResponse + +DEVICE_EXAMPLE = {"app_id": "com.example.gcm", "pushkey": "spqr", "pushkey_ts": 42} +DEVICE_EXAMPLE2 = {"app_id": "com.example.gcm", "pushkey": "spqr2", "pushkey_ts": 42} + + +class TestGcmPushkin(GcmPushkin): + """ + A GCM pushkin with the ability to make HTTP requests removed and instead + can be preloaded with virtual requests. + """ + + def __init__(self, name, sygnal, config, canonical_reg_id_store): + super().__init__(name, sygnal, config, canonical_reg_id_store) + self.preloaded_response = None + self.preloaded_response_payload = None + self.last_request_body = None + self.last_request_headers = None + self.num_requests = 0 + + def preload_with_response(self, code, response_payload): + """ + Preloads a fake GCM response. + """ + self.preloaded_response = DummyResponse(code) + self.preloaded_response_payload = response_payload + + async def _perform_http_request(self, body, headers): + self.last_request_body = body + self.last_request_headers = headers + self.num_requests += 1 + return self.preloaded_response, json.dumps(self.preloaded_response_payload) + + +class GcmTestCase(testutils.TestCase): + def config_setup(self, config): + super(GcmTestCase, self).config_setup(config) + config["apps"]["com.example.gcm"] = { + "type": "tests.test_gcm.TestGcmPushkin", + "api_key": "kii", + } + + def test_expected(self): + """ + Tests the expected case: a good response from GCM leads to a good + response from Sygnal. + """ + gcm = self.sygnal.pushkins["com.example.gcm"] + gcm.preload_with_response( + 200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]} + ) + + req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + resp = self._collect_request(req) + + self.assertEquals(resp, {"rejected": []}) + self.assertEquals(gcm.num_requests, 1) + + def test_rejected(self): + """ + Tests the rejected case: a pushkey rejected to GCM leads to Sygnal + informing the homeserver of the rejection. + """ + gcm = self.sygnal.pushkins["com.example.gcm"] + gcm.preload_with_response( + 200, {"results": [{"registration_id": "spqr", "error": "NotRegistered"}]} + ) + + req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + resp = self._collect_request(req) + + self.assertEquals(resp, {"rejected": ["spqr"]}) + self.assertEquals(gcm.num_requests, 1) + + def test_regenerated_id(self): + """ + Tests that pushkeys regenerated by GCM are kept track of ­ that is, + the new device ID is used in lieu of the old one if we are aware that + it has changed. + """ + gcm = self.sygnal.pushkins["com.example.gcm"] + gcm.preload_with_response( + 200, {"results": [{"registration_id": "spqr_new", "message_id": "msg42"}]} + ) + + req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + resp = self._collect_request(req) + + self.assertEquals(resp, {"rejected": []}) + + gcm.preload_with_response( + 200, {"results": [{"registration_id": "spqr_new", "message_id": "msg43"}]} + ) + + req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + resp = self._collect_request(req) + + self.assertEquals(gcm.last_request_body["to"], "spqr_new") + + self.assertEquals(resp, {"rejected": []}) + self.assertEquals(gcm.num_requests, 2) + + def test_batching(self): + """ + Tests that multiple GCM devices have their notification delivered to GCM + together, instead of being delivered separately. + """ + gcm = self.sygnal.pushkins["com.example.gcm"] + gcm.preload_with_response( + 200, + { + "results": [ + {"registration_id": "spqr", "message_id": "msg42"}, + {"registration_id": "spqr2", "message_id": "msg42"}, + ] + }, + ) + + req = self._make_request( + self._make_dummy_notification([DEVICE_EXAMPLE, DEVICE_EXAMPLE2]) + ) + + resp = self._collect_request(req) + + self.assertEquals(resp, {"rejected": []}) + self.assertEquals(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"]) + self.assertEquals(gcm.num_requests, 1) + + def test_batching_individual_failure(self): + """ + Tests that multiple GCM devices have their notification delivered to GCM + together, instead of being delivered separately, + and that if only one device ID is rejected, then only that device is + reported to the homeserver as rejected. + """ + gcm = self.sygnal.pushkins["com.example.gcm"] + gcm.preload_with_response( + 200, + { + "results": [ + {"registration_id": "spqr", "message_id": "msg42"}, + {"registration_id": "spqr2", "error": "NotRegistered"}, + ] + }, + ) + + req = self._make_request( + self._make_dummy_notification([DEVICE_EXAMPLE, DEVICE_EXAMPLE2]) + ) + + resp = self._collect_request(req) + + self.assertEquals(resp, {"rejected": ["spqr2"]}) + self.assertEquals(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"]) + self.assertEquals(gcm.num_requests, 1) + + def test_regenerated_failure(self): + """ + Tests that use of a regenerated device ID does not cause confusion + when reporting a rejection to the homeserver. + The homeserver doesn't know about the regenerated ID so the rejection + must be translated back into the one provided by the homeserver. + """ + gcm = self.sygnal.pushkins["com.example.gcm"] + gcm.preload_with_response( + 200, {"results": [{"registration_id": "spqr_new", "message_id": "msg42"}]} + ) + + req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + resp = self._collect_request(req) + + self.assertEquals(resp, {"rejected": []}) + + # imagine there is some non-negligible time between these two, + # and the device in question is unregistered + + gcm.preload_with_response( + 200, + {"results": [{"registration_id": "spqr_new", "error": "NotRegistered"}]}, + ) + + req = self._make_request(self._make_dummy_notification([DEVICE_EXAMPLE])) + + resp = self._collect_request(req) + + self.assertEquals(gcm.last_request_body["to"], "spqr_new") + + # the ID translation needs to be transparent as the homeserver will not + # make sense of it otherwise. + self.assertEquals(resp, {"rejected": ["spqr"]}) + self.assertEquals(gcm.num_requests, 2) diff --git a/tests/test_pushgateway_api_v1.py b/tests/test_pushgateway_api_v1.py new file mode 100644 index 00000000..5b940ec6 --- /dev/null +++ b/tests/test_pushgateway_api_v1.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sygnal.exceptions import ( + NotificationDispatchException, + TemporaryNotificationDispatchException, +) +from sygnal.notifications import Pushkin + +from tests import testutils + +DEVICE_RAISE_EXCEPTION = { + "app_id": "com.example.spqr", + "pushkey": "raise_exception", + "pushkey_ts": 1234, +} + +DEVICE_REMOTE_ERROR = { + "app_id": "com.example.spqr", + "pushkey": "remote_error", + "pushkey_ts": 1234, +} + +DEVICE_TEMPORARY_REMOTE_ERROR = { + "app_id": "com.example.spqr", + "pushkey": "temporary_remote_error", + "pushkey_ts": 1234, +} + +DEVICE_REJECTED = { + "app_id": "com.example.spqr", + "pushkey": "reject", + "pushkey_ts": 1234, +} + +DEVICE_ACCEPTED = { + "app_id": "com.example.spqr", + "pushkey": "accept", + "pushkey_ts": 1234, +} + + +class TestPushkin(Pushkin): + """ + A synthetic Pushkin with simple rules. + """ + + async def dispatch_notification(self, n, device, context): + if device.pushkey == "raise_exception": + raise Exception("Bad things have occurred!") + elif device.pushkey == "remote_error": + raise NotificationDispatchException("Synthetic failure") + elif device.pushkey == "temporary_remote_error": + raise TemporaryNotificationDispatchException("Synthetic failure") + elif device.pushkey == "reject": + return [device.pushkey] + elif device.pushkey == "accept": + return [] + raise Exception(f"Unexpected fall-through. {device.pushkey}") + + +class PushGatewayApiV1TestCase(testutils.TestCase): + def config_setup(self, config): + """ + Set up a TestPushkin for the test. + """ + super(PushGatewayApiV1TestCase, self).config_setup(config) + config["apps"]["com.example.spqr"] = { + "type": "tests.test_pushgateway_api_v1.TestPushkin" + } + + def test_good_requests_give_200(self): + """ + Test that good requests give a 200 response code. + """ + # 200 codes cause the result to be parsed instead of returning the code + self.assertNot( + isinstance( + self._request( + self._make_dummy_notification([DEVICE_ACCEPTED, DEVICE_REJECTED]) + ), + int, + ) + ) + + def test_accepted_devices_are_not_rejected(self): + """ + Test that devices which are accepted by the Pushkin + do not lead to a rejection being returned to the homeserver. + """ + self.assertEquals( + self._request(self._make_dummy_notification([DEVICE_ACCEPTED])), + {"rejected": []}, + ) + + def test_rejected_devices_are_rejected(self): + """ + Test that devices which are rejected by the Pushkin + DO lead to a rejection being returned to the homeserver. + """ + self.assertEquals( + self._request(self._make_dummy_notification([DEVICE_REJECTED])), + {"rejected": [DEVICE_REJECTED["pushkey"]]}, + ) + + def test_only_rejected_devices_are_rejected(self): + """ + Test that devices which are rejected by the Pushkin + are the only ones to have a rejection returned to the homeserver, + even if other devices feature in the request. + """ + self.assertEquals( + self._request( + self._make_dummy_notification([DEVICE_REJECTED, DEVICE_ACCEPTED]) + ), + {"rejected": [DEVICE_REJECTED["pushkey"]]}, + ) + + def test_bad_requests_give_400(self): + """ + Test that bad requests lead to a 400 Bad Request response. + """ + self.assertEquals(self._request({}), 400) + + def test_exceptions_give_500(self): + """ + Test that internal exceptions/errors lead to a 500 Internal Server Error + response. + """ + + self.assertEquals( + self._request(self._make_dummy_notification([DEVICE_RAISE_EXCEPTION])), 500 + ) + + # we also check that a successful device doesn't hide the exception + self.assertEquals( + self._request( + self._make_dummy_notification([DEVICE_ACCEPTED, DEVICE_RAISE_EXCEPTION]) + ), + 500, + ) + + self.assertEquals( + self._request( + self._make_dummy_notification([DEVICE_RAISE_EXCEPTION, DEVICE_ACCEPTED]) + ), + 500, + ) + + def test_remote_errors_give_502(self): + """ + Test that errors caused by remote services such as GCM or APNS + lead to a 502 Bad Gateway response. + """ + + self.assertEquals( + self._request(self._make_dummy_notification([DEVICE_REMOTE_ERROR])), 502 + ) + + # we also check that a successful device doesn't hide the exception + self.assertEquals( + self._request( + self._make_dummy_notification([DEVICE_ACCEPTED, DEVICE_REMOTE_ERROR]) + ), + 502, + ) + + self.assertEquals( + self._request( + self._make_dummy_notification([DEVICE_REMOTE_ERROR, DEVICE_ACCEPTED]) + ), + 502, + ) diff --git a/tests/testutils.py b/tests/testutils.py new file mode 100644 index 00000000..f1155b67 --- /dev/null +++ b/tests/testutils.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from io import BytesIO +from threading import Condition + +from twisted.internet.defer import ensureDeferred +from twisted.test.proto_helpers import MemoryReactorClock +from twisted.trial import unittest +from twisted.web.http_headers import Headers +from twisted.web.server import NOT_DONE_YET +from twisted.web.test.requesthelper import DummyRequest as UnaugmentedDummyRequest + +from sygnal.http import PushGatewayApiServer +from sygnal.sygnal import Sygnal, merge_left_with_defaults, CONFIG_DEFAULTS + +REQ_PATH = b"/_matrix/push/v1/notify" + + +class TestCase(unittest.TestCase): + def config_setup(self, config): + config["db"]["dbfile"] = ":memory:" + + def setUp(self): + reactor = ExtendedMemoryReactorClock() + + config = {"apps": {}, "db": {}, "log": {"setup": {"version": 1}}} + config = merge_left_with_defaults(CONFIG_DEFAULTS, config) + + self.config_setup(config) + + self.sygnal = Sygnal(config, reactor) + self.sygnal.database.start() + self.v1api = PushGatewayApiServer(self.sygnal) + + start_deferred = ensureDeferred( + self.sygnal._make_pushkins_then_start(0, [], None) + ) + + while not start_deferred.called: + # we need to advance until the pushkins have started up + self.sygnal.reactor.advance(1) + self.sygnal.reactor.wait_for_work(lambda: start_deferred.called) + + def tearDown(self): + super().tearDown() + self.sygnal.database.close() + + def _make_dummy_notification(self, devices): + return { + "notification": { + "id": "$3957tyerfgewrf384", + "room_id": "!slw48wfj34rtnrf:example.com", + "type": "m.room.message", + "sender": "@exampleuser:matrix.org", + "sender_display_name": "Major Tom", + "room_name": "Mission Control", + "room_alias": "#exampleroom:matrix.org", + "prio": "high", + "content": { + "msgtype": "m.text", + "body": "I'm floating in a most peculiar way.", + }, + "counts": {"unread": 2, "missed_calls": 1}, + "devices": devices, + } + } + + def _make_request(self, payload, headers=None): + """ + Make a dummy request to the notify endpoint with the specified + Args: + payload: payload to be JSON encoded + headers (dict, optional): A L{dict} mapping header names as L{bytes} + to L{list}s of header values as L{bytes} + + Returns (DummyRequest): + A dummy request corresponding to the request arguments supplied. + + """ + pathparts = REQ_PATH.split(b"/") + if pathparts[0] == b"": + pathparts = pathparts[1:] + dreq = DummyRequest(pathparts) + dreq.requestHeaders = Headers(headers or {}) + dreq.responseCode = 200 # default to 200 + + if isinstance(payload, dict): + payload = json.dumps(payload) + + dreq.content = BytesIO(payload.encode()) + dreq.method = "POST" + + return dreq + + def _collect_request(self, request): + """ + Collects (waits until done and then returns the result of) the request. + Args: + request (Request): a request to collect + + Returns (dict or int): + If successful (200 response received), the response is JSON decoded + and the resultant dict is returned. + If the response code is not 200, returns the response code. + """ + resource = self.v1api.site.getResourceFor(request) + rendered = resource.render(request) + + if request.responseCode != 200: + return request.responseCode + + if isinstance(rendered, str): + return json.loads(rendered) + elif rendered == NOT_DONE_YET: + + while not request.finished: + # we need to advance until the request has been finished + self.sygnal.reactor.advance(1) + self.sygnal.reactor.wait_for_work(lambda: request.finished) + + assert request.finished > 0 + + if request.responseCode != 200: + return request.responseCode + + written_bytes = b"".join(request.written) + return json.loads(written_bytes) + else: + raise RuntimeError(f"Can't collect: {rendered}") + + def _request(self, *args, **kwargs): + """ + Makes and collects a request. + See L{_make_request} and L{_collect_request}. + """ + request = self._make_request(*args, **kwargs) + + return self._collect_request(request) + + +class ExtendedMemoryReactorClock(MemoryReactorClock): + def __init__(self): + super().__init__() + self.work_notifier = Condition() + + def callFromThread(self, function, *args): + self.callLater(0, function, *args) + + def callLater(self, when, what, *a, **kw): + self.work_notifier.acquire() + try: + return_value = super().callLater(when, what, *a, **kw) + self.work_notifier.notify_all() + finally: + self.work_notifier.release() + + return return_value + + def wait_for_work(self, early_stop=lambda: False): + """ + Blocks until there is work as long as the early stop condition + is not satisfied. + + Args: + early_stop: Extra function called that determines whether to stop + blocking. + Should returns true iff the early stop condition is satisfied, + in which case no blocking will be done. + It is intended to be used to detect when the task you are + waiting for is complete, e.g. a Deferred has fired or a + Request has been finished. + """ + self.work_notifier.acquire() + + try: + while len(self.getDelayedCalls()) == 0 and not early_stop(): + self.work_notifier.wait() + finally: + self.work_notifier.release() + + +class DummyRequest(UnaugmentedDummyRequest): + """ + Tracks the response code in the 'code' field, like a normal Request. + """ + + def __init__(self, postpath, session=None, client=None): + super().__init__(postpath, session, client) + self.code = 200 + + def setResponseCode(self, code, message=None): + super().setResponseCode(code, message) + self.code = code + + +class DummyResponse(object): + def __init__(self, code): + self.code = code + + +def make_async_magic_mock(ret_val): + async def dummy(*_args, **_kwargs): + return ret_val + + return dummy