From 3a25d01e1a7a02d01fa228304344aeb750c6e08a Mon Sep 17 00:00:00 2001 From: Sufiyan Adhikari Date: Sun, 21 Apr 2024 16:15:00 +0530 Subject: [PATCH] support non-authoritative generic sources --- .../source/generic_importer_source.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/beancount_import/source/generic_importer_source.py b/beancount_import/source/generic_importer_source.py index a2df6c2..0cdf646 100644 --- a/beancount_import/source/generic_importer_source.py +++ b/beancount_import/source/generic_importer_source.py @@ -35,8 +35,8 @@ class ImporterSource(DescriptionBasedSource): def __init__(self, directory: str, - account: str, importer: ImporterProtocol, + account: Optional[str]=None, # use None for importers that are not authoritative and would not clear any postings **kwargs) -> None: super().__init__(**kwargs) self.directory = os.path.expanduser(directory) @@ -57,11 +57,16 @@ def name(self) -> str: return self.importer.name() def prepare(self, journal: 'JournalEditor', results: SourceResults) -> None: - results.add_account(self.account) + if self.account: + results.add_account(self.account) entries = OrderedDict() #type: Dict[Hashable, List[Directive]] for f in self.files: f_entries = self.importer.extract(f, existing_entries=journal.entries) + # if the importer is not authoritative, add all entries to pending + if not self.account: + results.add_pending_entries(map(self._make_import_result, f_entries)) + continue # collect all entries in current statement, grouped by hash hashed_entries = OrderedDict() #type: Dict[Hashable, Directive] for entry in f_entries: @@ -77,14 +82,15 @@ def prepare(self, journal: 'JournalEditor', results: SourceResults) -> None: n = len(entries[key_]) entries.setdefault(key_, []).extend(hashed_entries[key_][n:]) - get_pending_and_invalid_entries( - raw_entries=list(itertools.chain.from_iterable(entries.values())), - journal_entries=journal.all_entries, - account_set=set([self.account]), - get_key_from_posting=_get_key_from_posting, - get_key_from_raw_entry=self._get_key_from_imported_entry, - make_import_result=self._make_import_result, - results=results) + if self.account: + get_pending_and_invalid_entries( + raw_entries=list(itertools.chain.from_iterable(entries.values())), + journal_entries=journal.all_entries, + account_set=set([self.account]), + get_key_from_posting=_get_key_from_posting, + get_key_from_raw_entry=self._get_key_from_imported_entry, + make_import_result=self._make_import_result, + results=results) def _add_description(self, entry: Transaction): if not isinstance(entry, Transaction): return None