diff --git a/enterprise_access/apps/api_client/lms_client.py b/enterprise_access/apps/api_client/lms_client.py index b73a7b7b..8949c8bd 100755 --- a/enterprise_access/apps/api_client/lms_client.py +++ b/enterprise_access/apps/api_client/lms_client.py @@ -42,7 +42,6 @@ def all_pages_enterprise_group_members_cache_key( class LmsApiClient(BaseOAuthClient): - """ API client for calls to the LMS service. """ @@ -64,21 +63,32 @@ def enterprise_group_members_endpoint(self, group_uuid): "learners/", ) - def get_enterprise_customer_data(self, enterprise_customer_uuid): + def get_enterprise_customer_data(self, enterprise_customer_uuid=None, enterprise_customer_slug=None): """ - Gets the data for an EnterpriseCustomer for the given uuid. + Gets the data for an EnterpriseCustomer for the given uuid or slug. Arguments: enterprise_customer_uuid (string): id of the enterprise customer + enterprise_customer_slug (string): slug of the enterprise customer Returns: dictionary containing enterprise customer metadata """ + if enterprise_customer_uuid: + # Returns a dict + endpoint = f'{self.enterprise_customer_endpoint}{enterprise_customer_uuid}/' + elif enterprise_customer_slug: + # Returns a list of dicts + endpoint = f'{self.enterprise_customer_endpoint}?slug={enterprise_customer_slug}' + else: + raise ValueError('Either enterprise_customer_uuid or enterprise_customer_slug is required.') try: - endpoint = f'{self.enterprise_customer_endpoint}{enterprise_customer_uuid}/' response = self.client.get(endpoint, timeout=settings.LMS_CLIENT_TIMEOUT) response.raise_for_status() - return response.json() + payload = response.json() + if results := payload.get('results'): + return results[0] + return payload except requests.exceptions.HTTPError as exc: logger.exception(exc) raise @@ -406,6 +416,7 @@ class LmsUserApiClient(BaseUserApiClient): enterprise_api_base_url = f"{settings.LMS_URL}/enterprise/api/v1/" enterprise_learner_portal_api_base_url = f"{settings.LMS_URL}/enterprise_learner_portal/api/v1/" + enterprise_learner_endpoint = f"{enterprise_api_base_url}enterprise-learner/" default_enterprise_enrollment_intentions_learner_status_endpoint = ( f'{enterprise_api_base_url}default-enterprise-enrollment-intentions/learner-status/' ) @@ -413,6 +424,59 @@ class LmsUserApiClient(BaseUserApiClient): f'{enterprise_learner_portal_api_base_url}enterprise_course_enrollments/' ) + def get_enterprise_customers_for_user(self, username, traverse_pagination=False): + """ + Fetches enterprise learner data for a given username. + + Arguments: + username (str): Username of the learner + + Returns: + dict: Dictionary representation of the JSON response from the API + """ + query_params = { + 'username': username, + } + results = [] + initial_response_data = None + current_response = None + next_url = self.enterprise_learner_endpoint + try: + while next_url: + current_response = self.get( + next_url, + params=query_params, + timeout=settings.LMS_CLIENT_TIMEOUT + ) + current_response.raise_for_status() + data = current_response.json() + + if not initial_response_data: + # Store the initial response data (first page) for later use + initial_response_data = data + + # Collect results from the current page + results.extend(data.get('results', [])) + + # If pagination is enabled, continue with the next page; otherwise, break + next_url = data.get('next') if traverse_pagination else None + + consolidated_response = { + **initial_response_data, + 'next': None, + 'previous': None, + 'count': len(results), + 'num_pages': 1, + 'results': results, + } + return consolidated_response + except requests.exceptions.HTTPError as exc: + logger.exception( + f"Failed to fetch enterprise learner for learner {username}: {exc} " + f"Response content: {current_response.content if current_response else None}" + ) + raise + def get_default_enterprise_enrollment_intentions_learner_status(self, enterprise_customer_uuid): """ Fetches learner status from the default enterprise enrollment intentions endpoint. diff --git a/enterprise_access/apps/api_client/tests/test_lms_client.py b/enterprise_access/apps/api_client/tests/test_lms_client.py index c12eaf15..048aab50 100644 --- a/enterprise_access/apps/api_client/tests/test_lms_client.py +++ b/enterprise_access/apps/api_client/tests/test_lms_client.py @@ -109,21 +109,48 @@ def test_get_enterprise_admin_users(self, mock_oauth_client, mock_json): timeout=settings.LMS_CLIENT_TIMEOUT, ) + @ddt.data( + { + 'enterprise_uuid': 'some-uuid', + 'enterprise_slug': None + }, + { + 'enterprise_uuid': None, + 'enterprise_slug': 'some-slug', + }, + { + 'enterprise_uuid': 'some-uuid', + 'enterprise_slug': 'some-slug', + }, + ) @mock.patch('requests.Response.json') @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') - def test_get_enterprise_customer_data(self, mock_oauth_client, mock_json): + @ddt.unpack + def test_get_enterprise_customer_data( + self, + mock_oauth_client, + mock_json, + enterprise_uuid, + enterprise_slug, + ): """ Verify client hits the right URL for entepriseCustomer data. """ - mock_json.return_value = { + mock_enterprise_customer = { 'uuid': 'some-uuid', 'slug': 'some-test-slug', } + mock_json.return_value = mock_enterprise_customer + if not enterprise_uuid and enterprise_slug: + mock_json.return_value = {'results': [mock_enterprise_customer]} mock_oauth_client.return_value.get.return_value = requests.Response() mock_oauth_client.return_value.get.return_value.status_code = 200 client = LmsApiClient() - customer_data = client.get_enterprise_customer_data('some-uuid') + customer_data = client.get_enterprise_customer_data( + enterprise_customer_uuid=enterprise_uuid, + enterprise_customer_slug=enterprise_slug, + ) assert customer_data['uuid'] == 'some-uuid' assert customer_data['slug'] == 'some-test-slug' @@ -132,8 +159,16 @@ def test_get_enterprise_customer_data(self, mock_oauth_client, mock_json): 'http://edx-platform.example.com/' 'enterprise/api/v1/' 'enterprise-customer/' - 'some-uuid/' + f'{enterprise_uuid}/' ) + if not enterprise_uuid and enterprise_slug: + expected_url = ( + 'http://edx-platform.example.com/' + 'enterprise/api/v1/' + 'enterprise-customer/' + f'?slug={enterprise_slug}' + ) + mock_oauth_client.return_value.get.assert_called_with( expected_url, timeout=settings.LMS_CLIENT_TIMEOUT, diff --git a/enterprise_access/apps/bffs/context.py b/enterprise_access/apps/bffs/context.py index ea056078..4f1a3b8c 100644 --- a/enterprise_access/apps/bffs/context.py +++ b/enterprise_access/apps/bffs/context.py @@ -1,6 +1,8 @@ """ HandlerContext for bffs app. """ + +from enterprise_access.apps.api_client.lms_client import LmsApiClient, LmsUserApiClient from enterprise_access.apps.bffs import serializers @@ -15,8 +17,10 @@ class HandlerContext: data: A dictionary to store data loaded and processed by the handlers. errors: A list to store errors that occur during request processing. warnings: A list to store warnings that occur during the request processing. - enterprise_customer_uuid: The enterprise customer the user is associated with. + enterprise_customer_uuid: The enterprise customer uuid associated with this request. + enterprise_customer_slug: The enterprise customer slug associated with this request. lms_user_id: The id associated with the authenticated user. + enterprise_features: A dictionary to store enterprise features associated with the authenticated user. """ def __init__(self, request): @@ -26,14 +30,21 @@ def __init__(self, request): request: The incoming HTTP request. """ self._request = request + self._errors = [] # Stores any errors that occur during processing + self._warnings = [] # Stores any warnings that occur during processing + self._enterprise_customer_uuid = None + self._enterprise_customer_slug = None + self._lms_user_id = getattr(self.user, 'lms_user_id', None) + self._enterprise_features = {} + self.data = {} # Stores processed data for the response - self.errors = [] # Stores any errors that occur during processing - self.warnings = [] # Stores any warnings that occur during processing - self.enterprise_customer_uuid = None - self.lms_user_id = None - # Set common context attributes - self.initialize_common_context_data() + # API clients + self.lms_api_client = LmsApiClient() + self.lms_user_api_client = LmsUserApiClient(request) + + # Initialize common context data + self._initialize_common_context_data() @property def request(self): @@ -43,23 +54,221 @@ def request(self): def user(self): return self._request.user - def initialize_common_context_data(self): + @property + def errors(self): + return self._errors + + @property + def warnings(self): + return self._warnings + + @property + def enterprise_customer_uuid(self): + return self._enterprise_customer_uuid + + @property + def enterprise_customer_slug(self): + return self._enterprise_customer_slug + + @property + def lms_user_id(self): + return self._lms_user_id + + @property + def enterprise_features(self): + return self._enterprise_features + + @property + def enterprise_customer(self): + return self.data.get('enterprise_customer') + + def _initialize_common_context_data(self): """ - Initialize commonly used context attributes, such as enterprise customer UUID and LMS user ID. + Initializes common context data, like enterprise customer UUID and user ID. """ enterprise_uuid_query_param = self.request.query_params.get('enterprise_customer_uuid') + enterprise_slug_query_param = self.request.query_params.get('enterprise_customer_slug') + enterprise_uuid_post_param = None + enterprise_slug_post_param = None if self.request.method == 'POST': enterprise_uuid_post_param = self.request.data.get('enterprise_customer_uuid') + enterprise_slug_post_param = self.request.data.get('enterprise_customer_slug') enterprise_customer_uuid = enterprise_uuid_query_param or enterprise_uuid_post_param - if enterprise_customer_uuid: - self.enterprise_customer_uuid = enterprise_customer_uuid + self._enterprise_customer_uuid = enterprise_customer_uuid + enterprise_customer_slug = enterprise_slug_query_param or enterprise_slug_post_param + self._enterprise_customer_slug = enterprise_customer_slug + + # Initialize the enterprise customer users metatata derived from the LMS + self._initialize_enterprise_customer_users() + + if not self.enterprise_customer: + # If no enterprise customer is found, return early + return + + # Otherwise, update the enterprise customer UUID and slug if not already set + if not self.enterprise_customer_slug: + self._enterprise_customer_slug = self.enterprise_customer.get('slug') + if not self.enterprise_customer_uuid: + self._enterprise_customer_uuid = self.enterprise_customer.get('uuid') + + def _initialize_enterprise_customer_users(self): + """ + Initializes the enterprise customer users for the request user. + """ + try: + enterprise_customer_users_data = self.lms_user_api_client.get_enterprise_customers_for_user( + self.user.username, + traverse_pagination=True + ) + except Exception as e: # pylint: disable=broad-except + self.add_error( + user_message='Error retrieving linked enterprise customers', + developer_message=str(e) + ) + return + + # Set enterprise features from the response + self._enterprise_features = enterprise_customer_users_data.get('enterprise_features', {}) + + # Parse the enterprise customer user data + enterprise_customer_users = enterprise_customer_users_data.get('results', []) + active_enterprise_customer = self._get_active_enterprise_customer(enterprise_customer_users) + enterprise_customer_user_for_requested_customer = next( + ( + enterprise_customer_user + for enterprise_customer_user in enterprise_customer_users + if self._enterprise_customer_matches_slug_or_uuid(enterprise_customer_user.get('enterprise_customer')) + ), + None + ) + + # If no enterprise customer user is found for the requested customer (i.e., request user not explicitly + # linked), but the request user is staff, attempt to retrieve enterprise customer metadata from the + # `/enterprise-customer/` LMS API endpoint instead. + if not enterprise_customer_user_for_requested_customer: + staff_enterprise_customer = self._get_staff_enterprise_customer() else: - raise ValueError("enterprise_customer_uuid is required for this request.") + staff_enterprise_customer = None + + # Determine the enterprise customer user to display + requested_enterprise_customer = ( + enterprise_customer_user_for_requested_customer.get('enterprise_customer') + if enterprise_customer_user_for_requested_customer else None + ) + enterprise_customer = self._determine_enterprise_customer_for_display( + active_enterprise_customer=active_enterprise_customer, + requested_enterprise_customer=requested_enterprise_customer, + staff_enterprise_customer=staff_enterprise_customer, + ) + + # Update the context data with the enterprise customer user information + self.data.update({ + 'enterprise_customer': enterprise_customer, + 'active_enterprise_customer': active_enterprise_customer, + 'all_linked_enterprise_customer_users': enterprise_customer_users, + 'staff_enterprise_customer': staff_enterprise_customer, + }) + + def _get_active_enterprise_customer(self, enterprise_customer_users): + """ + Get the active enterprise customer user from the list of enterprise customer users. + """ + active_enterprise_customer_user = next( + ( + enterprise_customer_user + for enterprise_customer_user in enterprise_customer_users + if enterprise_customer_user.get('active', False) + ), + None + ) + if active_enterprise_customer_user: + return active_enterprise_customer_user.get('enterprise_customer') + return None + + def _get_staff_enterprise_customer(self): + """ + Retrieve enterprise customer metadata from `enterprise-customer` LMS API endpoint + if there is no enterprise customer user for the request enterprise and the user is staff. + """ + has_enterprise_customer_slug_or_uuid = self.enterprise_customer_slug or self.enterprise_customer_uuid + if has_enterprise_customer_slug_or_uuid and self.user.is_staff: + try: + staff_enterprise_customer = self.lms_api_client.get_enterprise_customer_data( + enterprise_customer_uuid=self.enterprise_customer_uuid, + enterprise_customer_slug=self.enterprise_customer_slug, + ) + return staff_enterprise_customer + except Exception as e: # pylint: disable=broad-except + self.add_error( + user_message='Error retrieving enterprise customer data', + developer_message=str(e) + ) + return None + + def _determine_enterprise_customer_for_display( + self, + active_enterprise_customer=None, + requested_enterprise_customer=None, + staff_enterprise_customer=None, + ): + """ + Determine the enterprise customer user for display. + + Returns: + The enterprise customer user for display. + """ + if not self.enterprise_customer_slug and not self.enterprise_customer_uuid: + # No enterprise customer specified in the request, so return the active enterprise customer + return active_enterprise_customer + + # If the requested enterprise does not match the active enterprise customer user's slug/uuid + # and there is a linked enterprise customer user for the requested enterprise, return the + # linked enterprise customer. + request_matches_active_enterprise_customer = self._request_matches_active_enterprise_customer( + active_enterprise_customer + ) + if not request_matches_active_enterprise_customer and requested_enterprise_customer: + return requested_enterprise_customer + + # If the request user is staff and the requested enterprise does not match the active enterprise + # customer user's slug/uuid, return the staff-enterprise customer. + if staff_enterprise_customer: + return staff_enterprise_customer + + # Otherwise, return the active enterprise customer. + return active_enterprise_customer + + def _request_matches_active_enterprise_customer(self, active_enterprise_customer): + """ + Check if the request matches the active enterprise customer. + """ + slug_matches_active_enterprise_customer = ( + active_enterprise_customer and active_enterprise_customer.get('slug') == self.enterprise_customer_slug + ) + uuid_matches_active_enterprise_customer = ( + active_enterprise_customer and active_enterprise_customer.get('uuid') == self.enterprise_customer_uuid + ) + return ( + slug_matches_active_enterprise_customer or uuid_matches_active_enterprise_customer + ) + + def _enterprise_customer_matches_slug_or_uuid(self, enterprise_customer): + """ + Check if the enterprise customer matches the slug or UUID. + Args: + enterprise_customer: The enterprise customer data. + Returns: + True if the enterprise customer matches the slug or UUID, otherwise False. + """ + if not enterprise_customer: + return False - # Set lms_user_id from the authenticated user object in the request - self.lms_user_id = getattr(self.user, 'lms_user_id', None) + return ( + enterprise_customer.get('slug') == self.enterprise_customer_slug or + enterprise_customer.get('uuid') == self.enterprise_customer_uuid + ) def add_error(self, **kwargs): """ diff --git a/enterprise_access/apps/bffs/handlers.py b/enterprise_access/apps/bffs/handlers.py index da5c130b..8c173397 100644 --- a/enterprise_access/apps/bffs/handlers.py +++ b/enterprise_access/apps/bffs/handlers.py @@ -5,6 +5,7 @@ import logging from enterprise_access.apps.api_client.license_manager_client import LicenseManagerUserApiClient +from enterprise_access.apps.api_client.lms_client import LmsUserApiClient from enterprise_access.apps.bffs.context import HandlerContext logger = logging.getLogger(__name__) @@ -60,7 +61,10 @@ def __init__(self, context): context (HandlerContext): The context object containing request information and data. """ super().__init__(context) + + # API Clients self.license_manager_client = LicenseManagerUserApiClient(self.context.request) + self.lms_user_api_client = LmsUserApiClient(self.context.request) def load_and_process(self): """ @@ -69,14 +73,15 @@ def load_and_process(self): The method in this class simply calls common learner logic to ensure the context is set up. """ try: + # Transform enterprise customer data + self.transform_enterprise_customers() + # Retrieve and process subscription licenses. Handles activation and auto-apply logic. - # TODO: retrieve enterprise customer metadata,ENT-9629 - # self.load_enterprise_customer() self.load_and_process_subscription_licenses() # Retrieve default enterprise courses and enroll in the redeemable ones - self.load_default_enterprise_courses() - self.enroll_in_redeemable_default_courses() + self.load_default_enterprise_enrollment_intentions() + self.enroll_in_redeemable_default_enterprise_enrollment_intentions() except Exception as e: # pylint: disable=broad-exception-caught logger.exception("Error loading learner portal handler") self.add_error( @@ -84,14 +89,78 @@ def load_and_process(self): developer_message=f"Error: {e}", ) + def transform_enterprise_customers(self): + """ + Transform enterprise customer metadata retrieved by self.context. + """ + for customer_record_key in ('enterprise_customer', 'active_enterprise_customer', 'staff_enterprise_customer'): + if not (customer_record := self.context.data.get(customer_record_key)): + continue + self.context.data[customer_record_key] = self.transform_enterprise_customer(customer_record) + + if enterprise_customer_users := self.context.data.get('all_linked_enterprise_customer_users'): + self.context.data['all_linked_enterprise_customer_users'] = [ + self.transform_enterprise_customer_user(enterprise_customer_user) + for enterprise_customer_user in enterprise_customer_users + ] + + def transform_enterprise_customer_user(self, enterprise_customer_user): + """ + Transform the enterprise customer user data. + + Args: + enterprise_customer_user: The enterprise customer user data. + Returns: + The transformed enterprise customer user data. + """ + enterprise_customer = enterprise_customer_user.get('enterprise_customer') + return { + **enterprise_customer_user, + 'enterprise_customer': self.transform_enterprise_customer(enterprise_customer), + } + + def transform_enterprise_customer(self, enterprise_customer): + """ + Transform the enterprise customer data. + + Args: + enterprise_customer: The enterprise customer data. + Returns: + The transformed enterprise customer data. + """ + if not enterprise_customer or not enterprise_customer.get('enable_learner_portal', False): + # If the enterprise customer does not exist or the learner portal is not enabled, return None + return None + + # Learner Portal is enabled, so transform the enterprise customer data. + identity_provider = enterprise_customer.get("identity_provider") + disable_search = bool( + not enterprise_customer.get("enable_integrated_customer_learner_portal_search", False) and + identity_provider + ) + show_integration_warning = bool(not disable_search and identity_provider) + + return { + **enterprise_customer, + 'disable_search': disable_search, + 'show_integration_warning': show_integration_warning, + } + def load_subscription_licenses(self): """ Load subscription licenses for the learner. """ - subscriptions_result = self.license_manager_client.get_subscription_licenses_for_learner( - enterprise_customer_uuid=self.context.enterprise_customer_uuid - ) - self.transform_subscriptions_result(subscriptions_result) + try: + subscriptions_result = self.license_manager_client.get_subscription_licenses_for_learner( + enterprise_customer_uuid=self.context.enterprise_customer_uuid + ) + self.transform_subscriptions_result(subscriptions_result) + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception("Error loading subscription licenses") + self.add_error( + user_message="An error occurred while loading subscription licenses.", + developer_message=f"Error: {e}", + ) def get_subscription_licenses(self): """ @@ -165,14 +234,11 @@ def process_subscription_licenses(self): This method is called after `load_subscription_licenses` to handle further actions based on the loaded data. - - If the `subscriptions` field does not exist, raises a KeyError """ - if not self.context.data['subscriptions']: - raise KeyError("Unable to retrieve subscriptions.") - - # Check if user already has 'activated' license(s). If so, no further action is needed. - if self.check_has_activated_license(): + if not self.context.data['subscriptions'] or self.check_has_activated_license(): + # Skip processing if: + # - there is no subscriptions data + # - user already has an activated license return # Check if there are 'assigned' licenses that need to be activated @@ -228,7 +294,9 @@ def check_and_activate_assigned_license(self): # Update the subscriptions.subscription_licenses_by_status context with the modified licenses data updated_activated_licenses = subscription_licenses_by_status.get('activated', []) updated_activated_licenses.extend(activated_licenses) - subscription_licenses_by_status['activated'] = updated_activated_licenses + if updated_activated_licenses: + subscription_licenses_by_status['activated'] = updated_activated_licenses + remaining_assigned_licenses = [ subscription_license for subscription_license in assigned_licenses @@ -238,6 +306,7 @@ def check_and_activate_assigned_license(self): subscription_licenses_by_status['assigned'] = remaining_assigned_licenses else: subscription_licenses_by_status.pop('assigned', None) + self.context.data['subscriptions']['subscription_licenses_by_status'] = subscription_licenses_by_status # Update the subscriptions.subscription_licenses context with the modified licenses data @@ -248,6 +317,7 @@ def check_and_activate_assigned_license(self): updated_subscription_licenses.append(activated_license) break updated_subscription_licenses.append(subscription_license) + self.context.data['subscriptions']['subscription_licenses'] = updated_subscription_licenses def check_and_auto_apply_license(self): @@ -260,13 +330,14 @@ def check_and_auto_apply_license(self): # Skip auto-apply if user already has an activated license or assigned licenses return - customer_agreement = self.context.data['subscriptions'].get('customer_agreement', {}) + customer_agreement = self.context.data['subscriptions'].get('customer_agreement') or {} has_subscription_plan_for_auto_apply = ( bool(customer_agreement.get('subscription_for_auto_applied_licenses')) and customer_agreement.get('net_days_until_expiration') > 0 ) + enterprise_customer = self.context.data.get('enterprise_customer', {}) idp_or_univeral_link_enabled = ( - # TODO: IDP from customer, ENT-9629 + enterprise_customer.get('identity_provider') or customer_agreement.get('enable_auto_applied_subscriptions_with_universal_link') ) is_eligible_for_auto_apply = has_subscription_plan_for_auto_apply and idp_or_univeral_link_enabled @@ -289,44 +360,53 @@ def check_and_auto_apply_license(self): developer_message=f"Customer agreement UUID: {customer_agreement.get('uuid')}, Error: {e}", ) - def load_default_enterprise_courses(self): + def load_default_enterprise_enrollment_intentions(self): """ Load default enterprise course enrollments (stubbed) """ - mock_catalog_uuid = 'f09ff39b-f456-4a03-b53b-44cd70f52108' - - self.context.data['default_enterprise_courses'] = [ - { - 'current_course_run_key': 'course-v1:edX+DemoX+Demo_Course', - 'applicable_catalog_uuids': [mock_catalog_uuid], - }, - { - 'current_course_run_key': 'course-v1:edX+SampleX+Sample_Course', - 'applicable_catalog_uuids': [mock_catalog_uuid], - }, - ] + client = self.lms_user_api_client + try: + default_enrollment_intentions = client.get_default_enterprise_enrollment_intentions_learner_status( + enterprise_customer_uuid=self.context.enterprise_customer_uuid, + ) + self.context.data['default_enterprise_enrollment_intentions'] = default_enrollment_intentions + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception("Error loading default enterprise courses") + self.add_error( + user_message="An error occurred while loading default enterprise courses.", + developer_message=f"Error: {e}", + ) - def enroll_in_redeemable_default_courses(self): + def enroll_in_redeemable_default_enterprise_enrollment_intentions(self): """ Enroll in redeemable courses. """ - default_enterprise_courses = self.context.data.get('default_enterprise_courses', []) + default_enterprise_enrollment_intentions = self.context.data.get('default_enterprise_enrollment_intentions', {}) + needs_enrollment = default_enterprise_enrollment_intentions.get('needs_enrollment', {}) + needs_enrollment_enrollable = needs_enrollment.get('enrollable', []) + activated_subscription_licenses = self.get_subscription_licenses_by_status().get('activated', []) - if not (default_enterprise_courses or activated_subscription_licenses): - # Skip enrollment if there are no default enterprise courses or activated subscription licenses + if not (needs_enrollment_enrollable or activated_subscription_licenses): + # Skip enrollment if there are no: + # - default enterprise enrollment intentions that should be enrolled OR + # - activated subscription licenses return redeemable_default_courses = [] - for course in default_enterprise_courses: + for enrollment_intention in default_enterprise_enrollment_intentions: for subscription_license in activated_subscription_licenses: subscription_plan = subscription_license.get('subscription_plan', {}) - if subscription_plan.get('enterprise_catalog_uuid') in course.get('applicable_catalog_uuids'): - redeemable_default_courses.append((course, subscription_license)) + subscription_catalog = subscription_plan.get('enterprise_catalog_uuid') + applicable_catalog_to_enrollment_intention = enrollment_intention.get( + 'applicable_enterprise_catalog_uuids' + ) + if subscription_catalog in applicable_catalog_to_enrollment_intention: + redeemable_default_courses.append((enrollment_intention, subscription_license)) break for redeemable_course, subscription_license in redeemable_default_courses: - # Enroll in redeemable courses (stubbed) + # TODO: enroll in redeemable courses (stubbed) if not self.context.data.get('default_enterprise_enrollment_realizations'): self.context.data['default_enterprise_enrollment_realizations'] = [] @@ -351,10 +431,11 @@ def load_and_process(self): This method overrides the `load_and_process` method in `BaseLearnerPortalHandler`. """ - # Call the common learner logic from the base class + super().load_and_process() + try: # Load data specific to the dashboard route - self.context.data['enterprise_course_enrollments'] = self.get_enterprise_course_enrollments() + self.load_enterprise_course_enrollments() except Exception as e: # pylint: disable=broad-exception-caught logger.exception("Error retrieving enterprise_course_enrollments") self.add_error( @@ -362,35 +443,22 @@ def load_and_process(self): developer_message=f"Error: {e}", ) - def get_enterprise_course_enrollments(self): + def load_enterprise_course_enrollments(self): """ Loads enterprise course enrollments data. Returns: list: A list of enterprise course enrollments. """ - # Placeholder logic for loading enterprise course enrollments data - return [ - { - "certificate_download_url": None, - "emails_enabled": False, - "course_run_id": "course-v1:BabsonX+MIS01x+1T2019", - "course_run_status": "in_progress", - "created": "2023-09-29T14:24:45.409031+00:00", - "start_date": "2019-03-19T10:00:00Z", - "end_date": "2024-12-31T04:30:00Z", - "display_name": "AI for Leaders", - "course_run_url": "https://learning.edx.org/course/course-v1:BabsonX+MIS01x+1T2019/home", - "due_dates": [], - "pacing": "self", - "org_name": "BabsonX", - "is_revoked": False, - "is_enrollment_active": True, - "mode": "verified", - "resume_course_run_url": None, - "course_key": "BabsonX+MIS01x", - "course_type": "verified-audit", - "product_source": "edx", - "enroll_by": "2024-12-21T23:59:59Z", - } - ] + try: + enterprise_course_enrollments = self.lms_user_api_client.get_enterprise_course_enrollments( + enterprise_customer_uuid=self.context.enterprise_customer_uuid, + is_active=True, + ) + self.context.data['enterprise_course_enrollments'] = enterprise_course_enrollments + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception("Error retrieving enterprise course enrollments") + self.add_error( + user_message="An error occurred while retrieving enterprise course enrollments.", + developer_message=f"Error: {e}", + ) diff --git a/enterprise_access/apps/bffs/tests/test_context.py b/enterprise_access/apps/bffs/tests/test_context.py index 3f452417..2f0b449d 100644 --- a/enterprise_access/apps/bffs/tests/test_context.py +++ b/enterprise_access/apps/bffs/tests/test_context.py @@ -1,77 +1,257 @@ """ Text for the BFF context """ + +from unittest import mock + import ddt -from django.test import RequestFactory, TestCase -from faker import Faker from rest_framework.exceptions import ValidationError from enterprise_access.apps.bffs.context import HandlerContext -from enterprise_access.apps.core.tests.factories import UserFactory +from enterprise_access.apps.bffs.tests.utils import TestHandlerContextMixin @ddt.ddt -class TestHandlerContext(TestCase): - def setUp(self): - super().setUp() - self.factory = RequestFactory() - self.mock_user = UserFactory() - self.faker = Faker() - - self.mock_enterprise_customer_uuid = self.faker.uuid4() - self.request = self.factory.get('sample/api/call') - self.request.user = self.mock_user - self.request.query_params = { - 'enterprise_customer_uuid': self.mock_enterprise_customer_uuid +class TestHandlerContext(TestHandlerContextMixin): + """ + Test the HandlerContext class + """ + + @ddt.data( + {'raises_exception': False}, + {'raises_exception': True}, + ) + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') + @ddt.unpack + def test_handler_context_init(self, mock_get_enterprise_customers_for_user, raises_exception): + if raises_exception: + mock_get_enterprise_customers_for_user.side_effect = Exception('Mock exception') + else: + mock_get_enterprise_customers_for_user.return_value = self.mock_enterprise_learner_response_data + + context = HandlerContext(self.request) + + self.assertEqual(context.request, self.request) + self.assertEqual(context.user, self.mock_user) + + expected_data = {} + if not raises_exception: + expected_data = { + 'enterprise_customer': self.mock_enterprise_customer, + 'active_enterprise_customer': self.mock_enterprise_customer, + 'staff_enterprise_customer': None, + 'all_linked_enterprise_customer_users': [ + { + **self.mock_enterprise_learner_response_data['results'][0], + 'enterprise_customer': self.mock_enterprise_customer, + }, + { + **self.mock_enterprise_learner_response_data['results'][1], + 'enterprise_customer': self.mock_enterprise_customer_2, + } + ], + } + + self.assertEqual(context.data, expected_data) + if raises_exception: + self.assertEqual(context.enterprise_features, {}) + else: + self.assertEqual(context.enterprise_customer_slug, self.mock_enterprise_customer_slug) + self.assertEqual( + context.enterprise_features, + self.mock_enterprise_learner_response_data['enterprise_features'] + ) + + expected_errors = ( + [ + { + 'developer_message': 'Mock exception', + 'user_message': 'Error retrieving linked enterprise customers' + } + ] if raises_exception else [] + ) + self.assertEqual(context.errors, expected_errors) + self.assertEqual(context.warnings, []) + self.assertEqual(context.enterprise_customer_uuid, self.mock_enterprise_customer_uuid) + expected_slug = None if raises_exception else self.mock_enterprise_customer_slug + self.assertEqual(context.enterprise_customer_slug, expected_slug) + self.assertEqual(context.lms_user_id, self.mock_user.lms_user_id) + expected_enterprise_customer = None if raises_exception else self.mock_enterprise_customer + self.assertEqual(context.enterprise_customer, expected_enterprise_customer) + + @ddt.data( + {'raises_exception': False}, + {'raises_exception': True}, + ) + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsApiClient.get_enterprise_customer_data') + @ddt.unpack + def test_handler_context_init_staff_user_unlinked( + self, + mock_get_enterprise_customer_data, + mock_get_enterprise_customers_for_user, + raises_exception, + ): + mock_get_enterprise_customers_for_user.return_value = { + **self.mock_enterprise_learner_response_data, + 'results': [], } - self.context = HandlerContext(self.request) - def test_handler_context_init(self): - self.assertEqual(self.context.request, self.request) - self.assertEqual(self.context.user, self.mock_user) - self.assertEqual(self.context.data, {}) - self.assertEqual(self.context.errors, []) - self.assertEqual(self.context.warnings, []) - self.assertEqual(self.context.enterprise_customer_uuid, self.mock_enterprise_customer_uuid) - self.assertEqual(self.context.lms_user_id, self.mock_user.lms_user_id) + if raises_exception: + mock_get_enterprise_customer_data.side_effect = Exception('Mock exception') + else: + mock_get_enterprise_customer_data.return_value = self.mock_enterprise_customer + + request = self.request + request.user = self.mock_staff_user + context = HandlerContext(request) + + self.assertEqual(context.request, request) + self.assertEqual(context.user, self.mock_staff_user) + + expected_data = { + 'enterprise_customer': self.mock_enterprise_customer, + 'active_enterprise_customer': None, + 'staff_enterprise_customer': self.mock_enterprise_customer, + 'all_linked_enterprise_customer_users': [], + } + if raises_exception: + expected_data.update({ + 'enterprise_customer': None, + 'staff_enterprise_customer': None, + }) + self.assertEqual(context.data, expected_data) + expected_errors = ( + [ + { + 'developer_message': 'Mock exception', + 'user_message': 'Error retrieving enterprise customer data' + } + ] if raises_exception else [] + ) + self.assertEqual(context.errors, expected_errors) + self.assertEqual(context.warnings, []) + self.assertEqual(context.enterprise_features, self.mock_enterprise_learner_response_data['enterprise_features']) + self.assertEqual(context.enterprise_customer_uuid, self.mock_enterprise_customer_uuid) + expected_slug = None if raises_exception else self.mock_enterprise_customer_slug + self.assertEqual(context.enterprise_customer_slug, expected_slug) + self.assertEqual(context.lms_user_id, self.mock_staff_user.lms_user_id) + expected_enterprise_customer = None if raises_exception else self.mock_enterprise_customer + self.assertEqual(context.enterprise_customer, expected_enterprise_customer) @ddt.data( + # No enterprise customer uuid/slug in the request; returns active enterprise customer user + { + 'has_query_params': False, + 'has_payload_data': False, + 'has_enterprise_customer_uuid_param': False, + 'has_enterprise_customer_slug_param': False, + }, + # Enterprise customer uuid in the request; returns enterprise customer user with that uuid + { + 'has_query_params': True, + 'has_payload_data': False, + 'has_enterprise_customer_uuid_param': True, + 'has_enterprise_customer_slug_param': False, + }, + { + 'has_query_params': False, + 'has_payload_data': True, + 'has_enterprise_customer_uuid_param': True, + 'has_enterprise_customer_slug_param': False, + }, + { + 'has_query_params': True, + 'has_payload_data': True, + 'has_enterprise_customer_uuid_param': True, + 'has_enterprise_customer_slug_param': False, + }, + # Enterprise customer slug in the request; returns enterprise customer user with that slug { - 'query_params': True, - 'data': True, + 'has_query_params': True, + 'has_payload_data': False, + 'has_enterprise_customer_uuid_param': False, + 'has_enterprise_customer_slug_param': True, }, { - 'query_params': False, - 'data': True, + 'has_query_params': False, + 'has_payload_data': True, + 'has_enterprise_customer_uuid_param': False, + 'has_enterprise_customer_slug_param': True, }, { - 'query_params': True, - 'data': False, + 'has_query_params': True, + 'has_payload_data': True, + 'has_enterprise_customer_uuid_param': False, + 'has_enterprise_customer_slug_param': True, }, + # Both enterprise customer uuid and slug in the request; returns enterprise customer user with that uuid { - 'query_params': False, - 'data': False, + 'has_query_params': True, + 'has_payload_data': False, + 'has_enterprise_customer_uuid_param': True, + 'has_enterprise_customer_slug_param': True, + }, + { + 'has_query_params': False, + 'has_payload_data': True, + 'has_enterprise_customer_uuid_param': True, + 'has_enterprise_customer_slug_param': True, + }, + { + 'has_query_params': True, + 'has_payload_data': True, + 'has_enterprise_customer_uuid_param': True, + 'has_enterprise_customer_slug_param': True, }, ) + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') @ddt.unpack - def test_handler_context_enterprise_customer_uuid_param(self, query_params, data): - if not query_params: - self.request.query_params = {} - - if data: - self.request = self.factory.post('sample/api/call') - self.request.data = { - 'enterprise_customer_uuid': self.mock_enterprise_customer_uuid - } + def test_handler_context_enterprise_customer_params( + self, + mock_get_enterprise_customers_for_user, + has_query_params, + has_payload_data, + has_enterprise_customer_uuid_param, + has_enterprise_customer_slug_param, + ): + mock_get_enterprise_customers_for_user.return_value = self.mock_enterprise_learner_response_data + request = self.request + + query_params = {} + if has_query_params: + if has_enterprise_customer_uuid_param: + query_params['enterprise_customer_uuid'] = self.mock_enterprise_customer_uuid_2 + if has_enterprise_customer_slug_param: + query_params['enterprise_customer_slug'] = self.mock_enterprise_customer_slug_2 - if not (query_params or data): - with self.assertRaises(ValueError): - HandlerContext(self.request) + if has_payload_data: + # Switch to a POST request + request = self.factory.post('sample/api/call') + request.user = self.mock_user + request.data = {} + if has_enterprise_customer_uuid_param: + request.data['enterprise_customer_uuid'] = self.mock_enterprise_customer_uuid_2 + if has_enterprise_customer_slug_param: + request.data['enterprise_customer_slug'] = self.mock_enterprise_customer_slug_2 + + # Set the query params, if any. + request.query_params = query_params + + context = HandlerContext(request) + + if has_enterprise_customer_slug_param or has_enterprise_customer_uuid_param: + self.assertEqual(context.enterprise_customer_uuid, self.mock_enterprise_customer_uuid_2) + self.assertEqual(context.enterprise_customer_slug, self.mock_enterprise_customer_slug_2) else: - self.assertEqual(self.context.enterprise_customer_uuid, self.mock_enterprise_customer_uuid) - self.assertEqual(self.context.lms_user_id, self.mock_user.lms_user_id) + self.assertEqual(context.enterprise_customer_uuid, self.mock_enterprise_customer_uuid) + self.assertEqual(context.enterprise_customer_slug, self.mock_enterprise_customer_slug) + + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') + def test_handler_context_add_error_serializer(self, mock_get_enterprise_customers_for_user): + mock_get_enterprise_customers_for_user.return_value = self.mock_enterprise_learner_response_data + context = HandlerContext(self.request) - def test_handler_context_add_error_serializer(self): expected_output = { "developer_message": "No enterprise uuid associated to the user mock-id", "user_message": "You may not be associated with the enterprise.", @@ -81,19 +261,26 @@ def test_handler_context_add_error_serializer(self): **expected_output, "status": 403 # Add an attribute that is not explicitly defined in the serializer to verify } - self.context.add_error( + context.add_error( **arguments ) - self.assertEqual(expected_output, self.context.errors[0]) + self.assertEqual(expected_output, context.errors[0]) + + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') + def test_handler_context_add_error_serializer_is_valid(self, mock_get_enterprise_customers_for_user): + mock_get_enterprise_customers_for_user.return_value = self.mock_enterprise_learner_response_data + context = HandlerContext(self.request) - def test_handler_context_add_error_serializer_is_valid(self): malformed_output = { "developer_message": "No enterprise uuid associated to the user mock-id", } with self.assertRaises(ValidationError): - self.context.add_error(**malformed_output) + context.add_error(**malformed_output) - def test_handler_context_add_warning_serializer(self): + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') + def test_handler_context_add_warning_serializer(self, mock_get_enterprise_customers_for_user): + mock_get_enterprise_customers_for_user.return_value = self.mock_enterprise_learner_response_data + context = HandlerContext(self.request) expected_output = { "developer_message": "Heuristic Expiration", "user_message": "The data received might be out-dated", @@ -103,14 +290,17 @@ def test_handler_context_add_warning_serializer(self): **expected_output, "status": 113 # Add an attribute that is not explicitly defined in the serializer to verify } - self.context.add_warning( + context.add_warning( **arguments ) - self.assertEqual(expected_output, self.context.warnings[0]) + self.assertEqual(expected_output, context.warnings[0]) - def test_handler_context_add_warning_serializer_is_valid(self): + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') + def test_handler_context_add_warning_serializer_is_valid(self, mock_get_enterprise_customers_for_user): + mock_get_enterprise_customers_for_user.return_value = self.mock_enterprise_learner_response_data + context = HandlerContext(self.request) malformed_output = { "user_message": "The data received might be out-dated", } with self.assertRaises(ValidationError): - self.context.add_error(**malformed_output) + context.add_error(**malformed_output) diff --git a/enterprise_access/apps/bffs/tests/test_handlers.py b/enterprise_access/apps/bffs/tests/test_handlers.py index 369980ba..a1afc0e3 100644 --- a/enterprise_access/apps/bffs/tests/test_handlers.py +++ b/enterprise_access/apps/bffs/tests/test_handlers.py @@ -1,40 +1,26 @@ -from django.test import RequestFactory, TestCase -from faker import Faker -from rest_framework.exceptions import ValidationError +from unittest import mock from enterprise_access.apps.bffs.context import HandlerContext from enterprise_access.apps.bffs.handlers import BaseHandler, BaseLearnerPortalHandler, DashboardHandler -from enterprise_access.apps.core.tests.factories import UserFactory +from enterprise_access.apps.bffs.tests.utils import TestHandlerContextMixin -class TestBaseHandlerMixin(TestCase): - def setUp(self): - super().setUp() - self.factory = RequestFactory() - self.mock_user = UserFactory() - self.faker = Faker() - - self.mock_enterprise_customer_uuid = self.faker.uuid4() - self.request = self.factory.get('sample/api/call') - self.request.query_params = { - 'enterprise_customer_uuid': self.mock_enterprise_customer_uuid - } - self.request.user = self.mock_user - self.context = HandlerContext(self.request) +class TestBaseHandler(TestHandlerContextMixin): + """ + Test BaseHandler + """ - -class TestBaseHandler(TestBaseHandlerMixin): - def setUp(self): - super().setUp() - self.base_handler = BaseHandler(self.context) - - def test_base_handler_uninitialized_load_and_process(self): - base_handler = self.base_handler + def test_base_handler_load_and_process_not_implemented(self): + context = HandlerContext(self.request) + base_handler = BaseHandler(context) with self.assertRaises(NotImplementedError): base_handler.load_and_process() - def test_base_handler_add_error(self): - base_handler = self.base_handler + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') + def test_base_handler_add_error(self, mock_get_enterprise_customers_for_user): + mock_get_enterprise_customers_for_user.return_value = {'results': []} + context = HandlerContext(self.request) + base_handler = BaseHandler(context) expected_output = { "developer_message": "No enterprise uuid associated to the user mock-uuid", "user_message": "You may not be associated with the enterprise.", @@ -50,7 +36,8 @@ def test_base_handler_add_error(self): self.assertEqual(expected_output, base_handler.context.errors[0]) def test_base_handler_add_warning(self): - base_handler = self.base_handler + context = HandlerContext(self.request) + base_handler = BaseHandler(context) expected_output = { "developer_message": "Heuristic Expiration", "user_message": "The data received might be out-dated", @@ -66,48 +53,214 @@ def test_base_handler_add_warning(self): self.assertEqual(expected_output, base_handler.context.warnings[0]) -class TestBaseLearnerPortalHandler(TestBaseHandlerMixin): +class TestBaseLearnerPortalHandler(TestHandlerContextMixin): + """ + Test BaseLearnerPortalHandler + """ + def setUp(self): super().setUp() - self.base_learner_portal_handler = BaseLearnerPortalHandler(self.context) + self.expected_enterprise_customer = { + **self.mock_enterprise_customer, + 'disable_search': False, + 'show_integration_warning': False, + } + self.expected_enterprise_customer_2 = { + **self.mock_enterprise_customer_2, + 'disable_search': False, + 'show_integration_warning': False, + } + self.mock_subscription_licenses_data = { + 'customer_agreement': None, + 'results': [], + } + self.mock_default_enterprise_enrollment_intentions_learner_status_data = { + "lms_user_id": self.mock_user.id, + "user_email": self.mock_user.email, + "enterprise_customer_uuid": self.mock_enterprise_customer_uuid, + "enrollment_statuses": { + "needs_enrollment": { + "enrollable": [], + "not_enrollable": [], + }, + 'already_enrolled': [], + }, + "metadata": { + "total_default_enterprise_enrollment_intentions": 0, + "total_needs_enrollment": { + "enrollable": 0, + "not_enrollable": 0 + }, + "total_already_enrolled": 0 + } + } + + def get_expected_enterprise_customer(self, enterprise_customer_user): + enterprise_customer = enterprise_customer_user.get('enterprise_customer') + return ( + self.expected_enterprise_customer + if enterprise_customer.get('uuid') == self.mock_enterprise_customer_uuid + else self.expected_enterprise_customer_2 + ) - # TODO: Test pertaining to currently stubbed out functions deferred for future tickets + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') + @mock.patch( + 'enterprise_access.apps.api_client.lms_client.LmsUserApiClient' + '.get_default_enterprise_enrollment_intentions_learner_status' + ) + @mock.patch( + 'enterprise_access.apps.api_client.license_manager_client.LicenseManagerUserApiClient' + '.get_subscription_licenses_for_learner' + ) + def test_load_and_process( + self, + mock_get_subscription_licenses_for_learner, + mock_get_default_enrollment_intentions_learner_status, + mock_get_enterprise_customers_for_user, + ): + """ + Test load_and_process method + """ + mock_get_enterprise_customers_for_user.return_value = self.mock_enterprise_learner_response_data + mock_get_subscription_licenses_for_learner.return_value = self.mock_subscription_licenses_data + mock_get_default_enrollment_intentions_learner_status.return_value =\ + self.mock_default_enterprise_enrollment_intentions_learner_status_data + context = HandlerContext(self.request) + handler = BaseLearnerPortalHandler(context) -class TestDashboardHandler(TestBaseHandlerMixin): - def setUp(self): - super().setUp() - self.dashboard_handler = DashboardHandler(self.context) + handler.load_and_process() - # TODO: Update tests once stubbed out function updated - def test_load_and_process(self): - expected_output = [ + # Enterprise Customer related assertions + actual_enterprise_customer = handler.context.data.get('enterprise_customer') + actual_active_enterprise_customer = handler.context.data.get('active_enterprise_customer') + actual_linked_ecus = handler.context.data.get('all_linked_enterprise_customer_users') + expected_linked_ecus = [ { - "certificate_download_url": None, - "emails_enabled": False, - "course_run_id": "course-v1:BabsonX+MIS01x+1T2019", - "course_run_status": "in_progress", - "created": "2023-09-29T14:24:45.409031+00:00", - "start_date": "2019-03-19T10:00:00Z", - "end_date": "2024-12-31T04:30:00Z", - "display_name": "AI for Leaders", - "course_run_url": "https://learning.edx.org/course/course-v1:BabsonX+MIS01x+1T2019/home", - "due_dates": [], - "pacing": "self", - "org_name": "BabsonX", - "is_revoked": False, - "is_enrollment_active": True, - "mode": "verified", - "resume_course_run_url": None, - "course_key": "BabsonX+MIS01x", - "course_type": "verified-audit", - "product_source": "edx", - "enroll_by": "2024-12-21T23:59:59Z", + **enterprise_customer_user, + 'enterprise_customer': self.get_expected_enterprise_customer(enterprise_customer_user), } + for enterprise_customer_user in self.mock_enterprise_learner_response_data['results'] ] - dashboard_handler = self.dashboard_handler + actual_staff_enterprise_customer = handler.context.data.get('staff_enterprise_customer') + expected_staff_enterprise_customer = None + self.assertEqual(actual_enterprise_customer, self.expected_enterprise_customer) + self.assertEqual(actual_active_enterprise_customer, self.expected_enterprise_customer) + self.assertEqual(actual_linked_ecus, expected_linked_ecus) + self.assertEqual(actual_staff_enterprise_customer, expected_staff_enterprise_customer) + + # Base subscriptions related assertions + actual_subscriptions = handler.context.data.get('subscriptions') + expected_subscriptions = { + 'customer_agreement': None, + 'subscription_licenses': [], + 'subscription_licenses_by_status': {}, + } + self.assertEqual(actual_subscriptions, expected_subscriptions) + + # Default enterprise enrollment intentions related assertions + actual_default_enterprise_enrollment_intentions = ( + handler.context.data.get('default_enterprise_enrollment_intentions') + ) + expected_default_enterprise_enrollment_intentions = ( + self.mock_default_enterprise_enrollment_intentions_learner_status_data + ) + self.assertEqual( + actual_default_enterprise_enrollment_intentions, + expected_default_enterprise_enrollment_intentions + ) + + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') + def test_load_and_process_without_learner_portal_enabled(self, mock_get_enterprise_customers_for_user): + """ + Test load_and_process method without learner portal enabled. No enterprise + customer metadata should be returned. + """ + mock_get_enterprise_customers_for_user.return_value = { + **self.mock_enterprise_learner_response_data, + 'results': [{ + **self.mock_enterprise_customer, + 'enable_learner_portal': False, + }], + } + context = HandlerContext(self.request) + handler = BaseLearnerPortalHandler(context) + + handler.load_and_process() + + actual_enterprise_customer = handler.context.data.get('enterprise_customer') + expected_enterprise_customer = None + self.assertEqual(actual_enterprise_customer, expected_enterprise_customer) + + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_customers_for_user') + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsApiClient.get_enterprise_customer_data') + def test_load_and_process_staff_enterprise_customer( + self, + mock_get_enterprise_customer_data, + mock_get_enterprise_customers_for_user, + ): + mock_get_enterprise_customers_for_user.return_value = { + **self.mock_enterprise_learner_response_data, + 'results': [], + } + mock_get_enterprise_customer_data.return_value = self.mock_enterprise_customer + request = self.request + request.user = self.mock_staff_user + context = HandlerContext(request) + handler = BaseLearnerPortalHandler(context) + + handler.load_and_process() + + actual_enterprise_customer = handler.context.data.get('enterprise_customer') + expected_enterprise_customer = self.expected_enterprise_customer + self.assertEqual(actual_enterprise_customer, expected_enterprise_customer) + actual_staff_enterprise_customer = handler.context.data.get('staff_enterprise_customer') + expected_staff_enterprise_customer = self.expected_enterprise_customer + self.assertEqual(actual_staff_enterprise_customer, expected_staff_enterprise_customer) + + +class TestDashboardHandler(TestHandlerContextMixin): + """ + Test DashboardHandler + """ + + def setUp(self): + super().setUp() + + self.mock_enterprise_course_enrollment = { + "certificate_download_url": None, + "emails_enabled": False, + "course_run_id": "course-v1:BabsonX+MIS01x+1T2019", + "course_run_status": "in_progress", + "created": "2023-09-29T14:24:45.409031+00:00", + "start_date": "2019-03-19T10:00:00Z", + "end_date": "2024-12-31T04:30:00Z", + "display_name": "AI for Leaders", + "course_run_url": "https://learning.edx.org/course/course-v1:BabsonX+MIS01x+1T2019/home", + "due_dates": [], + "pacing": "self", + "org_name": "BabsonX", + "is_revoked": False, + "is_enrollment_active": True, + "mode": "verified", + "resume_course_run_url": None, + "course_key": "BabsonX+MIS01x", + "course_type": "verified-audit", + "product_source": "edx", + "enroll_by": "2024-12-21T23:59:59Z", + } + self.mock_enterprise_course_enrollments = [self.mock_enterprise_course_enrollment] + + @mock.patch('enterprise_access.apps.api_client.lms_client.LmsUserApiClient.get_enterprise_course_enrollments') + def test_load_and_process(self, mock_get_enterprise_course_enrollments): + mock_get_enterprise_course_enrollments.return_value = self.mock_enterprise_course_enrollments + + context = HandlerContext(self.request) + dashboard_handler = DashboardHandler(context) + dashboard_handler.load_and_process() + self.assertEqual( - dashboard_handler.context.data['enterprise_course_enrollments'], - expected_output + dashboard_handler.context.data.get('enterprise_course_enrollments'), + self.mock_enterprise_course_enrollments, ) diff --git a/enterprise_access/apps/bffs/tests/utils.py b/enterprise_access/apps/bffs/tests/utils.py new file mode 100644 index 00000000..09dfdf60 --- /dev/null +++ b/enterprise_access/apps/bffs/tests/utils.py @@ -0,0 +1,60 @@ +""" +Test utilities for BFFs. +""" + +from django.test import RequestFactory, TestCase +from faker import Faker + +from enterprise_access.apps.core.tests.factories import UserFactory + + +class TestHandlerContextMixin(TestCase): + """ + Mixin for HandlerContext tests + """ + + def setUp(self): + super().setUp() + self.maxDiff = None + self.factory = RequestFactory() + self.mock_user = UserFactory() + self.mock_staff_user = UserFactory(is_staff=True) + self.faker = Faker() + + self.mock_enterprise_customer_uuid = self.faker.uuid4() + self.mock_enterprise_customer_slug = 'mock-slug' + self.mock_enterprise_customer_uuid_2 = self.faker.uuid4() + self.mock_enterprise_customer_slug_2 = 'mock-slug-2' + + # Mock request + self.request = self.factory.get('sample/api/call') + self.request.user = self.mock_user + self.request.query_params = { + 'enterprise_customer_uuid': self.mock_enterprise_customer_uuid + } + self.request.data = {} + + # Mock enterprise customer data + self.mock_enterprise_customer = { + 'uuid': self.mock_enterprise_customer_uuid, + 'slug': self.mock_enterprise_customer_slug, + 'enable_learner_portal': True, + } + self.mock_enterprise_customer_2 = { + 'uuid': self.mock_enterprise_customer_uuid_2, + 'slug': self.mock_enterprise_customer_slug_2, + 'enable_learner_portal': True, + } + self.mock_enterprise_learner_response_data = { + 'results': [ + { + 'active': True, + 'enterprise_customer': self.mock_enterprise_customer, + }, + { + 'active': False, + 'enterprise_customer': self.mock_enterprise_customer_2, + }, + ], + 'enterprise_features': {'feature_flag': True} + }