-
Notifications
You must be signed in to change notification settings - Fork 9
/
predict.py
34 lines (24 loc) · 1.05 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = "Karel Roots"
import numpy as np
"""
@input - model (Object); testing data (List); testing labels (List); model name (String); optional args
Method that predicts target values with given model and calculates the accuracy of the predicitions by mean value of correct answers.
@output - Accuracy value (float), truth values (list)
"""
def predict_accuracy(model, X_test, y_test, model_name, multi_branch=False, tl=False, subj=1, train_size=0.7):
if multi_branch:
probs = model.predict([X_test, X_test, X_test])
else:
probs = model.predict(X_test)
preds = probs.argmax(axis=-1)
equals = preds == y_test.argmax(axis=-1)
acc = np.mean(equals)
if tl:
print(
"Transfer learning classification accuracy for train_size %d ; model %s ; "
"subject %d : %f " % (train_size, model_name, subj, acc))
else:
print("Classification accuracy for %s : %f " % (model_name, acc))
return acc, equals