-
Notifications
You must be signed in to change notification settings - Fork 0
/
statistics.py
83 lines (76 loc) · 3.74 KB
/
statistics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def init():
global trainAcc ; trainAcc = 0
global valAcc ; valAcc = 0
global testAcc ; testAcc = 0
global trainLoss ; trainLoss = 0
global trainLoss_a ; trainLoss_a = 0
global trainLoss_b ; trainLoss_b = 0
global trainLoss_mixup ; trainLoss_mixup = 0
global valLoss ; valLoss = 0
global testLoss ; testLoss = 0
global numCorrect; numCorrect = 0
global numTotal ; numTotal = 0
global localBestValAcc; localBestValAcc = 0
global globalBestValAcc; globalBestValAcc = 0
global earlyStopCountForSupervisedModel; earlyStopCountForSupervisedModel = 0
global earlyStopCountForMainClassifier; earlyStopCountForMainClassifier = 0
global improved; improved = False
init()
def initRound():
global trainAcc ; trainAcc = 0
global valAcc ; valAcc = 0
global testAcc ; testAcc = 0
global trainLoss ; trainLoss = 0
global trainLoss_a ; trainLoss_a = 0
global trainLoss_b ; trainLoss_b = 0
global trainLoss_mixup ; trainLoss_mixup = 0
global valLoss ; valLoss = 0
global testLoss ; testLoss = 0
global numCorrect; numCorrect = 0
global numTotal ; numTotal = 0
global localBestValAcc; localBestValAcc = 0
global earlyStopCountForSupervisedModel; earlyStopCountForSupervisedModel = 0
global earlyStopCountForMainClassifier; earlyStopCountForMainClassifier = 0
global improved; improved = False
def summary():
print('trainAcc =', trainAcc )
print('valAcc =', valAcc )
print('testAcc =', testAcc )
print('trainLoss =', trainLoss )
print('trainLoss_a =', trainLoss_a )
print('trainLoss_b =', trainLoss_b )
print('trainLoss_mixup =', trainLoss_mixup )
print('valLoss =', valLoss )
print('testLoss =', testLoss )
print('numCorrect =', numCorrect )
print('numTotal =', numTotal )
print('localBestValAcc =', localBestValAcc )
print('globalBestValAcc =', globalBestValAcc )
def reset(mode):
global trainAcc ;
global valAcc ;
global testAcc ;
global trainLoss ;
global trainLoss_a ;
global trainLoss_b ;
global trainLoss_mixup ;
global valLoss ;
global testLoss ;
global numCorrect ;
global numTotal ;
global improved;
if mode=='train':
improved = False
trainAcc = 0
trainLoss = 0
trainLoss_a = 0
trainLoss_b = 0
trainLoss_mixup = 0
if mode=='val':
valAcc = 0
valLoss = 0
if mode=='test':
testAcc = 0
testLoss = 0
numCorrect = 0
numTotal = 0