-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNewsSummaryDatasetModule.py
72 lines (63 loc) · 2 KB
/
NewsSummaryDatasetModule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from NewsSummaryDataset import NewsSummaryDataset
import pandas as pd
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from transformers import (
T5TokenizerFast as T5Tokenizer
)
# News Summary Dataset Module containing the main
# train, test and validation dataloaders to be used
# in model fine-tuning
class NewsSummaryDatasetModule(pl.LightningDataModule):
def __init__(
self,
train_df: pd.DataFrame,
test_df: pd.DataFrame,
tokenizer: T5Tokenizer,
batch_size: int = 8,
test_max_token_len: int = 512,
summary_max_token_len: int = 128
):
super().__init__()
self.train_df = train_df
self.test_df = test_df
self.batch_size = batch_size
self.tokenizer = tokenizer
self.test_max_token_len = test_max_token_len # used for News Summary Dataset config
self.summary_max_token_len = summary_max_token_len # used for News Summary Dataset config
def setup(self, stage=None):
# create train dataset
self.train_dataset = NewsSummaryDataset(
self.train_df,
self.tokenizer,
self.test_max_token_len,
self.summary_max_token_len
)
# create test dataset
self.test_dataset = NewsSummaryDataset(
self.test_df,
self.tokenizer,
self.test_max_token_len,
self.summary_max_token_len
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size = self.batch_size,
shuffle=True,
num_workers=2
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size = self.batch_size,
shuffle=False,
num_workers=2
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size = self.batch_size,
shuffle=False,
num_workers=2
)