-
Notifications
You must be signed in to change notification settings - Fork 8
/
main.py
76 lines (51 loc) · 2.07 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from dictionary import Dictionary
import numpy as np
from sparse_code import Sparse_code
import pickle
def evaluate_model(dataset, predictions, dictionary):
'''Gives the classification accuracy'''
subIDX = np.unique(dataset.labels_gallery)
n_sbj = len(subIDX)
n_patches = dataset.n_patch
correct = 0
labels = dataset.labels_test
n_imgs = len(predictions)
for i in range(n_imgs):
if i==0:
lbls = labels[i*n_patches]
else:
lbls = np.append(lbls, labels[i*n_patches])
for i in range(len(predictions)):
true_label = lbls[i].astype(int)
ID = predictions[i][0].astype(int)
if true_label == ID:
correct = correct + 1
accuracy = float(correct)/len(predictions)
return accuracy*100
# --------------------------------------- MAIN Function ----------------------------------------------------------------------
def main():
data_path = 'data/lfw_158_sbj.pkl'
dict_path = 'data/lfw_158_sbj_dictionary.pkl'
print('\nLoading data from disk... ')
with open(data_path, 'rb') as input_file:
dataset = pickle.load(input_file)
info = dataset.get_data_info()
print(info)
print('Number of sbjs considered: ' + str(dataset.num_sbj))
# Dictionary Learning --------------------------------------------------------------------------------------------------
print('\nBuilding Dictionaries...')
dictionaries = Dictionary()
dictionaries.dict_learn(Y=dataset.gallery_lda, N=dataset.num_sbj, n_patch=dataset.n_patch, max_iter=10, init_method='data', sparsity=6)
with open(dict_path, "wb") as output_file:
print('\nSaving Dictionary to disk...')
pickle.dump(dictionaries, output_file, protocol=pickle.HIGHEST_PROTOCOL)
print('Done... Saved to: ' + dict_path)
#Classification ----------------------------------------------------------------------------------------------
print('\nClassifying...')
sc = Sparse_code()
IDs = sc.klimaps_classify_learned_dict(dataset, dictionaries)
print('\nNumber of images: ' + str(len(IDs)))
accuracy = evaluate_model(dataset, IDs, dictionaries)
print('\nAccuracy: ' + str(accuracy) + '%')
if __name__ == "__main__":
main()