diff --git a/webapp/graphite/errors.py b/webapp/graphite/errors.py index 289541798..144946309 100644 --- a/webapp/graphite/errors.py +++ b/webapp/graphite/errors.py @@ -94,6 +94,15 @@ def __str__(self): return msg +# Replace special characters "&", "<" and ">" to HTML-safe sequences. +def escape(s): + s = s.replace("&", "&") # Must be done first! + s = s.replace("<", "<") + s = s.replace(">", ">") + + return s + + # decorator which turns InputParameterExceptions into Django's HttpResponseBadRequest def handleInputParameterError(f): def new_f(*args, **kwargs): @@ -102,6 +111,6 @@ def new_f(*args, **kwargs): except InputParameterError as e: msgStr = str(e) log.warning('%s', msgStr) - return HttpResponseBadRequest(msgStr) + return HttpResponseBadRequest(escape(msgStr)) return new_f diff --git a/webapp/tests/base.py b/webapp/tests/base.py index 3039513cc..512e1e6d3 100644 --- a/webapp/tests/base.py +++ b/webapp/tests/base.py @@ -5,3 +5,15 @@ class TestCase(OriginalTestCase): def tearDown(self): stop_pools() + + # Assert that a response is unsanitized (for check XSS issues) + def assertXSS(self, response, status_code=200, msg_prefix=''): + if status_code is not None: + self.assertEqual( + response.status_code, status_code, + msg_prefix + "Couldn't retrieve content: Response code was %d" + " (expected %d)" % (response.status_code, status_code) + ) + + xss = response.content.find(b"<") != -1 or response.content.find(b">") != -1 + self.assertFalse(xss, msg=msg_prefix+str(response.content)) diff --git a/webapp/tests/test_xss.py b/webapp/tests/test_xss.py new file mode 100644 index 000000000..7a3a2c9b7 --- /dev/null +++ b/webapp/tests/test_xss.py @@ -0,0 +1,42 @@ +import logging +import sys + +try: + from django.urls import reverse +except ImportError: # Django < 1.10 + from django.core.urlresolvers import reverse + +from .base import TestCase + +# Silence logging during tests +LOGGER = logging.getLogger() + +# logging.NullHandler is a python 2.7ism +if hasattr(logging, "NullHandler"): + LOGGER.addHandler(logging.NullHandler()) + +if sys.version_info[0] >= 3: + def resp_text(r): + return r.content.decode('utf-8') +else: + def resp_text(r): + return r.content + + +class RenderXSSTest(TestCase): + def test_render_xss(self): + url = reverse('render') + xssStr = '