Skip to content

Commit

Permalink
update example code, test=docs_preview
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Aug 1, 2023
1 parent e8a6b31 commit fbce310
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 29 deletions.
7 changes: 7 additions & 0 deletions python/paddle/io/dataloader/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ class BatchSampler(Sampler):
.. code-block:: python
>>> import numpy as np
>>> from paddle.io import RandomSampler, BatchSampler, Dataset
>>> np.random.seed(2023)
>>> # init with dataset
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
Expand All @@ -80,7 +82,9 @@ class BatchSampler(Sampler):
...
>>> for batch_indices in bs:
... print(batch_indices)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
...
[96, 97, 98, 99]
>>> # init with sampler
>>> sampler = RandomSampler(RandomDataset(100))
>>> bs = BatchSampler(sampler=sampler,
Expand All @@ -89,6 +93,9 @@ class BatchSampler(Sampler):
...
>>> for batch_indices in bs:
... print(batch_indices)
[56, 12, 68, 0, 82, 66, 91, 44]
...
[53, 17, 22, 86, 52, 3, 92, 33]
"""

def __init__(
Expand Down
48 changes: 23 additions & 25 deletions python/paddle/io/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class Dataset:
...
>>> dataset = RandomDataset(10)
>>> for i in range(len(dataset)):
... print(dataset[i])
... image, label = dataset[i]
... # do something
"""

def __init__(self):
Expand Down Expand Up @@ -109,8 +110,9 @@ class IterableDataset(Dataset):
... yield image, label
...
>>> dataset = RandomDataset(10)
>>> for img, lbl in dataset:
... print(img, lbl)
>>> for img, label in dataset:
... # do something
... ...
When :attr:`num_workers > 0`, each worker has a different copy of the dataset object and
will yield whole dataset samples, which means samples in dataset will be repeated in
Expand Down Expand Up @@ -158,7 +160,7 @@ class IterableDataset(Dataset):
... drop_last=True)
...
>>> for data in dataloader:
... print(data)
... print(data) # doctest: +SKIP
Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
[[2]])
Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
Expand Down Expand Up @@ -216,7 +218,7 @@ class IterableDataset(Dataset):
... worker_init_fn=worker_init_fn)
...
>>> for data in dataloader:
... print(data)
... print(data) # doctest: +SKIP
Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
[[2]])
Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
Expand Down Expand Up @@ -288,7 +290,7 @@ class TensorDataset(Dataset):
>>> for i in range(len(dataset)):
... input, label = dataset[i]
... print(input, label)
... # do something
"""

def __init__(self, tensors):
Expand Down Expand Up @@ -354,10 +356,7 @@ class ComposeDataset(Dataset):
>>> dataset = ComposeDataset([RandomDataset(10), RandomDataset(10)])
>>> for i in range(len(dataset)):
... image1, label1, image2, label2 = dataset[i]
... print(image1)
... print(label1)
... print(image2)
... print(label2)
... # do something
"""

def __init__(self, datasets):
Expand Down Expand Up @@ -420,7 +419,9 @@ class ChainDataset(IterableDataset):
...
>>> dataset = ChainDataset([RandomDataset(10), RandomDataset(10)])
>>> for image, label in iter(dataset):
... print(image, label)
... # do something
... ...
"""

def __init__(self, datasets):
Expand Down Expand Up @@ -497,31 +498,28 @@ def random_split(dataset, lengths, generator=None):
>>> import paddle
>>> from paddle.io import random_split
>>> paddle.seed(2023)
>>> a_list = paddle.io.random_split(range(10), [3, 7])
>>> print(len(a_list))
2
>>> # output of the first subset
>>> for idx, v in enumerate(a_list[0]):
... print(idx, v)
>>> # doctest: +SKIP
0 1
1 3
2 9
>>> # doctest: -SKIP
0 8
1 2
2 5
>>> # output of the second subset
>>> for idx, v in enumerate(a_list[1]):
... print(idx, v)
>>> # doctest: +SKIP
0 5
1 7
2 8
3 6
4 0
5 2
6 4
>>> # doctest: -SKIP
0 9
1 6
2 3
3 4
4 1
5 0
6 7
"""
# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore
Expand Down
23 changes: 19 additions & 4 deletions python/paddle/io/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,10 @@ class RandomSampler(Sampler):
.. code-block:: python
>>> import numpy as np
>>> from paddle.io import Dataset, RandomSampler
>>> np.random.seed(2023)
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
... self.num_samples = num_samples
Expand All @@ -188,6 +190,11 @@ class RandomSampler(Sampler):
>>> for index in sampler:
... print(index)
56
12
68
...
87
"""

def __init__(
Expand Down Expand Up @@ -297,14 +304,22 @@ class WeightedRandomSampler(Sampler):
.. code-block:: python
>>> import numpy as np
>>> from paddle.io import WeightedRandomSampler
>>> sampler = WeightedRandomSampler(weights=[0.1, 0.3, 0.5, 0.7, 0.2],
... num_samples=5,
... replacement=True)
...
>>> np.random.seed(2023)
>>> sampler = WeightedRandomSampler(
... weights=[0.1, 0.3, 0.5, 0.7, 0.2],
... num_samples=5,
... replacement=True
... )
>>> for index in sampler:
... print(index)
2
4
3
1
1
"""

def __init__(self, weights, num_samples, replacement=True):
Expand Down

0 comments on commit fbce310

Please sign in to comment.