diff --git a/src/pytroll_watchers/local_watcher.py b/src/pytroll_watchers/local_watcher.py index 0d2a233..abf15a6 100644 --- a/src/pytroll_watchers/local_watcher.py +++ b/src/pytroll_watchers/local_watcher.py @@ -1,6 +1,22 @@ """Watcher for non-remote file systems. Either using OS-based envents (like inotify on linux), or polling. + +An example configuration file to retrieve data from a directory. + +.. code-block:: yaml + + backend: local + fs_config: + directory: /data + file pattern: "H-000-{orig_platform_name:4s}__-{orig_platform_name:4s}_{service:3s}____-{channel_name:_<9s}-\ + {segment:_<9s}-{start_time:%Y%m%d%H%M}-{compression:1s}_" + publisher_config: + name: hrit_watcher + message_config: + subject: /segment/hrit/l1b/ + atype: file + """ import logging import os diff --git a/src/pytroll_watchers/publisher.py b/src/pytroll_watchers/publisher.py index 9abcd5d..378a856 100644 --- a/src/pytroll_watchers/publisher.py +++ b/src/pytroll_watchers/publisher.py @@ -3,9 +3,10 @@ import datetime import json import logging -from contextlib import closing, suppress +from contextlib import closing, contextmanager, suppress from copy import deepcopy +import fsspec from posttroll.message import Message from posttroll.publisher import create_publisher_from_dict_config from trollsift import parse @@ -104,12 +105,37 @@ def _build_file_location(file_item, include_dir=None): else: uid = file_item.name file_location["uid"] = uid - with suppress(AttributeError): - file_location["filesystem"] = json.loads(file_item.fs.to_json()) + with suppress(AttributeError): # fileitem is not a UPath if it cannot access .fs + with dummy_connect(file_item): + file_location["filesystem"] = json.loads(file_item.fs.to_json(include_password=False)) + file_location["path"] = file_item.path + return file_location +@contextmanager +def dummy_connect(file_item): + """Make the _connect method of the fsspec class a no-op. + + This is for the case where only serialization of the filesystem is needed. + """ + def _fake_connect(*_args, **_kwargs): ... + + klass = fsspec.get_filesystem_class(file_item.protocol) + try: + original_connect = klass._connect + except AttributeError: + yield + return + + klass._connect = _fake_connect + try: + yield + finally: + klass._connect = original_connect + + def apply_aliases(aliases, metadata): """Apply aliases to the metadata. diff --git a/tests/test_local_watcher.py b/tests/test_local_watcher.py index f82db53..ac121fc 100644 --- a/tests/test_local_watcher.py +++ b/tests/test_local_watcher.py @@ -118,3 +118,25 @@ def test_publish_paths_forbids_passing_password(tmp_path, patched_local_events, local_watcher.file_publisher(fs_config=local_settings, publisher_config=publisher_settings, message_config=message_settings) + + +def test_publish_paths_with_ssh(tmp_path, patched_local_events): # noqa + """Test publishing paths with an ssh protocol.""" + filename = os.fspath(tmp_path / "foo.txt") + + host = "localhost" + + local_settings = dict(directory=tmp_path, protocol="ssh", + storage_options=dict(host=host)) + publisher_settings = dict(nameservers=False, port=1979) + message_settings = dict(subject="/segment/viirs/l1b/", atype="file", data=dict(sensor="viirs")) + + with patched_local_events([filename]): + with patched_publisher() as published_messages: + local_watcher.file_publisher(fs_config=local_settings, + publisher_config=publisher_settings, + message_config=message_settings) + assert len(published_messages) == 1 + message = Message(rawstr=published_messages[0]) + assert message.data["uri"].startswith("ssh://") + assert message.data["filesystem"]["host"] == host