diff --git a/README.md b/README.md index 1677c71..c990e2b 100644 --- a/README.md +++ b/README.md @@ -40,3 +40,15 @@ ## Installation ## Example + + +```python +INSTALLED_APPS = [ + ... + "dysession", # add dysession to installed apps + # 'django.contrib.sessions', # remove this default session + ... +] + +SESSION_ENGINE = "dysession.backends.db" +``` \ No newline at end of file diff --git a/dysession/aws/dynamodb.py b/dysession/aws/dynamodb.py index db1972f..d4c0a22 100644 --- a/dysession/aws/dynamodb.py +++ b/dysession/aws/dynamodb.py @@ -1,13 +1,17 @@ from datetime import datetime -from typing import Any, Dict, Literal, Optional, Union +import logging +from typing import Any, Callable, Dict, Literal, Optional, Union import boto3 from botocore import client as botoClitent from django.utils import timezone from dysession.aws.error import DynamodbItemNotFound, DynamodbTableNotFound -from dysession.backends.error import (SessionKeyDoesNotExist, - SessionKeyDuplicated) +from dysession.backends.error import ( + SessionExpired, + SessionKeyDoesNotExist, + SessionKeyDuplicated, +) from dysession.backends.model import SessionDataModel from ..settings import get_config @@ -78,33 +82,36 @@ def key_exists(session_key: str, table_name: Optional[str] = None, client=None) }, ProjectionExpression=f"{pk}", ) - return "Item" in response -def get_item(session_key: str, table_name: Optional[str] = None, client=None) -> bool: - - if client is None: - client = boto3.client("dynamodb", region_name=get_config()["DYNAMODB_REGION"]) +def get_item(session_key: str, table_name: Optional[str] = None) -> SessionDataModel: if table_name is None: table_name = get_config()["DYNAMODB_TABLENAME"] assert type(session_key) is str, "session_key should be string type" + logging.info("Get Item from DynamoDB") + pk = get_config()["PARTITION_KEY_NAME"] - response = client.get_item( - TableName=table_name, + resource = boto3.resource("dynamodb", region_name=get_config()["DYNAMODB_REGION"]) + table = resource.Table(table_name) + + response = table.get_item( Key={ - pk: {"S": session_key}, + pk: session_key, }, ) if "Item" not in response: raise DynamodbItemNotFound() - return response + model = SessionDataModel(session_key=session_key) + for k, v in response["Item"].items(): + model[k] = v + return model def insert_session_item( @@ -119,6 +126,9 @@ def insert_session_item( if table_name is None: table_name = get_config()["DYNAMODB_TABLENAME"] + if key_exists(data.session_key): + raise SessionKeyDuplicated + resource = boto3.resource("dynamodb", region_name=get_config()["DYNAMODB_REGION"]) table = resource.Table(table_name) pk = get_config()["PARTITION_KEY_NAME"] @@ -137,28 +147,54 @@ def insert_session_item( class DynamoDB: - def __init__(self, client) -> None: + def __init__(self, client=None) -> None: self.client = client def get( - self, session_key: Optional[str] = None, ttl: Optional[datetime] = None + self, + session_key: Optional[str] = None, + table_name: Optional[str] = None, + expired_time_fn: Callable[[], datetime] = datetime.now, ) -> Dict[str, Any]: """Return session data if dynamodb partision key is matched with inputed session_key""" if session_key is None: raise ValueError("session_key should be str type") + if table_name is None: + table_name = get_config()["DYNAMODB_TABLENAME"] - # if not found then raise - # raise SessionKeyDoesNotExist - # if key is expired - # raise SessionExpired - - def set(self, session_key: Optional[str] = None, session_data=None) -> None: - return - # Partision key duplicated - raise SessionKeyDuplicated + now = expired_time_fn() - def exists(self, session_key: Optional[str] = None) -> bool: - return False + try: + model = get_item(session_key=session_key, table_name=table_name) + if get_config()["TTL_ATTRIBUTE_NAME"] in model: + time = model[get_config()["TTL_ATTRIBUTE_NAME"]] + if time < int(now.timestamp()): + raise SessionExpired # if not found then raise - raise SessionKeyDoesNotExist + except DynamodbItemNotFound: + raise SessionKeyDoesNotExist + # if key is expired + except SessionExpired: + raise SessionExpired + + return model + + def set( + self, + data: SessionDataModel, + table_name: Optional[str] = None, + return_consumed_capacity: Literal["INDEXES", "TOTAL", "NONE"] = "TOTAL", + ignore_duplicated: bool = True, + ) -> None: + try: + insert_session_item(data, table_name, return_consumed_capacity) + except SessionKeyDuplicated: + if not ignore_duplicated: + raise SessionKeyDuplicated + + def exists(self, session_key: str) -> bool: + if type(session_key) is not str: + raise TypeError(f"session_key should be type of str instead of {type(session_key)}.") + + return key_exists(session_key=session_key) diff --git a/dysession/backends/db.py b/dysession/backends/db.py index febc3d8..c7bf334 100644 --- a/dysession/backends/db.py +++ b/dysession/backends/db.py @@ -14,6 +14,7 @@ SessionKeyDuplicated, ) from dysession.backends.model import SessionDataModel +from dysession.settings import get_config class SessionStore(SessionBase): @@ -21,12 +22,14 @@ class SessionStore(SessionBase): def __init__(self, session_key: Optional[str], **kwargs: Any) -> None: super().__init__(session_key, **kwargs) - # self.client = boto3.client("dynamodb") - self.db = DynamoDB(client=boto3.client("dynamodb")) + self.db = DynamoDB( + client=boto3.client("dynamodb", region_name=get_config()["DYNAMODB_REGION"]) + ) + self._get_session() def _get_session_from_ddb(self) -> SessionDataModel: try: - return self.db.get(session_key=self.session_key, ttl=timezone.now()) + return self.db.get(session_key=self.session_key) except (SessionKeyDoesNotExist, SessionExpired, SuspiciousOperation) as e: if isinstance(e, SuspiciousOperation): logger = logging.getLogger(f"django.security.{e.__class__.__name__}") @@ -40,10 +43,12 @@ def _get_session(self, no_load=False) -> SessionDataModel: """ self.accessed = True try: - return self._session_cache + if isinstance(self._session_cache, SessionDataModel): + return self._session_cache + raise AttributeError except AttributeError: if self.session_key is None or no_load: - self._session_cache = SessionDataModel() + self._session_cache = SessionDataModel(self.session_key) else: self._session_cache = self.load() return self._session_cache @@ -63,7 +68,13 @@ def clear(self): super().clear() self._session_cache = SessionDataModel() - # ====== Methods that subclass must implement + def items(self): + return self._session.items() + + def __str__(self): + return str(self._get_session()) + + # Methods that subclass must implement def exists(self, session_key: str) -> bool: """ Return True if the given session_key already exists. @@ -87,22 +98,24 @@ def create(self) -> None: self.modified = True return - def save(self, must_create: bool = ...) -> None: + def save(self, must_create: bool = False) -> None: """ Save the session data. If 'must_create' is True, create a new session object (or raise CreateError). Otherwise, only update an existing object and don't create one (raise UpdateError if needed). """ try: - self.db.set( - session_key=self._session_key, - session_data=self._get_session(must_create), - ) + if self._session_key is None: + return self.create() + + data = self._get_session(no_load=must_create) + data.session_key = self._session_key + self.db.set(data=data) except SessionKeyDuplicated: if must_create: raise SessionKeyDuplicated - def delete(self, request, *args, **kwargs): + def delete(self, session_key=None): """ Delete the session data under this key. If the key is None, use the current session key value. diff --git a/dysession/backends/model.py b/dysession/backends/model.py index 68f486e..b6176ce 100644 --- a/dysession/backends/model.py +++ b/dysession/backends/model.py @@ -1,21 +1,35 @@ +import json from typing import Any, Optional class SessionDataModel: + + NOTFOUND_ALLOW_LIST = ["_auth_user_id", "_auth_user_backend", "_auth_user_hash"] + def __init__(self, session_key: Optional[str] = None) -> None: if type(session_key) is not str and session_key is not None: raise TypeError("session_key should be type str or None") self.session_key = session_key - self.__variables_names = set() + self.__variables_names = set(["session_key"]) def __getitem__(self, key) -> Any: - return getattr(self, key) + # Set SESSION_EXPIRE_AT_BROWSER_CLOSE to False + # https://docs.djangoproject.com/en/4.1/topics/http/sessions/#browser-length-sessions-vs-persistent-sessions + if key == "_session_expiry": + return False + + try: + return getattr(self, key) + except AttributeError: + if key in self.NOTFOUND_ALLOW_LIST: + raise KeyError + raise def __setitem__(self, key, value): - if key == "session_key": - raise ValueError() + # if key == "session_key": + # raise ValueError() setattr(self, key, value) self.__variables_names.add(key) @@ -28,7 +42,7 @@ def __iter__(self): return iter(self.__variables_names) def __is_empty(self): - return len(self.__variables_names) == 0 + return "session_key" in self.__variables_names and len(self.__variables_names) == 1 is_empty = property(__is_empty) @@ -49,3 +63,14 @@ def pop(self, key, default=...): if default is Ellipsis: raise return default + + def items(self): + for key in self.__variables_names: + yield (key, self[key]) + + def __str__(self) -> str: + data = {} + for key in self.__variables_names: + data[key] = getattr(self, key) + + return json.dumps(data) diff --git a/dysession/management/commands/__init__.py b/dysession/management/commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dysession/management/commands/dysession_clear.py b/dysession/management/commands/dysession_clear.py index 2c9a2e7..4a1d29c 100644 --- a/dysession/management/commands/dysession_clear.py +++ b/dysession/management/commands/dysession_clear.py @@ -23,9 +23,6 @@ def add_arguments(self, parser: CommandParser) -> None: def handle(self, *args: Any, **options: Any) -> Optional[str]: userids = options.get("uid", None) if userids: - print(f"Ready to clear {userids} session data.") - return - - print("Clearing whole session data") - return + ... + return diff --git a/dysession/settings.py b/dysession/settings.py index b53a728..7b48c5d 100644 --- a/dysession/settings.py +++ b/dysession/settings.py @@ -17,6 +17,19 @@ @lru_cache def get_config() -> Dict[str, Union[str, int]]: + """Return cached django-dysession config in dictionary type + + Contain Items: + * DYNAMODB_TABLENAME + * PARTITION_KEY_NAME + * SORT_KEY_NAME + * TTL_ATTRIBUTE_NAME + * CACHE_PERIOD + * DYNAMODB_REGION + + Returns: + Dict[str, Union[str, int]] + """ config = DEFAULT_CONFIG.copy() custom_config = getattr(settings, "DYSESSION_CONFIG", {}) config.update(custom_config) diff --git a/tests/test_aws_dynamodb.py b/tests/test_aws_dynamodb.py index 9416904..ed81d8c 100644 --- a/tests/test_aws_dynamodb.py +++ b/tests/test_aws_dynamodb.py @@ -337,37 +337,6 @@ def test_check_key_wrong_type(self): ) # Get Item - @mock_dynamodb - def test_get_item_without_client(self): - - options = { - "pk": get_config()["PARTITION_KEY_NAME"], - "sk": get_config()["SORT_KEY_NAME"], - "table": "sessions", - "region": "ap-northeast-1", - } - - try: - check_dynamodb_table_exists(table_name=options["table"]) - except DynamodbTableNotFound: - create_dynamodb_table( - options={ - "pk": options["pk"], - "sk": options["sk"], - "table": options["table"], - }, - ) - - session_key = "aaaaaaaaaa" - model = SessionDataModel(session_key) - model["a"] = 1 - model["b"] = "qwerty" - - resp = insert_session_item(data=model) - self.assertEqual(resp["ResponseMetadata"]["HTTPStatusCode"], 200) - - resp = get_item(session_key=session_key, table_name=options["table"]) - self.assertIn("Item", resp) @mock_dynamodb def test_get_item_using_not_exist_key(self): @@ -394,15 +363,15 @@ def test_get_item_using_not_exist_key(self): with self.assertRaises(DynamodbItemNotFound): resp = get_item( - session_key="not_exist_key", table_name=options["table"], client=client + session_key="not_exist_key", table_name=options["table"] ) # Insert Item @parameterized.expand( [ - ["aaaaaaaaa"], - ["bbbbbbbbb"], - ["ccccccccc"], + ["aaaaaaaa"], + ["bbbbbbbb"], + ["cccccccc"], ] ) @mock_dynamodb @@ -435,15 +404,13 @@ def test_insert_item_with_tablename(self, session_key: str): model["d"] = 4 model["e"] = "qwerty" - insert_session_item(data=model, table_name=options["table"]) - resp = insert_session_item(data=model) self.assertEqual(resp["ResponseMetadata"]["HTTPStatusCode"], 200) resp = get_item( - session_key=session_key, table_name=options["table"], client=client + session_key=session_key, table_name=options["table"] ) - self.assertIn("Item", resp) + self.assertIsInstance(resp, SessionDataModel) @parameterized.expand( [ @@ -485,5 +452,5 @@ def test_insert_item_without_tablename(self, session_key: str): resp = insert_session_item(data=model) self.assertEqual(resp["ResponseMetadata"]["HTTPStatusCode"], 200) - resp = get_item(session_key=session_key, client=client) - self.assertIn("Item", resp) + resp = get_item(session_key=session_key) + self.assertIsInstance(resp, SessionDataModel) diff --git a/tests/test_backend_db.py b/tests/test_backend_db.py new file mode 100644 index 0000000..ac9ba55 --- /dev/null +++ b/tests/test_backend_db.py @@ -0,0 +1,235 @@ +import time +from datetime import datetime +from typing import Any + +import boto3 +from django.test import TestCase +from moto import mock_dynamodb +from parameterized import parameterized + +from dysession.aws.dynamodb import ( + DynamoDB, + check_dynamodb_table_exists, + create_dynamodb_table, + destory_dynamodb_table, + get_item, + insert_session_item, + key_exists, +) +from dysession.aws.error import DynamodbItemNotFound, DynamodbTableNotFound +from dysession.backends.error import ( + SessionExpired, + SessionKeyDoesNotExist, + SessionKeyDuplicated, +) +from dysession.backends.model import SessionDataModel +from dysession.settings import get_config + + +class DynamoDBTestCase(TestCase): + @mock_dynamodb + def create_dynamodb_table(self): + self.options = { + "pk": get_config()["PARTITION_KEY_NAME"], + "sk": get_config()["SORT_KEY_NAME"], + "table": "sessions", + "region": "ap-northeast-1", + } + + self.client = client = boto3.client( + "dynamodb", region_name=self.options["region"] + ) + try: + check_dynamodb_table_exists(table_name=self.options["table"], client=client) + except DynamodbTableNotFound: + create_dynamodb_table( + options={ + "pk": self.options["pk"], + "sk": self.options["sk"], + "table": self.options["table"], + }, + client=client, + ) + + @parameterized.expand( + [ + ["aaaaaaa"], + ["bbbbbbb"], + ] + ) + @mock_dynamodb + def test_get_datamodel_from_dynamodb_controller(self, session_key: str): + + self.create_dynamodb_table() + model = SessionDataModel(session_key) + model["a"] = 1 + model["b"] = {"z": "z", "x": 1} + model["c"] = False + model["d"] = [7, 8, 9] + model["e"] = "qwerty" + + resp = insert_session_item(data=model) + self.assertEqual(resp["ResponseMetadata"]["HTTPStatusCode"], 200) + + db = DynamoDB(self.client) + model = db.get(session_key) + self.assertIsInstance(model, SessionDataModel) + self.assertEqual(model.a, 1) + self.assertEqual(len(model.b), 2) + self.assertIn("z", model.b) + self.assertIn("x", model.b) + self.assertEqual(model.b["z"], "z") + self.assertEqual(model.b["x"], 1) + self.assertEqual(model.c, False) + self.assertListEqual(model.d, [7, 8, 9]) + self.assertEqual(model.e, "qwerty") + + @mock_dynamodb + def test_get_datamodel_from_dynamodb_controller_with_missing_session_key(self): + + db = DynamoDB(self.client) + with self.assertRaises(ValueError): + model = db.get() + + @mock_dynamodb + def test_get_nonexist_datamodel_from_dynamodb_controller(self): + + self.create_dynamodb_table() + + db = DynamoDB(self.client) + with self.assertRaises(SessionKeyDoesNotExist): + model = db.get("not_exist") + + @mock_dynamodb + def test_get_expired_datamodel_from_dynamodb_controller(self): + + session_key = "test_get_expired_datamodel_from_dynamodb_controller" + self.create_dynamodb_table() + + model = SessionDataModel(session_key) + model["a"] = 1 + model[get_config()["TTL_ATTRIBUTE_NAME"]] = int(datetime.now().timestamp()) + insert_session_item(data=model, table_name=self.options["table"]) + + # Make sure the item expired + time.sleep(2) + + db = DynamoDB(self.client) + with self.assertRaises(SessionExpired): + model = db.get(session_key=session_key) + + @mock_dynamodb + def test_get_not_expired_datamodel_from_dynamodb_controller(self): + + session_key = "test_get_not_expired_datamodel_from_dynamodb_controller" + self.create_dynamodb_table() + + model = SessionDataModel(session_key) + model["a"] = 1 + model[get_config()["TTL_ATTRIBUTE_NAME"]] = int(datetime.now().timestamp()) + 50 + insert_session_item(data=model, table_name=self.options["table"]) + + # Make sure the item expired + time.sleep(2) + + db = DynamoDB(self.client) + + model = db.get(session_key=session_key) + + @mock_dynamodb + def test_set_datamodel_via_dynamodb_controller(self): + + session_key = "test_set_datamodel_via_dynamodb_controller" + self.create_dynamodb_table() + + model = SessionDataModel(session_key) + model["a"] = 1 + model[get_config()["TTL_ATTRIBUTE_NAME"]] = int(datetime.now().timestamp()) + 50 + + db = DynamoDB(self.client) + db.set(model, get_config()["DYNAMODB_TABLENAME"]) + + query_model = db.get(session_key=session_key) + self.assertEqual(model.a, query_model.a) + + @mock_dynamodb + def test_set_duplicated_datamodel_via_dynamodb_controller_ignore_duplicated(self): + + session_key = "test_set_duplicated_datamodel_via_dynamodb_controller" + self.create_dynamodb_table() + + model = SessionDataModel(session_key) + model["a"] = 1 + model[get_config()["TTL_ATTRIBUTE_NAME"]] = int(datetime.now().timestamp()) + 50 + + db = DynamoDB(self.client) + db.set(model, get_config()["DYNAMODB_TABLENAME"]) + query_model = db.get(session_key=session_key) + self.assertEqual(model.a, query_model.a) + + db.set(model, get_config()["DYNAMODB_TABLENAME"], ignore_duplicated=True) + + @mock_dynamodb + def test_set_duplicated_datamodel_via_dynamodb_controller(self): + + session_key = "test_set_duplicated_datamodel_via_dynamodb_controller" + self.create_dynamodb_table() + + model = SessionDataModel(session_key) + model["a"] = 1 + model[get_config()["TTL_ATTRIBUTE_NAME"]] = int(datetime.now().timestamp()) + 50 + + db = DynamoDB(self.client) + db.set(model, get_config()["DYNAMODB_TABLENAME"]) + query_model = db.get(session_key=session_key) + self.assertEqual(model.a, query_model.a) + + with self.assertRaises(SessionKeyDuplicated): + db.set(model, get_config()["DYNAMODB_TABLENAME"], ignore_duplicated=False) + + @mock_dynamodb + def test_exist_check_via_dynamodb_controller(self): + + session_key = "test_set_duplicated_datamodel_via_dynamodb_controller" + self.create_dynamodb_table() + + model = SessionDataModel(session_key) + model["a"] = 1 + model[get_config()["TTL_ATTRIBUTE_NAME"]] = int(datetime.now().timestamp()) + 50 + + db = DynamoDB(self.client) + db.set(model, get_config()["DYNAMODB_TABLENAME"]) + query_model = db.get(session_key=session_key) + self.assertEqual(model.a, query_model.a) + + self.assertTrue(db.exists(session_key)) + self.assertFalse(db.exists(session_key + "_not_exists")) + + @parameterized.expand( + [ + [1], + [1.03], + [True], + [(1, 2, 3)], + [[1, 2, 3]], + ] + ) + @mock_dynamodb + def test_exist_check_input_type_error_via_dynamodb_controller( + self, error_input: Any + ): + + session_key = "test_set_duplicated_datamodel_via_dynamodb_controller" + self.create_dynamodb_table() + + model = SessionDataModel(session_key) + model["a"] = 1 + model[get_config()["TTL_ATTRIBUTE_NAME"]] = int(datetime.now().timestamp()) + 50 + + db = DynamoDB(self.client) + db.set(model, get_config()["DYNAMODB_TABLENAME"]) + query_model = db.get(session_key=session_key) + self.assertEqual(model.a, query_model.a) + + with self.assertRaises(TypeError): + db.exists(error_input) diff --git a/tests/test_backend_model.py b/tests/test_backend_model.py index b4facdd..5cd60f6 100644 --- a/tests/test_backend_model.py +++ b/tests/test_backend_model.py @@ -31,12 +31,6 @@ def test_set_attribute(self): model = SessionDataModel() model["good_key"] = 0 - def test_set_attribute_session_key(self): - model = SessionDataModel() - - with self.assertRaises(ValueError): - model["session_key"] = 0 - def test_get_attribute(self): model = SessionDataModel() model["good_key"] = 0 @@ -111,4 +105,51 @@ def test_iter(self): model["c"] = 1 model["d"] = 1 - self.assertEqual(set(model), set(["a", "b", "c", "d"])) + self.assertEqual(set(model), set(["a", "b", "c", "d", "session_key"])) + + def test_items(self): + model = SessionDataModel("session_key") + + model["a"] = 1 + model["b"] = 2 + model["c"] = 3 + model["d"] = 4 + + + keys = [] + values = [] + for k, v in model.items(): + keys.append(k) + values.append(v) + + + self.assertEqual(set(keys), set(["a", "b", "c", "d", "session_key"])) + self.assertEqual(set(values), set([1, 2, 3, 4, "session_key"])) + + def test_str_magic_method(self): + model = SessionDataModel("session_key") + model["a"] = 1 + model["b"] = 2 + model["c"] = 3 + model["d"] = 4 + + import json + + data = json.loads(str(model)) + + for k in ["session_key", "a", "b", "c", "d"]: + self.assertIn(k, data.keys()) + + def test_get__session_expiry(self): + model = SessionDataModel() + self.assertFalse(model["_session_expiry"]) + + def test_not_found_allow_list(self): + model = SessionDataModel() + + for k in ["_auth_user_id", "_auth_user_backend", "_auth_user_hash"]: + with self.assertRaises(KeyError): + model[k] + + with self.assertRaises(KeyError): + self.assertIsNone(model.get(k)) \ No newline at end of file diff --git a/tests/test_commands.py b/tests/test_commands.py index 30da24c..9358fbc 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -27,7 +27,6 @@ def call_command(self, command: str, *args, stdout=None, stderr=None, **kwargs): # def test_init_dynamodb_table(self): # client = boto3.client("dynamodb") -# print(client) # pass @@ -35,14 +34,11 @@ def call_command(self, command: str, *args, stdout=None, stderr=None, **kwargs): # def test_call_help(self): # out = StringIO() # call_command("dysession_clear", "-h", stdout=out) -# print(out.read()) # call_command("dysession_clear", *["-u", "XD"], "-h", stdout=out) -# print(out.read()) # class DysessionDestoryTestCase(CommandTestCase, TestCase): # def test_call_help(self): # out = StringIO() # call_command("dysession_destory", "-h", stdout=out) -# print(out.read())