-
Notifications
You must be signed in to change notification settings - Fork 442
/
data.py
28 lines (24 loc) · 974 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""
@author : Hyunwoong
@when : 2019-10-29
@homepage : https://github.com/gusdnd852
"""
from conf import *
from util.data_loader import DataLoader
from util.tokenizer import Tokenizer
tokenizer = Tokenizer()
loader = DataLoader(ext=('.en', '.de'),
tokenize_en=tokenizer.tokenize_en,
tokenize_de=tokenizer.tokenize_de,
init_token='<sos>',
eos_token='<eos>')
train, valid, test = loader.make_dataset()
loader.build_vocab(train_data=train, min_freq=2)
train_iter, valid_iter, test_iter = loader.make_iter(train, valid, test,
batch_size=batch_size,
device=device)
src_pad_idx = loader.source.vocab.stoi['<pad>']
trg_pad_idx = loader.target.vocab.stoi['<pad>']
trg_sos_idx = loader.target.vocab.stoi['<sos>']
enc_voc_size = len(loader.source.vocab)
dec_voc_size = len(loader.target.vocab)