From 7800eddef850df6c8f4c214839b9c387ab87c680 Mon Sep 17 00:00:00 2001 From: lbluque Date: Wed, 29 May 2024 10:15:10 -0700 Subject: [PATCH] TST: fix cluster tests --- smol/cofe/space/cluster.py | 6 +++--- tests/test_cofe/test_cluster.py | 28 ++++++++++++++-------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/smol/cofe/space/cluster.py b/smol/cofe/space/cluster.py index 2b6c740f5..17576b7bd 100644 --- a/smol/cofe/space/cluster.py +++ b/smol/cofe/space/cluster.py @@ -193,9 +193,9 @@ def from_file(cls, filename: str): """ filename = str(filename) - with zopen(filename) as f: - contents = f.read() - fname = filename.lower() + with zopen(filename, mode="rt", errors="replace") as file: + contents = file.read() + fname = os.path.basename(filename) if fnmatch(fname, "*.json*") or fnmatch(fname, "*.mson*"): return cls.from_str(contents, fmt="json") diff --git a/tests/test_cofe/test_cluster.py b/tests/test_cofe/test_cluster.py index b5ce9e33f..8d9288787 100644 --- a/tests/test_cofe/test_cluster.py +++ b/tests/test_cofe/test_cluster.py @@ -1,10 +1,9 @@ -import json import os from itertools import combinations import numpy as np import pytest -from ruamel import yaml +from ruamel.yaml import YAML from smol.cofe.space import Cluster from smol.cofe.space.domain import get_site_spaces @@ -70,23 +69,24 @@ def test_to_from(cluster, tmpdir): cluster2 = Cluster.from_str(yml, "yaml") assert cluster == cluster2 + YAML() with open(os.path.join(tmpdir, "cluster.yaml"), "w") as f: - yaml.dump(yml, f) + f.write(yml) with open(os.path.join(tmpdir, "cluster.json"), "w") as f: - json.dump(js, f) + f.write(js) - # cluster2 = Cluster.from_file(os.path.join(tmpdir, "cluster.yaml")) - # assert cluster == cluster2 - # cluster2 = Cluster.from_file(os.path.join(tmpdir, "cluster.json")) - # assert cluster == cluster2 + cluster2 = Cluster.from_file(os.path.join(tmpdir, "cluster.yaml")) + assert cluster == cluster2 + cluster2 = Cluster.from_file(os.path.join(tmpdir, "cluster.json")) + assert cluster == cluster2 - # cluster.to("yaml", os.path.join(tmpdir, "cluster.yaml")) - # cluster.to("json", os.path.join(tmpdir, "cluster.json")) + cluster.to("yaml", os.path.join(tmpdir, "cluster.yaml")) + cluster.to("json", os.path.join(tmpdir, "cluster.json")) - # cluster2 = Cluster.from_file(os.path.join(tmpdir, "cluster.yaml")) - # assert cluster == cluster2 - # cluster2 = Cluster.from_file(os.path.join(tmpdir, "cluster.json")) - # assert cluster == cluster2 + cluster2 = Cluster.from_file(os.path.join(tmpdir, "cluster.yaml")) + assert cluster == cluster2 + cluster2 = Cluster.from_file(os.path.join(tmpdir, "cluster.json")) + assert cluster == cluster2 with pytest.raises(ValueError): cluster.to("bad_format")