From 267b24b8adf8e0315254d7c6eb98302530f1af3d Mon Sep 17 00:00:00 2001 From: Josh Schneier Date: Fri, 10 May 2024 19:43:27 -0400 Subject: [PATCH] [s3] Skip generating signed URLs if querystring_auth=False --- storages/backends/s3.py | 65 ++++++++++++++++++----------------------- tests/test_s3.py | 29 +++++++----------- 2 files changed, 38 insertions(+), 56 deletions(-) diff --git a/storages/backends/s3.py b/storages/backends/s3.py index 9822f2ef..3230f305 100644 --- a/storages/backends/s3.py +++ b/storages/backends/s3.py @@ -7,9 +7,7 @@ import warnings from datetime import datetime from datetime import timedelta -from urllib.parse import parse_qsl from urllib.parse import urlencode -from urllib.parse import urlsplit from django.contrib.staticfiles.storage import ManifestFilesMixin from django.core.exceptions import ImproperlyConfigured @@ -34,6 +32,7 @@ try: import boto3.session + import botocore import s3transfer.constants from boto3.s3.transfer import TransferConfig from botocore.config import Config @@ -330,6 +329,7 @@ def __init__(self, **settings): self._bucket = None self._connections = threading.local() + self._unsigned_connections = threading.local() if self.config is not None: warnings.warn( @@ -439,11 +439,13 @@ def get_default_settings(self): def __getstate__(self): state = self.__dict__.copy() state.pop("_connections", None) + state.pop("_unsigned_connections", None) state.pop("_bucket", None) return state def __setstate__(self, state): state["_connections"] = threading.local() + state["_unsigned_connections"] = threading.local() state["_bucket"] = None self.__dict__ = state @@ -462,6 +464,24 @@ def connection(self): ) return self._connections.connection + @property + def unsigned_connection(self): + unsigned_connection = getattr(self._unsigned_connections, "connection", None) + if unsigned_connection is None: + session = self._create_session() + config = self.client_config.merge( + Config(signature_version=botocore.UNSIGNED) + ) + self._unsigned_connections.connection = session.resource( + "s3", + region_name=self.region_name, + use_ssl=self.use_ssl, + endpoint_url=self.endpoint_url, + config=config, + verify=self.verify, + ) + return self._unsigned_connections.connection + def _create_session(self): """ If a user specifies a profile name and this class obtains access keys @@ -635,37 +655,6 @@ def get_modified_time(self, name): else: return make_naive(entry.last_modified) - def _strip_signing_parameters(self, url): - # Boto3 does not currently support generating URLs that are unsigned. Instead - # we take the signed URLs and strip any querystring params related to signing - # and expiration. - # Note that this may end up with URLs that are still invalid, especially if - # params are passed in that only work with signed URLs, e.g. response header - # params. - # The code attempts to strip all query parameters that match names of known - # parameters from v2 and v4 signatures, regardless of the actual signature - # version used. - split_url = urlsplit(url) - qs = parse_qsl(split_url.query, keep_blank_values=True) - blacklist = { - "x-amz-algorithm", - "x-amz-credential", - "x-amz-date", - "x-amz-expires", - "x-amz-signedheaders", - "x-amz-signature", - "x-amz-security-token", - "awsaccesskeyid", - "expires", - "signature", - } - filtered_qs = ((key, val) for key, val in qs if key.lower() not in blacklist) - # Note: Parameters that did not have a value in the original query string will - # have an '=' sign appended to it, e.g ?foo&bar becomes ?foo=&bar= - joined_qs = ("=".join(keyval) for keyval in filtered_qs) - split_url = split_url._replace(query="&".join(joined_qs)) - return split_url.geturl() - def url(self, name, parameters=None, expire=None, http_method=None): # Preserve the trailing slash after normalizing the path. name = self._normalize_name(clean_name(name)) @@ -691,12 +680,14 @@ def url(self, name, parameters=None, expire=None, http_method=None): params["Bucket"] = self.bucket.name params["Key"] = name - url = self.bucket.meta.client.generate_presigned_url( + + connection = ( + self.connection if self.querystring_auth else self.unsigned_connection + ) + url = connection.meta.client.generate_presigned_url( "get_object", Params=params, ExpiresIn=expire, HttpMethod=http_method ) - if self.querystring_auth: - return url - return self._strip_signing_parameters(url) + return url def get_available_name(self, name, max_length=None): """Overwrite existing file with the same name.""" diff --git a/tests/test_s3.py b/tests/test_s3.py index 12c2d009..fc9e960a 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -35,6 +35,7 @@ class S3StorageTests(TestCase): def setUp(self): self.storage = s3.S3Storage() self.storage._connections.connection = mock.MagicMock() + self.storage._unsigned_connections.connection = mock.MagicMock() def test_s3_session(self): with override_settings(AWS_S3_SESSION_PROFILE="test_profile"): @@ -664,10 +665,10 @@ def _test_storage_mtime(self, use_tz): def test_storage_url(self): name = "test_storage_size.txt" url = "http://aws.amazon.com/%s" % name - self.storage.bucket.meta.client.generate_presigned_url.return_value = url + self.storage.connection.meta.client.generate_presigned_url.return_value = url self.storage.bucket.name = "bucket" self.assertEqual(self.storage.url(name), url) - self.storage.bucket.meta.client.generate_presigned_url.assert_called_with( + self.storage.connection.meta.client.generate_presigned_url.assert_called_with( "get_object", Params={"Bucket": self.storage.bucket.name, "Key": name}, ExpiresIn=self.storage.querystring_expire, @@ -677,7 +678,7 @@ def test_storage_url(self): custom_expire = 123 self.assertEqual(self.storage.url(name, expire=custom_expire), url) - self.storage.bucket.meta.client.generate_presigned_url.assert_called_with( + self.storage.connection.meta.client.generate_presigned_url.assert_called_with( "get_object", Params={"Bucket": self.storage.bucket.name, "Key": name}, ExpiresIn=custom_expire, @@ -687,13 +688,18 @@ def test_storage_url(self): custom_method = "HEAD" self.assertEqual(self.storage.url(name, http_method=custom_method), url) - self.storage.bucket.meta.client.generate_presigned_url.assert_called_with( + self.storage.connection.meta.client.generate_presigned_url.assert_called_with( "get_object", Params={"Bucket": self.storage.bucket.name, "Key": name}, ExpiresIn=self.storage.querystring_expire, HttpMethod=custom_method, ) + def test_url_unsigned(self): + self.storage.querystring_auth = False + self.storage.url("test_name") + self.storage.unsigned_connection.meta.client.generate_presigned_url.assert_called_once() + def test_storage_url_custom_domain_signed_urls(self): key_id = "test-key" filename = "file.txt" @@ -766,21 +772,6 @@ def test_custom_domain_parameters(self): self.assertEqual(parsed_url.path, "/filename.mp4") self.assertEqual(parsed_url.query, "version=10") - def test_strip_signing_parameters(self): - expected = "http://bucket.s3-aws-region.amazonaws.com/foo/bar" - self.assertEqual( - self.storage._strip_signing_parameters( - "%s?X-Amz-Date=12345678&X-Amz-Signature=Signature" % expected - ), - expected, - ) - self.assertEqual( - self.storage._strip_signing_parameters( - "%s?expires=12345678&signature=Signature" % expected - ), - expected, - ) - @skipIf(threading is None, "Test requires threading") def test_connection_threading(self): connections = []