-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
52 lines (43 loc) · 1.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import labelsgen
import inputgen
import model as nn
import numpy as np
import uuid
import sys
import log
# How to use: train.py <labels file> <training epochs> <optional: model name>
# Getting all arguments from terminal
argv = sys.argv
# Creating the model structure
model = nn.create_model_for_training()
# Generating some random input (features) x
x = inputgen.generate_inputs(nn.get_batch_size())
# Using existing labels from argv[1] (labels) y
y = []
if not len(argv) > 1:
log.error("The first required argument is missing: the name of the labels file without extension in the labels folder")
sys.exit(1)
else:
y = labelsgen.csv2arr(argv[1])
#Shuffling our arrays
randomize = np.arange(len(x))
np.random.shuffle(randomize)
x = x[randomize]
y = y[randomize]
x = np.asarray(x.get())
y = np.asarray(y.get())
# Fitting the model
if not len(argv) > 2:
log.error("The second required argument is missing: how many epochs to train")
sys.exit(1)
else:
print('\n')
model.fit(x, y, epochs=int(argv[2]), batch_size=nn.get_batch_size())
# Evaluating the model
model.evaluate(x, y)
model.summary()
# Saving the fitted model
if not len(argv) > 3:
model.save(f'models\\{uuid.uuid4}.h5')
else:
model.save(f'models\\{argv[3]}.h5')