diff --git a/cognite/neat/rules/importers/_spreadsheet2rules.py b/cognite/neat/rules/importers/_spreadsheet2rules.py index c14035223..9ac5d5c8d 100644 --- a/cognite/neat/rules/importers/_spreadsheet2rules.py +++ b/cognite/neat/rules/importers/_spreadsheet2rules.py @@ -4,6 +4,7 @@ """ from collections import UserDict, defaultdict +from dataclasses import dataclass from pathlib import Path from typing import Literal, cast, overload @@ -49,6 +50,12 @@ def role(self) -> RoleTypes: def has_schema_field(self) -> bool: return self.get("schema") in [schema.value for schema in SchemaCompleteness.__members__.values()] + @property + def schema(self) -> SchemaCompleteness | None: + if not self.has_schema_field: + return None + return SchemaCompleteness(self["schema"]) + def is_valid(self, issue_list: IssueList, filepath: Path) -> bool: if not self.has_role_field: issue_list.append(issues.spreadsheet_file.RoleMissingOrUnsupportedError(filepath)) @@ -61,6 +68,14 @@ def is_valid(self, issue_list: IssueList, filepath: Path) -> bool: return True +@dataclass +class ReadResult: + sheets: dict[str, dict | list] + read_info_by_sheet: dict[str, SpreadsheetRead] + role: RoleTypes + schema: SchemaCompleteness | None + + class SpreadsheetReader: def __init__(self, issue_list: IssueList, is_reference: bool = False): self.issue_list = issue_list @@ -79,7 +94,7 @@ def sheet_names(self, role: RoleTypes) -> set[str]: def to_reference_sheet(cls, sheet_name: str) -> str: return f"Ref{sheet_name}" - def read(self, filepath: Path) -> Rules | None: + def read(self, filepath: Path) -> None | ReadResult: with pd.ExcelFile(filepath) as excel_file: if self.metadata_sheet_name not in excel_file.sheet_names: self.issue_list.append( @@ -95,21 +110,10 @@ def read(self, filepath: Path) -> Rules | None: return None sheets, read_info_by_sheet = self._read_sheets(metadata, excel_file) - if self.issue_list.has_errors: - return None - - rules_cls = RULES_PER_ROLE[metadata.role] - with _handle_issues( - self.issue_list, - error_cls=issues.spreadsheet.InvalidSheetError, - error_args={"read_info_by_sheet": read_info_by_sheet}, - ) as future: - rules = rules_cls.model_validate(sheets) # type: ignore[attr-defined] - - if future.result == "failure" or self.issue_list.has_errors: + if sheets is None or self.issue_list.has_errors: return None - return rules + return ReadResult(sheets, read_info_by_sheet, metadata.role, metadata.schema) def _read_sheets( self, metadata: MetadataRaw, excel_file: ExcelFile @@ -172,38 +176,56 @@ def to_rules( issue_list.append(issues.spreadsheet_file.SpreadsheetNotFoundError(self.filepath)) return self._return_or_raise(issue_list, errors) - user_rules: Rules | None = None + user_result: ReadResult | None = None if not is_reference: - user_rules = SpreadsheetReader(issue_list, is_reference=False).read(self.filepath) - if issue_list.has_errors: + user_result = SpreadsheetReader(issue_list, is_reference=False).read(self.filepath) + if user_result is None or issue_list.has_errors: return self._return_or_raise(issue_list, errors) - reference_rules: Rules | None = None + reference_result: ReadResult | None = None if is_reference or ( - user_rules - and user_rules.metadata.role != RoleTypes.domain_expert - and cast(DMSRules | InformationRules, user_rules).metadata.schema_ == SchemaCompleteness.extended + user_result + and user_result.role != RoleTypes.domain_expert + and user_result.schema == SchemaCompleteness.extended ): - reference_rules = SpreadsheetReader(issue_list, is_reference=True).read(self.filepath) + reference_result = SpreadsheetReader(issue_list, is_reference=True).read(self.filepath) if issue_list.has_errors: return self._return_or_raise(issue_list, errors) - if user_rules and reference_rules and user_rules.metadata.role != reference_rules.metadata.role: + if user_result and reference_result and user_result.role != reference_result.role: issue_list.append(issues.spreadsheet_file.RoleMismatchError(self.filepath)) return self._return_or_raise(issue_list, errors) - if user_rules and reference_rules: - rules = user_rules - rules.reference = reference_rules - elif user_rules: - rules = user_rules - elif reference_rules: - rules = reference_rules + if user_result and reference_result: + user_result.sheets["reference"] = reference_result.sheets + sheets = user_result.sheets + role = user_result.role + read_info_by_sheet = user_result.read_info_by_sheet + read_info_by_sheet.update(reference_result.read_info_by_sheet) + elif user_result: + sheets = user_result.sheets + role = user_result.role + read_info_by_sheet = user_result.read_info_by_sheet + elif reference_result: + sheets = reference_result.sheets + role = reference_result.role + read_info_by_sheet = reference_result.read_info_by_sheet else: raise ValueError( "No rules were generated. This should have been caught earlier. " f"Bug in {type(self).__name__}." ) + rules_cls = RULES_PER_ROLE[role] + with _handle_issues( + issue_list, + error_cls=issues.spreadsheet.InvalidSheetError, + error_args={"read_info_by_sheet": read_info_by_sheet}, + ) as future: + rules = rules_cls.model_validate(sheets) # type: ignore[attr-defined] + + if future.result == "failure" or issue_list.has_errors: + return self._return_or_raise(issue_list, errors) + return self._to_output( rules, issue_list, diff --git a/tests/tests_unit/rules/test_models/test_dms_architect_rules.py b/tests/tests_unit/rules/test_models/test_dms_architect_rules.py index 713058650..c578ba30d 100644 --- a/tests/tests_unit/rules/test_models/test_dms_architect_rules.py +++ b/tests/tests_unit/rules/test_models/test_dms_architect_rules.py @@ -718,43 +718,9 @@ def rules_schema_tests_cases() -> Iterable[ParameterSet]: id="No casing standardization", ) - DMSRules( - metadata=DMSMetadata( - schema_="complete", - space="sp_enterprise", - external_id="enterprise_model", - version="1", - creator="Alice", - created="2021-01-01T00:00:00", - updated="2021-01-01T00:00:00", - ), - properties=SheetList[DMSProperty]( - data=[ - DMSProperty( - class_="Asset", - property_="children", - value_type="Asset", - relation="multiedge", - view="Asset", - view_property="children", - ), - ] - ), - containers=SheetList[DMSContainer]( - data=[ - DMSContainer(container="Asset", class_="Asset"), - ] - ), - views=SheetList[DMSView]( - data=[ - DMSView(view="Asset", class_="Asset"), - ] - ), - ) - dms_rules = DMSRules( metadata=DMSMetadata( - schema_="extended", + schema_="complete", space="sp_solution", external_id="solution_model", version="1",