Skip to content

Commit

Permalink
namedtuple -> DsetTuple missing attr fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Dec 27, 2024
1 parent 3165ce0 commit 6c036e0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
18 changes: 12 additions & 6 deletions sup3r/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,27 @@ class DsetTuple:
while being serializable"""

def __init__(self, **kwargs):
self.dset_names = list(kwargs)
self.__dict__.update(kwargs)

@property
def dsets(self):
"""Dictionary with only dset names and associated values."""
return {k: v for k, v in self.__dict__.items() if k in self.dset_names}

def __iter__(self):
return iter(self.__dict__.values())
return iter(self.dsets.values())

def __getitem__(self, key):
if isinstance(key, int):
key = list(self.__dict__)[key]
return self.__dict__[key]
key = list(self.dsets)[key]
return self.dsets[key]

def __len__(self):
return len(self.__dict__)
return len(self.dsets)

def __repr__(self):
return f"DsetTuple({self.__dict__})"
return f'DsetTuple({self.dsets})'


class Sup3rDataset:
Expand Down Expand Up @@ -237,7 +243,7 @@ def __getitem__(self, keys):
if len(self._ds) == 1:
return out[-1]
if all(isinstance(o, Sup3rX) for o in out):
return type(self)(**dict(zip(self._ds._fields, out)))
return type(self)(**dict(zip(self._ds.dset_names, out)))
return out

@property
Expand Down
4 changes: 2 additions & 2 deletions sup3r/preprocessing/batch_queues/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, samplers, **kwargs):
--------
:class:`~sup3r.preprocessing.batch_queues.abstract.AbstractBatchQueue`
"""
self.BATCH_MEMBERS = samplers[0]._fields
self.BATCH_MEMBERS = samplers[0].dset_names
super().__init__(samplers, **kwargs)
self.check_enhancement_factors()

Expand All @@ -30,7 +30,7 @@ def __init__(self, samplers, **kwargs):
def queue_shape(self):
"""Shape of objects stored in the queue."""
queue_shapes = [(self.batch_size, *self.lr_shape)]
hr_mems = len(self.Batch._fields) - 1
hr_mems = len(self.BATCH_MEMBERS) - 1
queue_shapes += [(self.batch_size, *self.hr_shape)] * hr_mems
return queue_shapes

Expand Down

0 comments on commit 6c036e0

Please sign in to comment.