From 55eab8e650b0cb845a22b4abd9935416db20c230 Mon Sep 17 00:00:00 2001 From: Alex Lee Date: Wed, 25 Sep 2024 13:39:17 +0200 Subject: [PATCH 1/3] Fix for issue 7369 --- pymc/backends/ndarray.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index 23f05488b97..b729e417ae5 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -115,6 +115,8 @@ def record(self, point, sampler_stats=None) -> None: if sampler_stats is not None: for data, vars in zip(self._stats, sampler_stats): for key, val in vars.items(): + if isinstance(val, np.ndarray) and val.shape[0] == 1: + val = val.item() data[key][self.draw_idx] = val self.draw_idx += 1 From 5b936a7d89aceb73f6017fc2c57dff7645c759c2 Mon Sep 17 00:00:00 2001 From: Alex Lee Date: Thu, 26 Sep 2024 12:39:59 +0200 Subject: [PATCH 2/3] Responding to PR comments --- pymc/backends/ndarray.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index b729e417ae5..e69bbae4bf1 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -114,9 +114,13 @@ def record(self, point, sampler_stats=None) -> None: raise ValueError("Unknown sampler_stats") if sampler_stats is not None: for data, vars in zip(self._stats, sampler_stats): - for key, val in vars.items(): - if isinstance(val, np.ndarray) and val.shape[0] == 1: - val = val.item() + compressed_vars = {} + for k, v in vars.items(): + if isinstance(v, np.ndarray) and v.shape[0] == 1: + compressed_vars[k] = v.item() + else: + compressed_vars[k] = v + for key, val in compressed_vars.items(): data[key][self.draw_idx] = val self.draw_idx += 1 From e5591b7292bd6e5bef19596797d4089e23b8df7c Mon Sep 17 00:00:00 2001 From: Alex Lee Date: Mon, 30 Sep 2024 09:42:05 +0200 Subject: [PATCH 3/3] Responding to PR comments --- pymc/backends/ndarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index e69bbae4bf1..fadc5b7e530 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -116,8 +116,8 @@ def record(self, point, sampler_stats=None) -> None: for data, vars in zip(self._stats, sampler_stats): compressed_vars = {} for k, v in vars.items(): - if isinstance(v, np.ndarray) and v.shape[0] == 1: - compressed_vars[k] = v.item() + if isinstance(v, np.ndarray) and v.shape[0] == 1 and len(v.shape) == 1: + compressed_vars[k] = v.reshape(1, 1) else: compressed_vars[k] = v for key, val in compressed_vars.items():