Skip to content

Commit

Permalink
missed compute call for slow batching. this was hidden by queueing an…
Browse files Browse the repository at this point in the history
…d dequeueing since this would cast to tensors.
  • Loading branch information
bnb32 committed Dec 29, 2024
1 parent 1e7afa9 commit 046a279
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
4 changes: 2 additions & 2 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def sample_batches(self, n_batches) -> None:
def needed_batches(self):
"""Number of batches needed to either fill or the queue or hit the
epoch limit."""
remaining = self.n_batches - self._batch_count - self.queue_len - 1
remaining = self.n_batches - self._batch_count - self.queue_len
return min(self.queue_cap - self.queue_len, remaining)

def enqueue_batches(self) -> None:
Expand Down Expand Up @@ -309,14 +309,14 @@ def __next__(self) -> DsetTuple:
if self._batch_count < self.n_batches:
self.timer.start()
samples = self.get_batch()
self._batch_count += 1
if self.sample_shape[2] == 1:
if isinstance(samples, (list, tuple)):
samples = tuple(s[..., 0, :] for s in samples)
else:
samples = samples[..., 0, :]
batch = self.post_proc(samples)
self.timer.stop()
self._batch_count += 1
if self.verbose:
logger.debug(
'Batch step %s finished in %s.',
Expand Down
16 changes: 8 additions & 8 deletions sup3r/preprocessing/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Dict, Optional, Tuple
from warnings import warn

import dask.array as da
import numpy as np

from sup3r.preprocessing.base import Container
Expand Down Expand Up @@ -194,9 +193,8 @@ def _reshape_samples(self, samples):
new_shape[2] // self.batch_size,
new_shape[-1],
]
out = compute_if_dask(samples)
# (lats, lons, batch_size, times, feats)
out = np.reshape(out, new_shape)
out = np.reshape(samples, new_shape)
# (batch_size, lats, lons, times, feats)
return np.transpose(out, axes=(2, 0, 1, 3, 4))

Expand All @@ -223,25 +221,27 @@ def _stack_samples(self, samples):
(batch_size, samp_shape[0], samp_shape[1], samp_shape[2], n_feats)
"""
if isinstance(samples[0], tuple):
lr = da.stack([s[0] for s in samples], axis=0)
hr = da.stack([s[1] for s in samples], axis=0)
lr = np.stack([s[0] for s in samples], axis=0)
hr = np.stack([s[1] for s in samples], axis=0)
return (lr, hr)
return da.stack(samples, axis=0)
return np.stack(samples, axis=0)

def _fast_batch(self):
"""Get batch of samples with adjacent time slices."""
out = self.data.sample(self.get_sample_index(n_obs=self.batch_size))
out = compute_if_dask(out)
if isinstance(out, tuple):
return tuple(self._reshape_samples(o) for o in out)
return self._reshape_samples(out)

def _slow_batch(self):
"""Get batch of samples with random time slices."""
samples = [
out = [
self.data.sample(self.get_sample_index(n_obs=1))
for _ in range(self.batch_size)
]
return self._stack_samples(samples)
out = compute_if_dask(out)
return self._stack_samples(out)

def _fast_batch_possible(self):
return self.batch_size * self.sample_shape[2] <= self.data.shape[2]
Expand Down
2 changes: 2 additions & 0 deletions sup3r/preprocessing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def compute_if_dask(arr):
compute_if_dask(arr.stop),
compute_if_dask(arr.step),
)
if isinstance(arr, (tuple, list)):
return type(arr)(compute_if_dask(a) for a in arr)
return arr.compute() if hasattr(arr, 'compute') else arr


Expand Down

0 comments on commit 046a279

Please sign in to comment.