diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 1537ca3bb..6c022118d 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -271,7 +271,7 @@ def sample_batches(self, n_batches) -> None: def needed_batches(self): """Number of batches needed to either fill or the queue or hit the epoch limit.""" - remaining = self.n_batches - self._batch_count - self.queue_len - 1 + remaining = self.n_batches - self._batch_count - self.queue_len return min(self.queue_cap - self.queue_len, remaining) def enqueue_batches(self) -> None: @@ -309,7 +309,6 @@ def __next__(self) -> DsetTuple: if self._batch_count < self.n_batches: self.timer.start() samples = self.get_batch() - self._batch_count += 1 if self.sample_shape[2] == 1: if isinstance(samples, (list, tuple)): samples = tuple(s[..., 0, :] for s in samples) @@ -317,6 +316,7 @@ def __next__(self) -> DsetTuple: samples = samples[..., 0, :] batch = self.post_proc(samples) self.timer.stop() + self._batch_count += 1 if self.verbose: logger.debug( 'Batch step %s finished in %s.', diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index b824cd360..e576c1521 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -7,7 +7,6 @@ from typing import Dict, Optional, Tuple from warnings import warn -import dask.array as da import numpy as np from sup3r.preprocessing.base import Container @@ -194,9 +193,8 @@ def _reshape_samples(self, samples): new_shape[2] // self.batch_size, new_shape[-1], ] - out = compute_if_dask(samples) # (lats, lons, batch_size, times, feats) - out = np.reshape(out, new_shape) + out = np.reshape(samples, new_shape) # (batch_size, lats, lons, times, feats) return np.transpose(out, axes=(2, 0, 1, 3, 4)) @@ -223,25 +221,27 @@ def _stack_samples(self, samples): (batch_size, samp_shape[0], samp_shape[1], samp_shape[2], n_feats) """ if isinstance(samples[0], tuple): - lr = da.stack([s[0] for s in samples], axis=0) - hr = da.stack([s[1] for s in samples], axis=0) + lr = np.stack([s[0] for s in samples], axis=0) + hr = np.stack([s[1] for s in samples], axis=0) return (lr, hr) - return da.stack(samples, axis=0) + return np.stack(samples, axis=0) def _fast_batch(self): """Get batch of samples with adjacent time slices.""" out = self.data.sample(self.get_sample_index(n_obs=self.batch_size)) + out = compute_if_dask(out) if isinstance(out, tuple): return tuple(self._reshape_samples(o) for o in out) return self._reshape_samples(out) def _slow_batch(self): """Get batch of samples with random time slices.""" - samples = [ + out = [ self.data.sample(self.get_sample_index(n_obs=1)) for _ in range(self.batch_size) ] - return self._stack_samples(samples) + out = compute_if_dask(out) + return self._stack_samples(out) def _fast_batch_possible(self): return self.batch_size * self.sample_shape[2] <= self.data.shape[2] diff --git a/sup3r/preprocessing/utilities.py b/sup3r/preprocessing/utilities.py index 8a51fadbf..f52e91b72 100644 --- a/sup3r/preprocessing/utilities.py +++ b/sup3r/preprocessing/utilities.py @@ -228,6 +228,8 @@ def compute_if_dask(arr): compute_if_dask(arr.stop), compute_if_dask(arr.step), ) + if isinstance(arr, (tuple, list)): + return type(arr)(compute_if_dask(a) for a in arr) return arr.compute() if hasattr(arr, 'compute') else arr