From d1ec72b5e56ffac78c51a2998dfbfb5b4e916bf2 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 2 Jan 2025 18:09:02 -0800 Subject: [PATCH] Reworked salesforce connector to use bulk api (#3581) --- backend/.gitignore | 1 + .../onyx/connectors/salesforce/connector.py | 347 ++++---- .../connectors/salesforce/doc_conversion.py | 10 +- .../connectors/salesforce/salesforce_calls.py | 210 +++++ .../shelve_stuff/shelve_functions.py | 209 +++++ .../salesforce/shelve_stuff/shelve_utils.py | 29 + .../shelve_stuff/test_salesforce_shelves.py | 737 +++++++++++++++++ .../connectors/salesforce/sqlite_functions.py | 386 +++++++++ backend/onyx/connectors/salesforce/utils.py | 72 ++ .../salesforce/test_salesforce_sqlite.py | 746 ++++++++++++++++++ 10 files changed, 2545 insertions(+), 202 deletions(-) create mode 100644 backend/onyx/connectors/salesforce/salesforce_calls.py create mode 100644 backend/onyx/connectors/salesforce/shelve_stuff/shelve_functions.py create mode 100644 backend/onyx/connectors/salesforce/shelve_stuff/shelve_utils.py create mode 100644 backend/onyx/connectors/salesforce/shelve_stuff/test_salesforce_shelves.py create mode 100644 backend/onyx/connectors/salesforce/sqlite_functions.py create mode 100644 backend/onyx/connectors/salesforce/utils.py create mode 100644 backend/tests/unit/onyx/connectors/salesforce/test_salesforce_sqlite.py diff --git a/backend/.gitignore b/backend/.gitignore index b1c4f4db71d..9c2da46d35a 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -9,3 +9,4 @@ api_keys.py vespa-app.zip dynamic_config_storage/ celerybeat-schedule* +onyx/connectors/salesforce/data/ \ No newline at end of file diff --git a/backend/onyx/connectors/salesforce/connector.py b/backend/onyx/connectors/salesforce/connector.py index 26524e9a0ed..6ada66387f4 100644 --- a/backend/onyx/connectors/salesforce/connector.py +++ b/backend/onyx/connectors/salesforce/connector.py @@ -1,11 +1,7 @@ import os -from collections.abc import Iterator -from datetime import datetime -from datetime import UTC from typing import Any from simple_salesforce import Salesforce -from simple_salesforce import SFType from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource @@ -20,36 +16,24 @@ from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import SlimDocument -from onyx.connectors.salesforce.doc_conversion import extract_sections +from onyx.connectors.salesforce.doc_conversion import extract_section +from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel +from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type +from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type +from onyx.connectors.salesforce.sqlite_functions import get_child_ids +from onyx.connectors.salesforce.sqlite_functions import get_record +from onyx.connectors.salesforce.sqlite_functions import init_db +from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv +from onyx.connectors.salesforce.utils import SalesforceObject from onyx.utils.logger import setup_logger -from shared_configs.utils import batch_list logger = setup_logger() -# max query length is 20,000 characters, leave 5000 characters for slop -_MAX_QUERY_LENGTH = 10000 -# There are 22 extra characters per ID so 200 * 22 = 4400 characters which is -# still well under the max query length -_MAX_ID_BATCH_SIZE = 200 - _DEFAULT_PARENT_OBJECT_TYPES = ["Account"] _ID_PREFIX = "SALESFORCE_" -def _build_time_filter_for_salesforce( - start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None -) -> str: - if start is None or end is None: - return "" - start_datetime = datetime.fromtimestamp(start, UTC) - end_datetime = datetime.fromtimestamp(end, UTC) - return ( - f" WHERE LastModifiedDate > {start_datetime.isoformat()} " - f"AND LastModifiedDate < {end_datetime.isoformat()}" - ) - - class SalesforceConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, @@ -64,7 +48,10 @@ def __init__( else _DEFAULT_PARENT_OBJECT_TYPES ) - def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + def load_credentials( + self, + credentials: dict[str, Any], + ) -> dict[str, Any] | None: self._sf_client = Salesforce( username=credentials["sf_username"], password=credentials["sf_password"], @@ -78,203 +65,146 @@ def sf_client(self) -> Salesforce: raise ConnectorMissingCredentialError("Salesforce") return self._sf_client - def _get_sf_type_object_json(self, type_name: str) -> Any: - sf_object = SFType( - type_name, self.sf_client.session_id, self.sf_client.sf_instance - ) - return sf_object.describe() - - def _get_name_from_id(self, id: str) -> str: - try: - user_object_info = self.sf_client.query( - f"SELECT Name FROM User WHERE Id = '{id}'" - ) - name = user_object_info.get("Records", [{}])[0].get("Name", "Null User") - return name - except Exception: - logger.warning(f"Couldnt find name for object id: {id}") - return "Null User" + def _extract_primary_owners( + self, sf_object: SalesforceObject + ) -> list[BasicExpertInfo] | None: + object_dict = sf_object.data + if not (last_modified_by_id := object_dict.get("LastModifiedById")): + return None + if not (last_modified_by := get_record(last_modified_by_id)): + return None + if not (last_modified_by_name := last_modified_by.data.get("Name")): + return None + primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)] + return primary_owners def _convert_object_instance_to_document( - self, object_dict: dict[str, Any] + self, sf_object: SalesforceObject ) -> Document: + object_dict = sf_object.data salesforce_id = object_dict["Id"] onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}" base_url = f"https://{self.sf_client.sf_instance}" extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"]) extracted_semantic_identifier = object_dict.get("Name", "Unknown Object") - extracted_primary_owners = [ - BasicExpertInfo( - display_name=self._get_name_from_id(object_dict["LastModifiedById"]) - ) - ] + + sections = [extract_section(sf_object, base_url)] + for id in get_child_ids(sf_object.id): + if not (child_object := get_record(id)): + continue + sections.append(extract_section(child_object, base_url)) doc = Document( id=onyx_salesforce_id, - sections=extract_sections(object_dict, base_url), + sections=sections, source=DocumentSource.SALESFORCE, semantic_identifier=extracted_semantic_identifier, doc_updated_at=extracted_doc_updated_at, - primary_owners=extracted_primary_owners, + primary_owners=self._extract_primary_owners(sf_object), metadata={}, ) return doc - def _is_valid_child_object(self, child_relationship: dict) -> bool: - if not child_relationship["childSObject"]: - return False - if not child_relationship["relationshipName"]: - return False - - sf_type = child_relationship["childSObject"] - object_description = self._get_sf_type_object_json(sf_type) - if not object_description["queryable"]: - return False - - try: - query = f"SELECT Count() FROM {sf_type} LIMIT 1" - result = self.sf_client.query(query) - if result["totalSize"] == 0: - return False - except Exception as e: - logger.warning(f"Object type {sf_type} doesn't support query: {e}") - return False - - if child_relationship["field"]: - if child_relationship["field"] == "RelatedToId": - return False - else: - return False - - return True - - def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]: - logger.debug(f"Fetching children for SF type: {sf_type}") - object_description = self._get_sf_type_object_json(sf_type) - - children_objects: list[dict] = [] - for child_relationship in object_description["childRelationships"]: - if self._is_valid_child_object(child_relationship): - children_objects.append( - { - "relationship_name": child_relationship["relationshipName"], - "object_type": child_relationship["childSObject"], - } - ) - return children_objects - - def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]: - object_description = self._get_sf_type_object_json(sf_type) - - fields = [ - field.get("name") - for field in object_description["fields"] - if field.get("type", "base64") != "base64" - ] - - return fields - - def _get_parent_object_ids( - self, parent_sf_type: str, time_filter_query: str - ) -> list[str]: - """Fetch all IDs for a given parent object type.""" - logger.debug(f"Fetching IDs for parent type: {parent_sf_type}") - query = f"SELECT Id FROM {parent_sf_type}{time_filter_query}" - query_result = self.sf_client.query_all(query) - ids = [record["Id"] for record in query_result["records"]] - logger.debug(f"Found {len(ids)} IDs for parent type: {parent_sf_type}") - return ids - - def _process_id_batch( - self, - id_batch: list[str], - queries: list[str], - ) -> dict[str, dict[str, Any]]: - """Process a batch of IDs using the given queries.""" - # Initialize results dictionary for this batch - logger.debug(f"Processing batch of {len(id_batch)} IDs") - query_results: dict[str, dict[str, Any]] = {} - - # For each query, fetch and combine results for the batch - for query in queries: - id_filter = f" WHERE Id IN {tuple(id_batch)}" - batch_query = query + id_filter - logger.debug(f"Executing query with length: {len(batch_query)}") - query_result = self.sf_client.query_all(batch_query) - logger.debug(f"Retrieved {len(query_result['records'])} records for query") - - for record_dict in query_result["records"]: - query_results.setdefault(record_dict["Id"], {}).update(record_dict) - - # Convert results to documents - return query_results - - def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]: - """ - parent_sf_type is a string that represents the Salesforce object type. - This function generates queries that will fetch: - - all the fields of the parent object type - - all the fields of the child objects of the parent object type - """ - logger.debug(f"Generating queries for parent type: {parent_sf_type}") - parent_fields = self._get_all_fields_for_sf_type(parent_sf_type) - logger.debug(f"Found {len(parent_fields)} fields for parent type") - child_sf_types = self._get_all_children_of_sf_type(parent_sf_type) - logger.debug(f"Found {len(child_sf_types)} child types") - - query = f"SELECT {', '.join(parent_fields)}" - for child_object_dict in child_sf_types: - fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"]) - query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})" - - if len(query_addition) + len(query) > _MAX_QUERY_LENGTH: - query += f"\n FROM {parent_sf_type}" - yield query - query = "SELECT Id" + query_addition - else: - query += query_addition - - query += f"\n FROM {parent_sf_type}" - - yield query - - def _batch_retrieval( - self, - id_batches: list[list[str]], - queries: list[str], - ) -> GenerateDocumentsOutput: - doc_batch: list[Document] = [] - # For each batch of IDs, perform all queries and convert to documents - # so they can be yielded in batches - for id_batch in id_batches: - query_results = self._process_id_batch(id_batch, queries) - for doc in query_results.values(): - doc_batch.append(self._convert_object_instance_to_document(doc)) - if len(doc_batch) >= self.batch_size: - yield doc_batch - doc_batch = [] - - yield doc_batch - def _fetch_from_salesforce( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> GenerateDocumentsOutput: - logger.debug(f"Starting Salesforce fetch from {start} to {end}") - time_filter_query = _build_time_filter_for_salesforce(start, end) + init_db() + all_object_types: set[str] = set(self.parent_object_list) + logger.info(f"Starting with {len(self.parent_object_list)} parent object types") + logger.debug(f"Parent object types: {self.parent_object_list}") + + # This takes like 20 seconds for parent_object_type in self.parent_object_list: - logger.info(f"Processing parent object type: {parent_object_type}") + child_types = get_all_children_of_sf_type( + self.sf_client, parent_object_type + ) + all_object_types.update(child_types) + logger.debug( + f"Found {len(child_types)} child types for {parent_object_type}" + ) + + logger.info(f"Found total of {len(all_object_types)} object types to fetch") + logger.debug(f"All object types: {all_object_types}") + + # checkpoint - we've found all object types, now time to fetch the data + logger.info("Starting to fetch CSVs for all object types") + # This takes like 30 minutes first time and <2 minutes for updates + object_type_to_csv_path = fetch_all_csvs_in_parallel( + sf_client=self.sf_client, + object_types=all_object_types, + start=start, + end=end, + ) - all_ids = self._get_parent_object_ids(parent_object_type, time_filter_query) - logger.info(f"Found {len(all_ids)} IDs for {parent_object_type}") - id_batches = batch_list(all_ids, _MAX_ID_BATCH_SIZE) + updated_ids: set[str] = set() + # This takes like 10 seconds + # This is for testing the rest of the functionality if data has + # already been fetched and put in sqlite + # from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type + # for object_type in self.parent_object_list: + # updated_ids.update(list(find_ids_by_type(object_type))) + + # This takes 10-70 minutes first time (idk why the range is so big) + total_types = len(object_type_to_csv_path) + logger.info(f"Starting to process {total_types} object types") + + for i, (object_type, csv_paths) in enumerate( + object_type_to_csv_path.items(), 1 + ): + logger.info(f"Processing object type {object_type} ({i}/{total_types})") + # If path is None, it means it failed to fetch the csv + if csv_paths is None: + continue + # Go through each csv path and use it to update the db + for csv_path in csv_paths: + logger.debug(f"Updating {object_type} with {csv_path}") + new_ids = update_sf_db_with_csv( + object_type=object_type, + csv_download_path=csv_path, + ) + updated_ids.update(new_ids) + logger.debug( + f"Added {len(new_ids)} new/updated records for {object_type}" + ) + # Remove the csv file after it has been used + # to successfully update the db + os.remove(csv_path) - # Generate all queries we'll need - queries = list(self._generate_query_per_parent_type(parent_object_type)) - logger.info(f"Generated {len(queries)} queries for {parent_object_type}") - yield from self._batch_retrieval(id_batches, queries) + logger.info(f"Found {len(updated_ids)} total updated records") + logger.info( + f"Starting to process parent objects of types: {self.parent_object_list}" + ) + + docs_to_yield: list[Document] = [] + docs_processed = 0 + # Takes 15-20 seconds per batch + for parent_type, parent_id_batch in get_affected_parent_ids_by_type( + updated_ids=list(updated_ids), + parent_types=self.parent_object_list, + ): + logger.info( + f"Processing batch of {len(parent_id_batch)} {parent_type} objects" + ) + for parent_id in parent_id_batch: + if not (parent_object := get_record(parent_id, parent_type)): + logger.warning( + f"Failed to get parent object {parent_id} for {parent_type}" + ) + continue + + docs_to_yield.append( + self._convert_object_instance_to_document(parent_object) + ) + docs_processed += 1 + + if len(docs_to_yield) >= self.batch_size: + yield docs_to_yield + docs_to_yield = [] + + yield docs_to_yield def load_from_state(self) -> GenerateDocumentsOutput: return self._fetch_from_salesforce() @@ -305,9 +235,9 @@ def retrieve_all_slim_documents( if __name__ == "__main__": - connector = SalesforceConnector( - requested_objects=os.environ["REQUESTED_OBJECTS"].split(",") - ) + import time + + connector = SalesforceConnector(requested_objects=["Account"]) connector.load_credentials( { @@ -316,5 +246,20 @@ def retrieve_all_slim_documents( "sf_security_token": os.environ["SF_SECURITY_TOKEN"], } ) - document_batches = connector.load_from_state() - print(next(document_batches)) + start_time = time.time() + doc_count = 0 + section_count = 0 + text_count = 0 + for doc_batch in connector.load_from_state(): + doc_count += len(doc_batch) + print(f"doc_count: {doc_count}") + for doc in doc_batch: + section_count += len(doc.sections) + for section in doc.sections: + text_count += len(section.text) + end_time = time.time() + + print(f"Doc count: {doc_count}") + print(f"Section count: {section_count}") + print(f"Text count: {text_count}") + print(f"Time taken: {end_time - start_time}") diff --git a/backend/onyx/connectors/salesforce/doc_conversion.py b/backend/onyx/connectors/salesforce/doc_conversion.py index e0b8b861f74..908b39e80a4 100644 --- a/backend/onyx/connectors/salesforce/doc_conversion.py +++ b/backend/onyx/connectors/salesforce/doc_conversion.py @@ -2,6 +2,7 @@ from collections import OrderedDict from onyx.connectors.models import Section +from onyx.connectors.salesforce.utils import SalesforceObject # All of these types of keys are handled by specific fields in the doc # conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs) @@ -102,6 +103,13 @@ def _extract_dict_text(raw_dict: dict) -> str: return natural_language_for_dict +def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section: + return Section( + text=_extract_dict_text(salesforce_object.data), + link=f"{base_url}/{salesforce_object.id}", + ) + + def _field_value_is_child_object(field_value: dict) -> bool: """ Checks if the field value is a child object. @@ -115,7 +123,7 @@ def _field_value_is_child_object(field_value: dict) -> bool: ) -def extract_sections(salesforce_object: dict, base_url: str) -> list[Section]: +def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]: """ This goes through the salesforce_object and extracts the top level fields as a Section. It also goes through the child objects and extracts them as Sections. diff --git a/backend/onyx/connectors/salesforce/salesforce_calls.py b/backend/onyx/connectors/salesforce/salesforce_calls.py new file mode 100644 index 00000000000..f569b28b873 --- /dev/null +++ b/backend/onyx/connectors/salesforce/salesforce_calls.py @@ -0,0 +1,210 @@ +import os +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any + +from pytz import UTC +from simple_salesforce import Salesforce +from simple_salesforce import SFType +from simple_salesforce.bulk2 import SFBulk2Handler +from simple_salesforce.bulk2 import SFBulk2Type + +from onyx.connectors.interfaces import SecondsSinceUnixEpoch +from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type +from onyx.connectors.salesforce.utils import get_object_type_path +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def _build_time_filter_for_salesforce( + start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None +) -> str: + if start is None or end is None: + return "" + start_datetime = datetime.fromtimestamp(start, UTC) + end_datetime = datetime.fromtimestamp(end, UTC) + return ( + f" WHERE LastModifiedDate > {start_datetime.isoformat()} " + f"AND LastModifiedDate < {end_datetime.isoformat()}" + ) + + +def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any: + sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance) + return sf_object.describe() + + +def _is_valid_child_object( + sf_client: Salesforce, child_relationship: dict[str, Any] +) -> bool: + if not child_relationship["childSObject"]: + return False + if not child_relationship["relationshipName"]: + return False + + sf_type = child_relationship["childSObject"] + object_description = _get_sf_type_object_json(sf_client, sf_type) + if not object_description["queryable"]: + return False + + if child_relationship["field"]: + if child_relationship["field"] == "RelatedToId": + return False + else: + return False + + return True + + +def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]: + object_description = _get_sf_type_object_json(sf_client, sf_type) + + child_object_types = set() + for child_relationship in object_description["childRelationships"]: + if _is_valid_child_object(sf_client, child_relationship): + logger.debug( + f"Found valid child object {child_relationship['childSObject']}" + ) + child_object_types.add(child_relationship["childSObject"]) + return child_object_types + + +def _get_all_queryable_fields_of_sf_type( + sf_client: Salesforce, + sf_type: str, +) -> list[str]: + object_description = _get_sf_type_object_json(sf_client, sf_type) + fields: list[dict[str, Any]] = object_description["fields"] + valid_fields: set[str] = set() + compound_field_names: set[str] = set() + for field in fields: + if compound_field_name := field.get("compoundFieldName"): + compound_field_names.add(compound_field_name) + if field.get("type", "base64") == "base64": + continue + if field_name := field.get("name"): + valid_fields.add(field_name) + + return list(valid_fields - compound_field_names) + + +def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool: + """ + Send a small query to check if the object type is empty so we don't + perform extra bulk queries + """ + try: + query = f"SELECT Count() FROM {sf_type} LIMIT 1" + result = sf_client.query(query) + if result["totalSize"] == 0: + return False + except Exception as e: + if "OPERATION_TOO_LARGE" not in str(e): + logger.warning(f"Object type {sf_type} doesn't support query: {e}") + return False + return True + + +def _check_for_existing_csvs(sf_type: str) -> list[str] | None: + # Check if the csv already exists + if os.path.exists(get_object_type_path(sf_type)): + existing_csvs = [ + os.path.join(get_object_type_path(sf_type), f) + for f in os.listdir(get_object_type_path(sf_type)) + if f.endswith(".csv") + ] + # If the csv already exists, return the path + # This is likely due to a previous run that failed + # after downloading the csv but before the data was + # written to the db + if existing_csvs: + return existing_csvs + return None + + +def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str: + queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type) + query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}" + return query + + +def _bulk_retrieve_from_salesforce( + sf_client: Salesforce, + sf_type: str, + time_filter: str, +) -> tuple[str, list[str] | None]: + if not _check_if_object_type_is_empty(sf_client, sf_type): + return sf_type, None + + if existing_csvs := _check_for_existing_csvs(sf_type): + return sf_type, existing_csvs + + query = _build_bulk_query(sf_client, sf_type, time_filter) + + bulk_2_handler = SFBulk2Handler( + session_id=sf_client.session_id, + bulk2_url=sf_client.bulk2_url, + proxies=sf_client.proxies, + session=sf_client.session, + ) + bulk_2_type = SFBulk2Type( + object_name=sf_type, + bulk2_url=bulk_2_handler.bulk2_url, + headers=bulk_2_handler.headers, + session=bulk_2_handler.session, + ) + + logger.info(f"Downloading {sf_type}") + logger.info(f"Query: {query}") + + try: + # This downloads the file to a file in the target path with a random name + results = bulk_2_type.download( + query=query, + path=get_object_type_path(sf_type), + max_records=1000000, + ) + all_download_paths = [result["file"] for result in results] + logger.info(f"Downloaded {sf_type} to {all_download_paths}") + return sf_type, all_download_paths + except Exception as e: + logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}") + return sf_type, None + + +def fetch_all_csvs_in_parallel( + sf_client: Salesforce, + object_types: set[str], + start: SecondsSinceUnixEpoch | None, + end: SecondsSinceUnixEpoch | None, +) -> dict[str, list[str] | None]: + """ + Fetches all the csvs in parallel for the given object types + Returns a dict of (sf_type, full_download_path) + """ + time_filter = _build_time_filter_for_salesforce(start, end) + time_filter_for_each_object_type = {} + # We do this outside of the thread pool executor because this requires + # a database connection and we don't want to block the thread pool + # executor from running + for sf_type in object_types: + """Only add time filter if there is at least one object of the type + in the database. We aren't worried about partially completed object update runs + because this occurs after we check for existing csvs which covers this case""" + if has_at_least_one_object_of_type(sf_type): + time_filter_for_each_object_type[sf_type] = time_filter + else: + time_filter_for_each_object_type[sf_type] = "" + + # Run the bulk retrieve in parallel + with ThreadPoolExecutor() as executor: + results = executor.map( + lambda object_type: _bulk_retrieve_from_salesforce( + sf_client=sf_client, + sf_type=object_type, + time_filter=time_filter_for_each_object_type[object_type], + ), + object_types, + ) + return dict(results) diff --git a/backend/onyx/connectors/salesforce/shelve_stuff/shelve_functions.py b/backend/onyx/connectors/salesforce/shelve_stuff/shelve_functions.py new file mode 100644 index 00000000000..c57000b4059 --- /dev/null +++ b/backend/onyx/connectors/salesforce/shelve_stuff/shelve_functions.py @@ -0,0 +1,209 @@ +import csv +import shelve + +from onyx.connectors.salesforce.shelve_stuff.shelve_utils import ( + get_child_to_parent_shelf_path, +) +from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path +from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path +from onyx.connectors.salesforce.shelve_stuff.shelve_utils import ( + get_parent_to_child_shelf_path, +) +from onyx.connectors.salesforce.utils import SalesforceObject +from onyx.connectors.salesforce.utils import validate_salesforce_id +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def _update_relationship_shelves( + child_id: str, + parent_ids: set[str], +) -> None: + """Update the relationship shelf when a record is updated.""" + try: + # Convert child_id to string once + str_child_id = str(child_id) + + # First update child to parent mapping + with shelve.open( + get_child_to_parent_shelf_path(), + flag="c", + protocol=None, + writeback=True, + ) as child_to_parent_db: + old_parent_ids = set(child_to_parent_db.get(str_child_id, [])) + child_to_parent_db[str_child_id] = list(parent_ids) + + # Calculate differences outside the next context manager + parent_ids_to_remove = old_parent_ids - parent_ids + parent_ids_to_add = parent_ids - old_parent_ids + + # Only sync once at the end + child_to_parent_db.sync() + + # Then update parent to child mapping in a single transaction + if not parent_ids_to_remove and not parent_ids_to_add: + return + with shelve.open( + get_parent_to_child_shelf_path(), + flag="c", + protocol=None, + writeback=True, + ) as parent_to_child_db: + # Process all removals first + for parent_id in parent_ids_to_remove: + str_parent_id = str(parent_id) + existing_children = set(parent_to_child_db.get(str_parent_id, [])) + if str_child_id in existing_children: + existing_children.remove(str_child_id) + parent_to_child_db[str_parent_id] = list(existing_children) + + # Then process all additions + for parent_id in parent_ids_to_add: + str_parent_id = str(parent_id) + existing_children = set(parent_to_child_db.get(str_parent_id, [])) + existing_children.add(str_child_id) + parent_to_child_db[str_parent_id] = list(existing_children) + + # Single sync at the end + parent_to_child_db.sync() + + except Exception as e: + logger.error(f"Error updating relationship shelves: {e}") + logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}") + raise + + +def get_child_ids(parent_id: str) -> set[str]: + """Get all child IDs for a given parent ID. + + Args: + parent_id: The ID of the parent object + + Returns: + A set of child object IDs + """ + with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db: + return set(parent_to_child_db.get(parent_id, [])) + + +def update_sf_db_with_csv( + object_type: str, + csv_download_path: str, +) -> list[str]: + """Update the SF DB with a CSV file using shelve storage.""" + updated_ids = [] + shelf_path = get_object_shelf_path(object_type) + + # First read the CSV to get all the data + with open(csv_download_path, "r", newline="", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + id = row["Id"] + parent_ids = set() + field_to_remove: set[str] = set() + # Update relationship shelves for any parent references + for field, value in row.items(): + if validate_salesforce_id(value) and field != "Id": + parent_ids.add(value) + field_to_remove.add(field) + if not value: + field_to_remove.add(field) + _update_relationship_shelves(id, parent_ids) + for field in field_to_remove: + # We use this to extract the Primary Owner later + if field != "LastModifiedById": + del row[field] + + # Update the main object shelf + with shelve.open(shelf_path) as object_type_db: + object_type_db[id] = row + # Update the ID-to-type mapping shelf + with shelve.open(get_id_type_shelf_path()) as id_type_db: + id_type_db[id] = object_type + + updated_ids.append(id) + + # os.remove(csv_download_path) + return updated_ids + + +def get_type_from_id(object_id: str) -> str | None: + """Get the type of an object from its ID.""" + # Look up the object type from the ID-to-type mapping + with shelve.open(get_id_type_shelf_path()) as id_type_db: + if object_id not in id_type_db: + logger.warning(f"Object ID {object_id} not found in ID-to-type mapping") + return None + return id_type_db[object_id] + + +def get_record( + object_id: str, object_type: str | None = None +) -> SalesforceObject | None: + """ + Retrieve the record and return it as a SalesforceObject. + The object type will be looked up from the ID-to-type mapping shelf. + """ + if object_type is None: + if not (object_type := get_type_from_id(object_id)): + return None + + shelf_path = get_object_shelf_path(object_type) + with shelve.open(shelf_path) as db: + if object_id not in db: + logger.warning(f"Object ID {object_id} not found in {shelf_path}") + return None + data = db[object_id] + return SalesforceObject( + id=object_id, + type=object_type, + data=data, + ) + + +def find_ids_by_type(object_type: str) -> list[str]: + """ + Find all object IDs for rows of the specified type. + """ + shelf_path = get_object_shelf_path(object_type) + try: + with shelve.open(shelf_path) as db: + return list(db.keys()) + except FileNotFoundError: + return [] + + +def get_affected_parent_ids_by_type( + updated_ids: set[str], parent_types: list[str] +) -> dict[str, set[str]]: + """Get IDs of objects that are of the specified parent types and are either in the updated_ids + or have children in the updated_ids. + + Args: + updated_ids: List of IDs that were updated + parent_types: List of object types to filter by + + Returns: + A dictionary of IDs that match the criteria + """ + affected_ids_by_type: dict[str, set[str]] = {} + + # Check each updated ID + for updated_id in updated_ids: + # Add the ID itself if it's of a parent type + updated_type = get_type_from_id(updated_id) + if updated_type in parent_types: + affected_ids_by_type.setdefault(updated_type, set()).add(updated_id) + continue + + # Get parents of this ID and add them if they're of a parent type + with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db: + parent_ids = child_to_parent_db.get(updated_id, []) + for parent_id in parent_ids: + parent_type = get_type_from_id(parent_id) + if parent_type in parent_types: + affected_ids_by_type.setdefault(parent_type, set()).add(parent_id) + + return affected_ids_by_type diff --git a/backend/onyx/connectors/salesforce/shelve_stuff/shelve_utils.py b/backend/onyx/connectors/salesforce/shelve_stuff/shelve_utils.py new file mode 100644 index 00000000000..0ea937af126 --- /dev/null +++ b/backend/onyx/connectors/salesforce/shelve_stuff/shelve_utils.py @@ -0,0 +1,29 @@ +import os + +from onyx.connectors.salesforce.utils import BASE_DATA_PATH +from onyx.connectors.salesforce.utils import get_object_type_path + + +def get_object_shelf_path(object_type: str) -> str: + """Get the path to the shelf file for a specific object type.""" + base_path = get_object_type_path(object_type) + os.makedirs(base_path, exist_ok=True) + return os.path.join(base_path, "data.shelf") + + +def get_id_type_shelf_path() -> str: + """Get the path to the ID-to-type mapping shelf.""" + os.makedirs(BASE_DATA_PATH, exist_ok=True) + return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g") + + +def get_parent_to_child_shelf_path() -> str: + """Get the path to the parent-to-child mapping shelf.""" + os.makedirs(BASE_DATA_PATH, exist_ok=True) + return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g") + + +def get_child_to_parent_shelf_path() -> str: + """Get the path to the child-to-parent mapping shelf.""" + os.makedirs(BASE_DATA_PATH, exist_ok=True) + return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g") diff --git a/backend/onyx/connectors/salesforce/shelve_stuff/test_salesforce_shelves.py b/backend/onyx/connectors/salesforce/shelve_stuff/test_salesforce_shelves.py new file mode 100644 index 00000000000..74df54d312b --- /dev/null +++ b/backend/onyx/connectors/salesforce/shelve_stuff/test_salesforce_shelves.py @@ -0,0 +1,737 @@ +import csv +import os +import shutil + +from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type +from onyx.connectors.salesforce.shelve_stuff.shelve_functions import ( + get_affected_parent_ids_by_type, +) +from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids +from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record +from onyx.connectors.salesforce.shelve_stuff.shelve_functions import ( + update_sf_db_with_csv, +) +from onyx.connectors.salesforce.utils import BASE_DATA_PATH +from onyx.connectors.salesforce.utils import get_object_type_path + +_VALID_SALESFORCE_IDS = [ + "001bm00000fd9Z3AAI", + "001bm00000fdYTdAAM", + "001bm00000fdYTeAAM", + "001bm00000fdYTfAAM", + "001bm00000fdYTgAAM", + "001bm00000fdYThAAM", + "001bm00000fdYTiAAM", + "001bm00000fdYTjAAM", + "001bm00000fdYTkAAM", + "001bm00000fdYTlAAM", + "001bm00000fdYTmAAM", + "001bm00000fdYTnAAM", + "001bm00000fdYToAAM", + "500bm00000XoOxtAAF", + "500bm00000XoOxuAAF", + "500bm00000XoOxvAAF", + "500bm00000XoOxwAAF", + "500bm00000XoOxxAAF", + "500bm00000XoOxyAAF", + "500bm00000XoOxzAAF", + "500bm00000XoOy0AAF", + "500bm00000XoOy1AAF", + "500bm00000XoOy2AAF", + "500bm00000XoOy3AAF", + "500bm00000XoOy4AAF", + "500bm00000XoOy5AAF", + "500bm00000XoOy6AAF", + "500bm00000XoOy7AAF", + "500bm00000XoOy8AAF", + "500bm00000XoOy9AAF", + "500bm00000XoOyAAAV", + "500bm00000XoOyBAAV", + "500bm00000XoOyCAAV", + "500bm00000XoOyDAAV", + "500bm00000XoOyEAAV", + "500bm00000XoOyFAAV", + "500bm00000XoOyGAAV", + "500bm00000XoOyHAAV", + "500bm00000XoOyIAAV", + "003bm00000EjHCjAAN", + "003bm00000EjHCkAAN", + "003bm00000EjHClAAN", + "003bm00000EjHCmAAN", + "003bm00000EjHCnAAN", + "003bm00000EjHCoAAN", + "003bm00000EjHCpAAN", + "003bm00000EjHCqAAN", + "003bm00000EjHCrAAN", + "003bm00000EjHCsAAN", + "003bm00000EjHCtAAN", + "003bm00000EjHCuAAN", + "003bm00000EjHCvAAN", + "003bm00000EjHCwAAN", + "003bm00000EjHCxAAN", + "003bm00000EjHCyAAN", + "003bm00000EjHCzAAN", + "003bm00000EjHD0AAN", + "003bm00000EjHD1AAN", + "003bm00000EjHD2AAN", + "550bm00000EXc2tAAD", + "006bm000006kyDpAAI", + "006bm000006kyDqAAI", + "006bm000006kyDrAAI", + "006bm000006kyDsAAI", + "006bm000006kyDtAAI", + "006bm000006kyDuAAI", + "006bm000006kyDvAAI", + "006bm000006kyDwAAI", + "006bm000006kyDxAAI", + "006bm000006kyDyAAI", + "006bm000006kyDzAAI", + "006bm000006kyE0AAI", + "006bm000006kyE1AAI", + "006bm000006kyE2AAI", + "006bm000006kyE3AAI", + "006bm000006kyE4AAI", + "006bm000006kyE5AAI", + "006bm000006kyE6AAI", + "006bm000006kyE7AAI", + "006bm000006kyE8AAI", + "006bm000006kyE9AAI", + "006bm000006kyEAAAY", + "006bm000006kyEBAAY", + "006bm000006kyECAAY", + "006bm000006kyEDAAY", + "006bm000006kyEEAAY", + "006bm000006kyEFAAY", + "006bm000006kyEGAAY", + "006bm000006kyEHAAY", + "006bm000006kyEIAAY", + "006bm000006kyEJAAY", + "005bm000009zy0TAAQ", + "005bm000009zy25AAA", + "005bm000009zy26AAA", + "005bm000009zy28AAA", + "005bm000009zy29AAA", + "005bm000009zy2AAAQ", + "005bm000009zy2BAAQ", +] + + +def clear_sf_db() -> None: + """ + Clears the SF DB by deleting all files in the data directory. + """ + shutil.rmtree(BASE_DATA_PATH) + + +def create_csv_file( + object_type: str, records: list[dict], filename: str = "test_data.csv" +) -> None: + """ + Creates a CSV file for the given object type and records. + + Args: + object_type: The Salesforce object type (e.g. "Account", "Contact") + records: List of dictionaries containing the record data + filename: Name of the CSV file to create (default: test_data.csv) + """ + if not records: + return + + # Get all unique fields from records + fields: set[str] = set() + for record in records: + fields.update(record.keys()) + fields = set(sorted(list(fields))) # Sort for consistent order + + # Create CSV file + csv_path = os.path.join(get_object_type_path(object_type), filename) + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fields) + writer.writeheader() + for record in records: + writer.writerow(record) + + # Update the database with the CSV + update_sf_db_with_csv(object_type, csv_path) + + +def create_csv_with_example_data() -> None: + """ + Creates CSV files with example data, organized by object type. + """ + example_data: dict[str, list[dict]] = { + "Account": [ + { + "Id": _VALID_SALESFORCE_IDS[0], + "Name": "Acme Inc.", + "BillingCity": "New York", + "Industry": "Technology", + }, + { + "Id": _VALID_SALESFORCE_IDS[1], + "Name": "Globex Corp", + "BillingCity": "Los Angeles", + "Industry": "Manufacturing", + }, + { + "Id": _VALID_SALESFORCE_IDS[2], + "Name": "Initech", + "BillingCity": "Austin", + "Industry": "Software", + }, + { + "Id": _VALID_SALESFORCE_IDS[3], + "Name": "TechCorp Solutions", + "BillingCity": "San Francisco", + "Industry": "Software", + "AnnualRevenue": 5000000, + }, + { + "Id": _VALID_SALESFORCE_IDS[4], + "Name": "BioMed Research", + "BillingCity": "Boston", + "Industry": "Healthcare", + "AnnualRevenue": 12000000, + }, + { + "Id": _VALID_SALESFORCE_IDS[5], + "Name": "Green Energy Co", + "BillingCity": "Portland", + "Industry": "Energy", + "AnnualRevenue": 8000000, + }, + { + "Id": _VALID_SALESFORCE_IDS[6], + "Name": "DataFlow Analytics", + "BillingCity": "Seattle", + "Industry": "Technology", + "AnnualRevenue": 3000000, + }, + { + "Id": _VALID_SALESFORCE_IDS[7], + "Name": "Cloud Nine Services", + "BillingCity": "Denver", + "Industry": "Cloud Computing", + "AnnualRevenue": 7000000, + }, + ], + "Contact": [ + { + "Id": _VALID_SALESFORCE_IDS[40], + "FirstName": "John", + "LastName": "Doe", + "Email": "john.doe@acme.com", + "Title": "CEO", + }, + { + "Id": _VALID_SALESFORCE_IDS[41], + "FirstName": "Jane", + "LastName": "Smith", + "Email": "jane.smith@acme.com", + "Title": "CTO", + }, + { + "Id": _VALID_SALESFORCE_IDS[42], + "FirstName": "Bob", + "LastName": "Johnson", + "Email": "bob.j@globex.com", + "Title": "Sales Director", + }, + { + "Id": _VALID_SALESFORCE_IDS[43], + "FirstName": "Sarah", + "LastName": "Chen", + "Email": "sarah.chen@techcorp.com", + "Title": "Product Manager", + "Phone": "415-555-0101", + }, + { + "Id": _VALID_SALESFORCE_IDS[44], + "FirstName": "Michael", + "LastName": "Rodriguez", + "Email": "m.rodriguez@biomed.com", + "Title": "Research Director", + "Phone": "617-555-0202", + }, + { + "Id": _VALID_SALESFORCE_IDS[45], + "FirstName": "Emily", + "LastName": "Green", + "Email": "emily.g@greenenergy.com", + "Title": "Sustainability Lead", + "Phone": "503-555-0303", + }, + { + "Id": _VALID_SALESFORCE_IDS[46], + "FirstName": "David", + "LastName": "Kim", + "Email": "david.kim@dataflow.com", + "Title": "Data Scientist", + "Phone": "206-555-0404", + }, + { + "Id": _VALID_SALESFORCE_IDS[47], + "FirstName": "Rachel", + "LastName": "Taylor", + "Email": "r.taylor@cloudnine.com", + "Title": "Cloud Architect", + "Phone": "303-555-0505", + }, + ], + "Opportunity": [ + { + "Id": _VALID_SALESFORCE_IDS[62], + "Name": "Acme Server Upgrade", + "Amount": 50000, + "Stage": "Prospecting", + "CloseDate": "2024-06-30", + }, + { + "Id": _VALID_SALESFORCE_IDS[63], + "Name": "Globex Manufacturing Line", + "Amount": 150000, + "Stage": "Negotiation", + "CloseDate": "2024-03-15", + }, + { + "Id": _VALID_SALESFORCE_IDS[64], + "Name": "Initech Software License", + "Amount": 75000, + "Stage": "Closed Won", + "CloseDate": "2024-01-30", + }, + { + "Id": _VALID_SALESFORCE_IDS[65], + "Name": "TechCorp AI Implementation", + "Amount": 250000, + "Stage": "Needs Analysis", + "CloseDate": "2024-08-15", + "Probability": 60, + }, + { + "Id": _VALID_SALESFORCE_IDS[66], + "Name": "BioMed Lab Equipment", + "Amount": 500000, + "Stage": "Value Proposition", + "CloseDate": "2024-09-30", + "Probability": 75, + }, + { + "Id": _VALID_SALESFORCE_IDS[67], + "Name": "Green Energy Solar Project", + "Amount": 750000, + "Stage": "Proposal", + "CloseDate": "2024-07-15", + "Probability": 80, + }, + { + "Id": _VALID_SALESFORCE_IDS[68], + "Name": "DataFlow Analytics Platform", + "Amount": 180000, + "Stage": "Negotiation", + "CloseDate": "2024-05-30", + "Probability": 90, + }, + { + "Id": _VALID_SALESFORCE_IDS[69], + "Name": "Cloud Nine Infrastructure", + "Amount": 300000, + "Stage": "Qualification", + "CloseDate": "2024-10-15", + "Probability": 40, + }, + ], + } + + # Create CSV files for each object type + for object_type, records in example_data.items(): + create_csv_file(object_type, records) + + +def test_query() -> None: + """ + Tests querying functionality by verifying: + 1. All expected Account IDs are found + 2. Each Account's data matches what was inserted + """ + # Expected test data for verification + expected_accounts: dict[str, dict[str, str | int]] = { + _VALID_SALESFORCE_IDS[0]: { + "Name": "Acme Inc.", + "BillingCity": "New York", + "Industry": "Technology", + }, + _VALID_SALESFORCE_IDS[1]: { + "Name": "Globex Corp", + "BillingCity": "Los Angeles", + "Industry": "Manufacturing", + }, + _VALID_SALESFORCE_IDS[2]: { + "Name": "Initech", + "BillingCity": "Austin", + "Industry": "Software", + }, + _VALID_SALESFORCE_IDS[3]: { + "Name": "TechCorp Solutions", + "BillingCity": "San Francisco", + "Industry": "Software", + "AnnualRevenue": 5000000, + }, + _VALID_SALESFORCE_IDS[4]: { + "Name": "BioMed Research", + "BillingCity": "Boston", + "Industry": "Healthcare", + "AnnualRevenue": 12000000, + }, + _VALID_SALESFORCE_IDS[5]: { + "Name": "Green Energy Co", + "BillingCity": "Portland", + "Industry": "Energy", + "AnnualRevenue": 8000000, + }, + _VALID_SALESFORCE_IDS[6]: { + "Name": "DataFlow Analytics", + "BillingCity": "Seattle", + "Industry": "Technology", + "AnnualRevenue": 3000000, + }, + _VALID_SALESFORCE_IDS[7]: { + "Name": "Cloud Nine Services", + "BillingCity": "Denver", + "Industry": "Cloud Computing", + "AnnualRevenue": 7000000, + }, + } + + # Get all Account IDs + account_ids = find_ids_by_type("Account") + + # Verify we found all expected accounts + assert len(account_ids) == len( + expected_accounts + ), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}" + assert set(account_ids) == set( + expected_accounts.keys() + ), "Found account IDs don't match expected IDs" + + # Verify each account's data + for acc_id in account_ids: + combined = get_record(acc_id) + assert combined is not None, f"Could not find account {acc_id}" + + expected = expected_accounts[acc_id] + + # Verify account data matches + for key, value in expected.items(): + value = str(value) + assert ( + combined.data[key] == value + ), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}" + + print("All query tests passed successfully!") + + +def test_upsert() -> None: + """ + Tests upsert functionality by: + 1. Updating an existing account + 2. Creating a new account + 3. Verifying both operations were successful + """ + # Create CSV for updating an existing account and adding a new one + update_data: list[dict[str, str | int]] = [ + { + "Id": _VALID_SALESFORCE_IDS[0], + "Name": "Acme Inc. Updated", + "BillingCity": "New York", + "Industry": "Technology", + "Description": "Updated company info", + }, + { + "Id": _VALID_SALESFORCE_IDS[2], + "Name": "New Company Inc.", + "BillingCity": "Miami", + "Industry": "Finance", + "AnnualRevenue": 1000000, + }, + ] + + create_csv_file("Account", update_data, "update_data.csv") + + # Verify the update worked + updated_record = get_record(_VALID_SALESFORCE_IDS[0]) + assert updated_record is not None, "Updated record not found" + assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated" + assert ( + updated_record.data["Description"] == "Updated company info" + ), "Description not added" + + # Verify the new record was created + new_record = get_record(_VALID_SALESFORCE_IDS[2]) + assert new_record is not None, "New record not found" + assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect" + assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect" + + print("All upsert tests passed successfully!") + + +def test_relationships() -> None: + """ + Tests relationship shelf updates and queries by: + 1. Creating test data with relationships + 2. Verifying the relationships are correctly stored + 3. Testing relationship queries + """ + # Create test data for each object type + test_data: dict[str, list[dict[str, str | int]]] = { + "Case": [ + { + "Id": _VALID_SALESFORCE_IDS[13], + "AccountId": _VALID_SALESFORCE_IDS[0], + "Subject": "Test Case 1", + }, + { + "Id": _VALID_SALESFORCE_IDS[14], + "AccountId": _VALID_SALESFORCE_IDS[0], + "Subject": "Test Case 2", + }, + ], + "Contact": [ + { + "Id": _VALID_SALESFORCE_IDS[48], + "AccountId": _VALID_SALESFORCE_IDS[0], + "FirstName": "Test", + "LastName": "Contact", + } + ], + "Opportunity": [ + { + "Id": _VALID_SALESFORCE_IDS[62], + "AccountId": _VALID_SALESFORCE_IDS[0], + "Name": "Test Opportunity", + "Amount": 100000, + } + ], + } + + # Create and update CSV files for each object type + for object_type, records in test_data.items(): + create_csv_file(object_type, records, "relationship_test.csv") + + # Test relationship queries + # All these objects should be children of Acme Inc. + child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0]) + assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}" + assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship" + assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship" + assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship" + assert ( + _VALID_SALESFORCE_IDS[62] in child_ids + ), "Opportunity not found in relationship" + + # Test querying relationships for a different account (should be empty) + other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1]) + assert ( + len(other_account_children) == 0 + ), "Expected no children for different account" + + print("All relationship tests passed successfully!") + + +def test_account_with_children() -> None: + """ + Tests querying all accounts and retrieving their child objects. + This test verifies that: + 1. All accounts can be retrieved + 2. Child objects are correctly linked + 3. Child object data is complete and accurate + """ + # First get all account IDs + account_ids = find_ids_by_type("Account") + assert len(account_ids) > 0, "No accounts found" + + # For each account, get its children and verify the data + for account_id in account_ids: + account = get_record(account_id) + assert account is not None, f"Could not find account {account_id}" + + # Get all child objects + child_ids = get_child_ids(account_id) + + # For Acme Inc., verify specific relationships + if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc. + assert ( + len(child_ids) == 4 + ), f"Expected 4 children for Acme Inc., found {len(child_ids)}" + + # Get all child records + child_records = [] + for child_id in child_ids: + child_record = get_record(child_id) + if child_record is not None: + child_records.append(child_record) + # Verify Cases + cases = [r for r in child_records if r.type == "Case"] + assert ( + len(cases) == 2 + ), f"Expected 2 cases for Acme Inc., found {len(cases)}" + case_subjects = {case.data["Subject"] for case in cases} + assert "Test Case 1" in case_subjects, "Test Case 1 not found" + assert "Test Case 2" in case_subjects, "Test Case 2 not found" + + # Verify Contacts + contacts = [r for r in child_records if r.type == "Contact"] + assert ( + len(contacts) == 1 + ), f"Expected 1 contact for Acme Inc., found {len(contacts)}" + contact = contacts[0] + assert contact.data["FirstName"] == "Test", "Contact first name mismatch" + assert contact.data["LastName"] == "Contact", "Contact last name mismatch" + + # Verify Opportunities + opportunities = [r for r in child_records if r.type == "Opportunity"] + assert ( + len(opportunities) == 1 + ), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}" + opportunity = opportunities[0] + assert ( + opportunity.data["Name"] == "Test Opportunity" + ), "Opportunity name mismatch" + assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch" + + print("All account with children tests passed successfully!") + + +def test_relationship_updates() -> None: + """ + Tests that relationships are properly updated when a child object's parent reference changes. + This test verifies: + 1. Initial relationship is created correctly + 2. When parent reference is updated, old relationship is removed + 3. New relationship is created correctly + """ + # Create initial test data - Contact linked to Acme Inc. + initial_contact = [ + { + "Id": _VALID_SALESFORCE_IDS[40], + "AccountId": _VALID_SALESFORCE_IDS[0], + "FirstName": "Test", + "LastName": "Contact", + } + ] + create_csv_file("Contact", initial_contact, "initial_contact.csv") + + # Verify initial relationship + acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0]) + assert ( + _VALID_SALESFORCE_IDS[40] in acme_children + ), "Initial relationship not created" + + # Update contact to be linked to Globex Corp instead + updated_contact = [ + { + "Id": _VALID_SALESFORCE_IDS[40], + "AccountId": _VALID_SALESFORCE_IDS[1], + "FirstName": "Test", + "LastName": "Contact", + } + ] + create_csv_file("Contact", updated_contact, "updated_contact.csv") + + # Verify old relationship is removed + acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0]) + assert ( + _VALID_SALESFORCE_IDS[40] not in acme_children + ), "Old relationship not removed" + + # Verify new relationship is created + globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1]) + assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created" + + print("All relationship update tests passed successfully!") + + +def test_get_affected_parent_ids() -> None: + """ + Tests get_affected_parent_ids functionality by verifying: + 1. IDs that are directly in the parent_types list are included + 2. IDs that have children in the updated_ids list are included + 3. IDs that are neither of the above are not included + """ + # Create test data with relationships + test_data = { + "Account": [ + { + "Id": _VALID_SALESFORCE_IDS[0], + "Name": "Parent Account 1", + }, + { + "Id": _VALID_SALESFORCE_IDS[1], + "Name": "Parent Account 2", + }, + { + "Id": _VALID_SALESFORCE_IDS[2], + "Name": "Not Affected Account", + }, + ], + "Contact": [ + { + "Id": _VALID_SALESFORCE_IDS[40], + "AccountId": _VALID_SALESFORCE_IDS[0], + "FirstName": "Child", + "LastName": "Contact", + } + ], + } + + # Create and update CSV files for test data + for object_type, records in test_data.items(): + create_csv_file(object_type, records) + + # Test Case 1: Account directly in updated_ids and parent_types + updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2 + parent_types = ["Account"] + affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) + assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included" + + # Test Case 2: Account with child in updated_ids + updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact + parent_types = ["Account"] + affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) + assert ( + _VALID_SALESFORCE_IDS[0] in affected_ids + ), "Parent of updated child not included" + + # Test Case 3: Both direct and indirect affects + updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases + parent_types = ["Account"] + affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) + assert len(affected_ids) == 2, "Expected exactly two affected parent IDs" + assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included" + assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included" + assert ( + _VALID_SALESFORCE_IDS[2] not in affected_ids + ), "Unaffected ID incorrectly included" + + # Test Case 4: No matches + updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact + parent_types = ["Opportunity"] # Wrong type + affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) + assert len(affected_ids) == 0, "Should return empty list when no matches" + + print("All get_affected_parent_ids tests passed successfully!") + + +def main_build() -> None: + clear_sf_db() + create_csv_with_example_data() + test_query() + test_upsert() + test_relationships() + test_account_with_children() + test_relationship_updates() + test_get_affected_parent_ids() + + +if __name__ == "__main__": + main_build() diff --git a/backend/onyx/connectors/salesforce/sqlite_functions.py b/backend/onyx/connectors/salesforce/sqlite_functions.py new file mode 100644 index 00000000000..eb8d72ba4f3 --- /dev/null +++ b/backend/onyx/connectors/salesforce/sqlite_functions.py @@ -0,0 +1,386 @@ +import csv +import json +import os +import sqlite3 +from collections.abc import Iterator +from contextlib import contextmanager + +from onyx.connectors.salesforce.utils import get_sqlite_db_path +from onyx.connectors.salesforce.utils import SalesforceObject +from onyx.connectors.salesforce.utils import validate_salesforce_id +from onyx.utils.logger import setup_logger +from shared_configs.utils import batch_list + +logger = setup_logger() + + +@contextmanager +def get_db_connection( + isolation_level: str | None = None, +) -> Iterator[sqlite3.Connection]: + """Get a database connection with proper isolation level and error handling. + + Args: + isolation_level: SQLite isolation level. None = default "DEFERRED", + can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation. + """ + # 60 second timeout for locks + conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0) + + if isolation_level is not None: + conn.isolation_level = isolation_level + try: + yield conn + except Exception: + conn.rollback() + raise + finally: + conn.close() + + +def init_db() -> None: + """Initialize the SQLite database with required tables if they don't exist.""" + if os.path.exists(get_sqlite_db_path()): + return + + # Create database directory if it doesn't exist + os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True) + + with get_db_connection("EXCLUSIVE") as conn: + cursor = conn.cursor() + + # Enable WAL mode for better concurrent access and write performance + cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.execute("PRAGMA temp_store=MEMORY") + cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache + + # Main table for storing Salesforce objects + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS salesforce_objects ( + id TEXT PRIMARY KEY, + object_type TEXT NOT NULL, + data TEXT NOT NULL, -- JSON serialized data + last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management + ) WITHOUT ROWID -- Optimize for primary key lookups + """ + ) + + # Table for parent-child relationships with covering index + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS relationships ( + child_id TEXT NOT NULL, + parent_id TEXT NOT NULL, + PRIMARY KEY (child_id, parent_id) + ) WITHOUT ROWID -- Optimize for primary key lookups + """ + ) + + # New table for caching parent-child relationships with object types + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS relationship_types ( + child_id TEXT NOT NULL, + parent_id TEXT NOT NULL, + parent_type TEXT NOT NULL, + PRIMARY KEY (child_id, parent_id, parent_type) + ) WITHOUT ROWID + """ + ) + + # Always recreate indexes to ensure they exist + cursor.execute("DROP INDEX IF EXISTS idx_object_type") + cursor.execute("DROP INDEX IF EXISTS idx_parent_id") + cursor.execute("DROP INDEX IF EXISTS idx_child_parent") + cursor.execute("DROP INDEX IF EXISTS idx_object_type_id") + cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup") + + # Create covering indexes for common queries + cursor.execute( + """ + CREATE INDEX idx_object_type + ON salesforce_objects(object_type, id) + WHERE object_type IS NOT NULL + """ + ) + + cursor.execute( + """ + CREATE INDEX idx_parent_id + ON relationships(parent_id, child_id) + """ + ) + + cursor.execute( + """ + CREATE INDEX idx_child_parent + ON relationships(child_id) + WHERE child_id IS NOT NULL + """ + ) + + # New composite index for fast parent type lookups + cursor.execute( + """ + CREATE INDEX idx_relationship_types_lookup + ON relationship_types(parent_type, child_id, parent_id) + """ + ) + + # Analyze tables to help query planner + cursor.execute("ANALYZE relationships") + cursor.execute("ANALYZE salesforce_objects") + cursor.execute("ANALYZE relationship_types") + + conn.commit() + + +def _update_relationship_tables( + conn: sqlite3.Connection, child_id: str, parent_ids: set[str] +) -> None: + """Update the relationship tables when a record is updated. + + Args: + conn: The database connection to use (must be in a transaction) + child_id: The ID of the child record + parent_ids: Set of parent IDs to link to + """ + try: + cursor = conn.cursor() + + # Get existing parent IDs + cursor.execute( + "SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,) + ) + old_parent_ids = {row[0] for row in cursor.fetchall()} + + # Calculate differences + parent_ids_to_remove = old_parent_ids - parent_ids + parent_ids_to_add = parent_ids - old_parent_ids + + # Remove old relationships + if parent_ids_to_remove: + cursor.executemany( + "DELETE FROM relationships WHERE child_id = ? AND parent_id = ?", + [(child_id, pid) for pid in parent_ids_to_remove], + ) + # Also remove from relationship_types + cursor.executemany( + "DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?", + [(child_id, pid) for pid in parent_ids_to_remove], + ) + + # Add new relationships + if parent_ids_to_add: + # First add to relationships table + cursor.executemany( + "INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)", + [(child_id, pid) for pid in parent_ids_to_add], + ) + + # Then get the types of the parent objects and add to relationship_types + for parent_id in parent_ids_to_add: + cursor.execute( + "SELECT object_type FROM salesforce_objects WHERE id = ?", + (parent_id,), + ) + result = cursor.fetchone() + if result: + parent_type = result[0] + cursor.execute( + """ + INSERT INTO relationship_types (child_id, parent_id, parent_type) + VALUES (?, ?, ?) + """, + (child_id, parent_id, parent_type), + ) + + except Exception as e: + logger.error(f"Error updating relationship tables: {e}") + logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}") + raise + + +def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]: + """Update the SF DB with a CSV file using SQLite storage.""" + updated_ids = [] + + # Use IMMEDIATE to get a write lock at the start of the transaction + with get_db_connection("IMMEDIATE") as conn: + cursor = conn.cursor() + + with open(csv_download_path, "r", newline="", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + if "Id" not in row: + logger.warning( + f"Row {row} does not have an Id field in {csv_download_path}" + ) + continue + id = row["Id"] + parent_ids = set() + field_to_remove: set[str] = set() + + # Process relationships and clean data + for field, value in row.items(): + if validate_salesforce_id(value) and field != "Id": + parent_ids.add(value) + field_to_remove.add(field) + if not value: + field_to_remove.add(field) + + # Remove unwanted fields + for field in field_to_remove: + if field != "LastModifiedById": + del row[field] + + # Update main object data + cursor.execute( + """ + INSERT OR REPLACE INTO salesforce_objects (id, object_type, data) + VALUES (?, ?, ?) + """, + (id, object_type, json.dumps(row)), + ) + + # Update relationships using the same connection + _update_relationship_tables(conn, id, parent_ids) + updated_ids.append(id) + + conn.commit() + + return updated_ids + + +def get_child_ids(parent_id: str) -> set[str]: + """Get all child IDs for a given parent ID.""" + with get_db_connection() as conn: + cursor = conn.cursor() + + # Force index usage with INDEXED BY + cursor.execute( + "SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?", + (parent_id,), + ) + child_ids = {row[0] for row in cursor.fetchall()} + return child_ids + + +def get_type_from_id(object_id: str) -> str | None: + """Get the type of an object from its ID.""" + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,) + ) + result = cursor.fetchone() + if not result: + logger.warning(f"Object ID {object_id} not found") + return None + return result[0] + + +def get_record( + object_id: str, object_type: str | None = None +) -> SalesforceObject | None: + """Retrieve the record and return it as a SalesforceObject.""" + if object_type is None: + object_type = get_type_from_id(object_id) + if not object_type: + return None + + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)) + result = cursor.fetchone() + if not result: + logger.warning(f"Object ID {object_id} not found") + return None + + data = json.loads(result[0]) + return SalesforceObject(id=object_id, type=object_type, data=data) + + +def find_ids_by_type(object_type: str) -> list[str]: + """Find all object IDs for rows of the specified type.""" + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,) + ) + return [row[0] for row in cursor.fetchall()] + + +def get_affected_parent_ids_by_type( + updated_ids: list[str], + parent_types: list[str], + batch_size: int = 500, +) -> Iterator[tuple[str, set[str]]]: + """Get IDs of objects that are of the specified parent types and are either in the + updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids). + """ + # SQLite typically has a limit of 999 variables + updated_ids_batches = batch_list(updated_ids, batch_size) + updated_parent_ids: set[str] = set() + + with get_db_connection() as conn: + cursor = conn.cursor() + + for batch_ids in updated_ids_batches: + id_placeholders = ",".join(["?" for _ in batch_ids]) + + for parent_type in parent_types: + affected_ids: set[str] = set() + + # Get directly updated objects of parent types - using index on object_type + cursor.execute( + f""" + SELECT id FROM salesforce_objects + WHERE id IN ({id_placeholders}) + AND object_type = ? + """, + batch_ids + [parent_type], + ) + affected_ids.update(row[0] for row in cursor.fetchall()) + + # Get parent objects of updated objects - using optimized relationship_types table + cursor.execute( + f""" + SELECT DISTINCT parent_id + FROM relationship_types + INDEXED BY idx_relationship_types_lookup + WHERE parent_type = ? + AND child_id IN ({id_placeholders}) + """, + [parent_type] + batch_ids, + ) + affected_ids.update(row[0] for row in cursor.fetchall()) + + # Remove any parent IDs that have already been processed + new_affected_ids = affected_ids - updated_parent_ids + # Add the new affected IDs to the set of updated parent IDs + updated_parent_ids.update(new_affected_ids) + + if new_affected_ids: + yield parent_type, new_affected_ids + + +def has_at_least_one_object_of_type(object_type: str) -> bool: + """Check if there is at least one object of the specified type in the database. + + Args: + object_type: The Salesforce object type to check + + Returns: + bool: True if at least one object exists, False otherwise + """ + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?", + (object_type,), + ) + count = cursor.fetchone()[0] + return count > 0 diff --git a/backend/onyx/connectors/salesforce/utils.py b/backend/onyx/connectors/salesforce/utils.py new file mode 100644 index 00000000000..11ada9d019a --- /dev/null +++ b/backend/onyx/connectors/salesforce/utils.py @@ -0,0 +1,72 @@ +import os +from dataclasses import dataclass +from typing import Any + + +@dataclass +class SalesforceObject: + id: str + type: str + data: dict[str, Any] + + def to_dict(self) -> dict[str, Any]: + return { + "ID": self.id, + "Type": self.type, + "Data": self.data, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject": + return cls( + id=data["Id"], + type=data["Type"], + data=data, + ) + + +# This defines the base path for all data files relative to this file +# AKA BE CAREFUL WHEN MOVING THIS FILE +BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data") + + +def get_sqlite_db_path() -> str: + """Get the path to the sqlite db file.""" + return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite") + + +def get_object_type_path(object_type: str) -> str: + """Get the directory path for a specific object type.""" + type_dir = os.path.join(BASE_DATA_PATH, object_type) + os.makedirs(type_dir, exist_ok=True) + return type_dir + + +_CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" +_LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)} + + +def validate_salesforce_id(salesforce_id: str) -> bool: + """Validate the checksum portion of an 18-character Salesforce ID. + + Args: + salesforce_id: An 18-character Salesforce ID + + Returns: + bool: True if the checksum is valid, False otherwise + """ + if len(salesforce_id) != 18: + return False + + chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]] + + checksum = salesforce_id[15:18] + calculated_checksum = "" + + for chunk in chunks: + result_string = "".join( + "1" if char.isupper() else "0" for char in reversed(chunk) + ) + calculated_checksum += _LOOKUP[result_string] + + return checksum == calculated_checksum diff --git a/backend/tests/unit/onyx/connectors/salesforce/test_salesforce_sqlite.py b/backend/tests/unit/onyx/connectors/salesforce/test_salesforce_sqlite.py new file mode 100644 index 00000000000..3afc4f11731 --- /dev/null +++ b/backend/tests/unit/onyx/connectors/salesforce/test_salesforce_sqlite.py @@ -0,0 +1,746 @@ +import csv +import os +import shutil + +from onyx.connectors.salesforce.sqlite_functions import find_ids_by_type +from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type +from onyx.connectors.salesforce.sqlite_functions import get_child_ids +from onyx.connectors.salesforce.sqlite_functions import get_record +from onyx.connectors.salesforce.sqlite_functions import init_db +from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv +from onyx.connectors.salesforce.utils import BASE_DATA_PATH +from onyx.connectors.salesforce.utils import get_object_type_path + +_VALID_SALESFORCE_IDS = [ + "001bm00000fd9Z3AAI", + "001bm00000fdYTdAAM", + "001bm00000fdYTeAAM", + "001bm00000fdYTfAAM", + "001bm00000fdYTgAAM", + "001bm00000fdYThAAM", + "001bm00000fdYTiAAM", + "001bm00000fdYTjAAM", + "001bm00000fdYTkAAM", + "001bm00000fdYTlAAM", + "001bm00000fdYTmAAM", + "001bm00000fdYTnAAM", + "001bm00000fdYToAAM", + "500bm00000XoOxtAAF", + "500bm00000XoOxuAAF", + "500bm00000XoOxvAAF", + "500bm00000XoOxwAAF", + "500bm00000XoOxxAAF", + "500bm00000XoOxyAAF", + "500bm00000XoOxzAAF", + "500bm00000XoOy0AAF", + "500bm00000XoOy1AAF", + "500bm00000XoOy2AAF", + "500bm00000XoOy3AAF", + "500bm00000XoOy4AAF", + "500bm00000XoOy5AAF", + "500bm00000XoOy6AAF", + "500bm00000XoOy7AAF", + "500bm00000XoOy8AAF", + "500bm00000XoOy9AAF", + "500bm00000XoOyAAAV", + "500bm00000XoOyBAAV", + "500bm00000XoOyCAAV", + "500bm00000XoOyDAAV", + "500bm00000XoOyEAAV", + "500bm00000XoOyFAAV", + "500bm00000XoOyGAAV", + "500bm00000XoOyHAAV", + "500bm00000XoOyIAAV", + "003bm00000EjHCjAAN", + "003bm00000EjHCkAAN", + "003bm00000EjHClAAN", + "003bm00000EjHCmAAN", + "003bm00000EjHCnAAN", + "003bm00000EjHCoAAN", + "003bm00000EjHCpAAN", + "003bm00000EjHCqAAN", + "003bm00000EjHCrAAN", + "003bm00000EjHCsAAN", + "003bm00000EjHCtAAN", + "003bm00000EjHCuAAN", + "003bm00000EjHCvAAN", + "003bm00000EjHCwAAN", + "003bm00000EjHCxAAN", + "003bm00000EjHCyAAN", + "003bm00000EjHCzAAN", + "003bm00000EjHD0AAN", + "003bm00000EjHD1AAN", + "003bm00000EjHD2AAN", + "550bm00000EXc2tAAD", + "006bm000006kyDpAAI", + "006bm000006kyDqAAI", + "006bm000006kyDrAAI", + "006bm000006kyDsAAI", + "006bm000006kyDtAAI", + "006bm000006kyDuAAI", + "006bm000006kyDvAAI", + "006bm000006kyDwAAI", + "006bm000006kyDxAAI", + "006bm000006kyDyAAI", + "006bm000006kyDzAAI", + "006bm000006kyE0AAI", + "006bm000006kyE1AAI", + "006bm000006kyE2AAI", + "006bm000006kyE3AAI", + "006bm000006kyE4AAI", + "006bm000006kyE5AAI", + "006bm000006kyE6AAI", + "006bm000006kyE7AAI", + "006bm000006kyE8AAI", + "006bm000006kyE9AAI", + "006bm000006kyEAAAY", + "006bm000006kyEBAAY", + "006bm000006kyECAAY", + "006bm000006kyEDAAY", + "006bm000006kyEEAAY", + "006bm000006kyEFAAY", + "006bm000006kyEGAAY", + "006bm000006kyEHAAY", + "006bm000006kyEIAAY", + "006bm000006kyEJAAY", + "005bm000009zy0TAAQ", + "005bm000009zy25AAA", + "005bm000009zy26AAA", + "005bm000009zy28AAA", + "005bm000009zy29AAA", + "005bm000009zy2AAAQ", + "005bm000009zy2BAAQ", +] + + +def _clear_sf_db() -> None: + """ + Clears the SF DB by deleting all files in the data directory. + """ + shutil.rmtree(BASE_DATA_PATH, ignore_errors=True) + + +def _create_csv_file( + object_type: str, records: list[dict], filename: str = "test_data.csv" +) -> None: + """ + Creates a CSV file for the given object type and records. + + Args: + object_type: The Salesforce object type (e.g. "Account", "Contact") + records: List of dictionaries containing the record data + filename: Name of the CSV file to create (default: test_data.csv) + """ + if not records: + return + + # Get all unique fields from records + fields: set[str] = set() + for record in records: + fields.update(record.keys()) + fields = set(sorted(list(fields))) # Sort for consistent order + + # Create CSV file + csv_path = os.path.join(get_object_type_path(object_type), filename) + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fields) + writer.writeheader() + for record in records: + writer.writerow(record) + + # Update the database with the CSV + update_sf_db_with_csv(object_type, csv_path) + + +def _create_csv_with_example_data() -> None: + """ + Creates CSV files with example data, organized by object type. + """ + example_data: dict[str, list[dict]] = { + "Account": [ + { + "Id": _VALID_SALESFORCE_IDS[0], + "Name": "Acme Inc.", + "BillingCity": "New York", + "Industry": "Technology", + }, + { + "Id": _VALID_SALESFORCE_IDS[1], + "Name": "Globex Corp", + "BillingCity": "Los Angeles", + "Industry": "Manufacturing", + }, + { + "Id": _VALID_SALESFORCE_IDS[2], + "Name": "Initech", + "BillingCity": "Austin", + "Industry": "Software", + }, + { + "Id": _VALID_SALESFORCE_IDS[3], + "Name": "TechCorp Solutions", + "BillingCity": "San Francisco", + "Industry": "Software", + "AnnualRevenue": 5000000, + }, + { + "Id": _VALID_SALESFORCE_IDS[4], + "Name": "BioMed Research", + "BillingCity": "Boston", + "Industry": "Healthcare", + "AnnualRevenue": 12000000, + }, + { + "Id": _VALID_SALESFORCE_IDS[5], + "Name": "Green Energy Co", + "BillingCity": "Portland", + "Industry": "Energy", + "AnnualRevenue": 8000000, + }, + { + "Id": _VALID_SALESFORCE_IDS[6], + "Name": "DataFlow Analytics", + "BillingCity": "Seattle", + "Industry": "Technology", + "AnnualRevenue": 3000000, + }, + { + "Id": _VALID_SALESFORCE_IDS[7], + "Name": "Cloud Nine Services", + "BillingCity": "Denver", + "Industry": "Cloud Computing", + "AnnualRevenue": 7000000, + }, + ], + "Contact": [ + { + "Id": _VALID_SALESFORCE_IDS[40], + "FirstName": "John", + "LastName": "Doe", + "Email": "john.doe@acme.com", + "Title": "CEO", + }, + { + "Id": _VALID_SALESFORCE_IDS[41], + "FirstName": "Jane", + "LastName": "Smith", + "Email": "jane.smith@acme.com", + "Title": "CTO", + }, + { + "Id": _VALID_SALESFORCE_IDS[42], + "FirstName": "Bob", + "LastName": "Johnson", + "Email": "bob.j@globex.com", + "Title": "Sales Director", + }, + { + "Id": _VALID_SALESFORCE_IDS[43], + "FirstName": "Sarah", + "LastName": "Chen", + "Email": "sarah.chen@techcorp.com", + "Title": "Product Manager", + "Phone": "415-555-0101", + }, + { + "Id": _VALID_SALESFORCE_IDS[44], + "FirstName": "Michael", + "LastName": "Rodriguez", + "Email": "m.rodriguez@biomed.com", + "Title": "Research Director", + "Phone": "617-555-0202", + }, + { + "Id": _VALID_SALESFORCE_IDS[45], + "FirstName": "Emily", + "LastName": "Green", + "Email": "emily.g@greenenergy.com", + "Title": "Sustainability Lead", + "Phone": "503-555-0303", + }, + { + "Id": _VALID_SALESFORCE_IDS[46], + "FirstName": "David", + "LastName": "Kim", + "Email": "david.kim@dataflow.com", + "Title": "Data Scientist", + "Phone": "206-555-0404", + }, + { + "Id": _VALID_SALESFORCE_IDS[47], + "FirstName": "Rachel", + "LastName": "Taylor", + "Email": "r.taylor@cloudnine.com", + "Title": "Cloud Architect", + "Phone": "303-555-0505", + }, + ], + "Opportunity": [ + { + "Id": _VALID_SALESFORCE_IDS[62], + "Name": "Acme Server Upgrade", + "Amount": 50000, + "Stage": "Prospecting", + "CloseDate": "2024-06-30", + }, + { + "Id": _VALID_SALESFORCE_IDS[63], + "Name": "Globex Manufacturing Line", + "Amount": 150000, + "Stage": "Negotiation", + "CloseDate": "2024-03-15", + }, + { + "Id": _VALID_SALESFORCE_IDS[64], + "Name": "Initech Software License", + "Amount": 75000, + "Stage": "Closed Won", + "CloseDate": "2024-01-30", + }, + { + "Id": _VALID_SALESFORCE_IDS[65], + "Name": "TechCorp AI Implementation", + "Amount": 250000, + "Stage": "Needs Analysis", + "CloseDate": "2024-08-15", + "Probability": 60, + }, + { + "Id": _VALID_SALESFORCE_IDS[66], + "Name": "BioMed Lab Equipment", + "Amount": 500000, + "Stage": "Value Proposition", + "CloseDate": "2024-09-30", + "Probability": 75, + }, + { + "Id": _VALID_SALESFORCE_IDS[67], + "Name": "Green Energy Solar Project", + "Amount": 750000, + "Stage": "Proposal", + "CloseDate": "2024-07-15", + "Probability": 80, + }, + { + "Id": _VALID_SALESFORCE_IDS[68], + "Name": "DataFlow Analytics Platform", + "Amount": 180000, + "Stage": "Negotiation", + "CloseDate": "2024-05-30", + "Probability": 90, + }, + { + "Id": _VALID_SALESFORCE_IDS[69], + "Name": "Cloud Nine Infrastructure", + "Amount": 300000, + "Stage": "Qualification", + "CloseDate": "2024-10-15", + "Probability": 40, + }, + ], + } + + # Create CSV files for each object type + for object_type, records in example_data.items(): + _create_csv_file(object_type, records) + + +def _test_query() -> None: + """ + Tests querying functionality by verifying: + 1. All expected Account IDs are found + 2. Each Account's data matches what was inserted + """ + # Expected test data for verification + expected_accounts: dict[str, dict[str, str | int]] = { + _VALID_SALESFORCE_IDS[0]: { + "Name": "Acme Inc.", + "BillingCity": "New York", + "Industry": "Technology", + }, + _VALID_SALESFORCE_IDS[1]: { + "Name": "Globex Corp", + "BillingCity": "Los Angeles", + "Industry": "Manufacturing", + }, + _VALID_SALESFORCE_IDS[2]: { + "Name": "Initech", + "BillingCity": "Austin", + "Industry": "Software", + }, + _VALID_SALESFORCE_IDS[3]: { + "Name": "TechCorp Solutions", + "BillingCity": "San Francisco", + "Industry": "Software", + "AnnualRevenue": 5000000, + }, + _VALID_SALESFORCE_IDS[4]: { + "Name": "BioMed Research", + "BillingCity": "Boston", + "Industry": "Healthcare", + "AnnualRevenue": 12000000, + }, + _VALID_SALESFORCE_IDS[5]: { + "Name": "Green Energy Co", + "BillingCity": "Portland", + "Industry": "Energy", + "AnnualRevenue": 8000000, + }, + _VALID_SALESFORCE_IDS[6]: { + "Name": "DataFlow Analytics", + "BillingCity": "Seattle", + "Industry": "Technology", + "AnnualRevenue": 3000000, + }, + _VALID_SALESFORCE_IDS[7]: { + "Name": "Cloud Nine Services", + "BillingCity": "Denver", + "Industry": "Cloud Computing", + "AnnualRevenue": 7000000, + }, + } + + # Get all Account IDs + account_ids = find_ids_by_type("Account") + + # Verify we found all expected accounts + assert len(account_ids) == len( + expected_accounts + ), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}" + assert set(account_ids) == set( + expected_accounts.keys() + ), "Found account IDs don't match expected IDs" + + # Verify each account's data + for acc_id in account_ids: + combined = get_record(acc_id) + assert combined is not None, f"Could not find account {acc_id}" + + expected = expected_accounts[acc_id] + + # Verify account data matches + for key, value in expected.items(): + value = str(value) + assert ( + combined.data[key] == value + ), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}" + + print("All query tests passed successfully!") + + +def _test_upsert() -> None: + """ + Tests upsert functionality by: + 1. Updating an existing account + 2. Creating a new account + 3. Verifying both operations were successful + """ + # Create CSV for updating an existing account and adding a new one + update_data: list[dict[str, str | int]] = [ + { + "Id": _VALID_SALESFORCE_IDS[0], + "Name": "Acme Inc. Updated", + "BillingCity": "New York", + "Industry": "Technology", + "Description": "Updated company info", + }, + { + "Id": _VALID_SALESFORCE_IDS[2], + "Name": "New Company Inc.", + "BillingCity": "Miami", + "Industry": "Finance", + "AnnualRevenue": 1000000, + }, + ] + + _create_csv_file("Account", update_data, "update_data.csv") + + # Verify the update worked + updated_record = get_record(_VALID_SALESFORCE_IDS[0]) + assert updated_record is not None, "Updated record not found" + assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated" + assert ( + updated_record.data["Description"] == "Updated company info" + ), "Description not added" + + # Verify the new record was created + new_record = get_record(_VALID_SALESFORCE_IDS[2]) + assert new_record is not None, "New record not found" + assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect" + assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect" + + print("All upsert tests passed successfully!") + + +def _test_relationships() -> None: + """ + Tests relationship shelf updates and queries by: + 1. Creating test data with relationships + 2. Verifying the relationships are correctly stored + 3. Testing relationship queries + """ + # Create test data for each object type + test_data: dict[str, list[dict[str, str | int]]] = { + "Case": [ + { + "Id": _VALID_SALESFORCE_IDS[13], + "AccountId": _VALID_SALESFORCE_IDS[0], + "Subject": "Test Case 1", + }, + { + "Id": _VALID_SALESFORCE_IDS[14], + "AccountId": _VALID_SALESFORCE_IDS[0], + "Subject": "Test Case 2", + }, + ], + "Contact": [ + { + "Id": _VALID_SALESFORCE_IDS[48], + "AccountId": _VALID_SALESFORCE_IDS[0], + "FirstName": "Test", + "LastName": "Contact", + } + ], + "Opportunity": [ + { + "Id": _VALID_SALESFORCE_IDS[62], + "AccountId": _VALID_SALESFORCE_IDS[0], + "Name": "Test Opportunity", + "Amount": 100000, + } + ], + } + + # Create and update CSV files for each object type + for object_type, records in test_data.items(): + _create_csv_file(object_type, records, "relationship_test.csv") + + # Test relationship queries + # All these objects should be children of Acme Inc. + child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0]) + assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}" + assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship" + assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship" + assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship" + assert ( + _VALID_SALESFORCE_IDS[62] in child_ids + ), "Opportunity not found in relationship" + + # Test querying relationships for a different account (should be empty) + other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1]) + assert ( + len(other_account_children) == 0 + ), "Expected no children for different account" + + print("All relationship tests passed successfully!") + + +def _test_account_with_children() -> None: + """ + Tests querying all accounts and retrieving their child objects. + This test verifies that: + 1. All accounts can be retrieved + 2. Child objects are correctly linked + 3. Child object data is complete and accurate + """ + # First get all account IDs + account_ids = find_ids_by_type("Account") + assert len(account_ids) > 0, "No accounts found" + + # For each account, get its children and verify the data + for account_id in account_ids: + account = get_record(account_id) + assert account is not None, f"Could not find account {account_id}" + + # Get all child objects + child_ids = get_child_ids(account_id) + + # For Acme Inc., verify specific relationships + if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc. + assert ( + len(child_ids) == 4 + ), f"Expected 4 children for Acme Inc., found {len(child_ids)}" + + # Get all child records + child_records = [] + for child_id in child_ids: + child_record = get_record(child_id) + if child_record is not None: + child_records.append(child_record) + # Verify Cases + cases = [r for r in child_records if r.type == "Case"] + assert ( + len(cases) == 2 + ), f"Expected 2 cases for Acme Inc., found {len(cases)}" + case_subjects = {case.data["Subject"] for case in cases} + assert "Test Case 1" in case_subjects, "Test Case 1 not found" + assert "Test Case 2" in case_subjects, "Test Case 2 not found" + + # Verify Contacts + contacts = [r for r in child_records if r.type == "Contact"] + assert ( + len(contacts) == 1 + ), f"Expected 1 contact for Acme Inc., found {len(contacts)}" + contact = contacts[0] + assert contact.data["FirstName"] == "Test", "Contact first name mismatch" + assert contact.data["LastName"] == "Contact", "Contact last name mismatch" + + # Verify Opportunities + opportunities = [r for r in child_records if r.type == "Opportunity"] + assert ( + len(opportunities) == 1 + ), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}" + opportunity = opportunities[0] + assert ( + opportunity.data["Name"] == "Test Opportunity" + ), "Opportunity name mismatch" + assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch" + + print("All account with children tests passed successfully!") + + +def _test_relationship_updates() -> None: + """ + Tests that relationships are properly updated when a child object's parent reference changes. + This test verifies: + 1. Initial relationship is created correctly + 2. When parent reference is updated, old relationship is removed + 3. New relationship is created correctly + """ + # Create initial test data - Contact linked to Acme Inc. + initial_contact = [ + { + "Id": _VALID_SALESFORCE_IDS[40], + "AccountId": _VALID_SALESFORCE_IDS[0], + "FirstName": "Test", + "LastName": "Contact", + } + ] + _create_csv_file("Contact", initial_contact, "initial_contact.csv") + + # Verify initial relationship + acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0]) + assert ( + _VALID_SALESFORCE_IDS[40] in acme_children + ), "Initial relationship not created" + + # Update contact to be linked to Globex Corp instead + updated_contact = [ + { + "Id": _VALID_SALESFORCE_IDS[40], + "AccountId": _VALID_SALESFORCE_IDS[1], + "FirstName": "Test", + "LastName": "Contact", + } + ] + _create_csv_file("Contact", updated_contact, "updated_contact.csv") + + # Verify old relationship is removed + acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0]) + assert ( + _VALID_SALESFORCE_IDS[40] not in acme_children + ), "Old relationship not removed" + + # Verify new relationship is created + globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1]) + assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created" + + print("All relationship update tests passed successfully!") + + +def _test_get_affected_parent_ids() -> None: + """ + Tests get_affected_parent_ids functionality by verifying: + 1. IDs that are directly in the parent_types list are included + 2. IDs that have children in the updated_ids list are included + 3. IDs that are neither of the above are not included + """ + # Create test data with relationships + test_data = { + "Account": [ + { + "Id": _VALID_SALESFORCE_IDS[0], + "Name": "Parent Account 1", + }, + { + "Id": _VALID_SALESFORCE_IDS[1], + "Name": "Parent Account 2", + }, + { + "Id": _VALID_SALESFORCE_IDS[2], + "Name": "Not Affected Account", + }, + ], + "Contact": [ + { + "Id": _VALID_SALESFORCE_IDS[40], + "AccountId": _VALID_SALESFORCE_IDS[0], + "FirstName": "Child", + "LastName": "Contact", + } + ], + } + + # Create and update CSV files for test data + for object_type, records in test_data.items(): + _create_csv_file(object_type, records) + + # Test Case 1: Account directly in updated_ids and parent_types + updated_ids = [_VALID_SALESFORCE_IDS[1]] # Parent Account 2 + parent_types = ["Account"] + affected_ids_by_type = dict( + get_affected_parent_ids_by_type(updated_ids, parent_types) + ) + assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type" + assert ( + _VALID_SALESFORCE_IDS[1] in affected_ids_by_type["Account"] + ), "Direct parent ID not included" + + # Test Case 2: Account with child in updated_ids + updated_ids = [_VALID_SALESFORCE_IDS[40]] # Child Contact + parent_types = ["Account"] + affected_ids_by_type = dict( + get_affected_parent_ids_by_type(updated_ids, parent_types) + ) + assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type" + assert ( + _VALID_SALESFORCE_IDS[0] in affected_ids_by_type["Account"] + ), "Parent of updated child not included" + + # Test Case 3: Both direct and indirect affects + updated_ids = [_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]] # Both cases + parent_types = ["Account"] + affected_ids_by_type = dict( + get_affected_parent_ids_by_type(updated_ids, parent_types) + ) + assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type" + affected_ids = affected_ids_by_type["Account"] + assert len(affected_ids) == 2, "Expected exactly two affected parent IDs" + assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included" + assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included" + assert ( + _VALID_SALESFORCE_IDS[2] not in affected_ids + ), "Unaffected ID incorrectly included" + + # Test Case 4: No matches + updated_ids = [_VALID_SALESFORCE_IDS[40]] # Child Contact + parent_types = ["Opportunity"] # Wrong type + affected_ids_by_type = dict( + get_affected_parent_ids_by_type(updated_ids, parent_types) + ) + assert len(affected_ids_by_type) == 0, "Should return empty dict when no matches" + + print("All get_affected_parent_ids tests passed successfully!") + + +def test_salesforce_sqlite() -> None: + _clear_sf_db() + init_db() + _create_csv_with_example_data() + _test_query() + _test_upsert() + _test_relationships() + _test_account_with_children() + _test_relationship_updates() + _test_get_affected_parent_ids() + _clear_sf_db()