-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathrun_tests.py
136 lines (130 loc) · 4.33 KB
/
run_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from hot_pytorch.tests import test_set_util, test_dense_util, test_sparse_util
from hot_pytorch.tests import test_dense_linear, test_sparse_linear
from hot_pytorch.tests import test_dense_attn
from hot_pytorch.tests import test_sparse_attn
if __name__ == '__main__':
try:
test_set_util.test_masking()
print('Set util test passed')
except Exception as e:
print(e)
print('Set util test failed')
try:
test_dense_util.test_diag()
test_dense_util.test_nondiag()
test_dense_util.test_rotate()
test_dense_util.test_batch()
test_dense_util.test_batch_fn()
print('Dense util test passed')
except Exception as e:
print(e)
print('Dense util test failed')
try:
test_sparse_util.test_make_batch()
test_sparse_util.test_transpose()
test_sparse_util.test_diag()
test_sparse_util.test_nondiag()
print('Sparse util test passed')
except Exception as e:
print(e)
print('Sparse util test failed')
try:
test_dense_linear.test_mask()
test_dense_linear.test_1_0()
test_dense_linear.test_1_1()
test_dense_linear.test_1_2()
test_dense_linear.test_2_0()
test_dense_linear.test_2_1()
test_dense_linear.test_2_2()
print('Dense linear subroutine test passed')
except Exception as e:
print(e)
print('Dense linear subroutine test failed')
try:
test_dense_linear.test_forward()
test_dense_linear.test_pool()
print('Dense linear forward test passed')
except Exception as e:
print(e)
print('Dense linear forward test failed')
try:
test_dense_linear.test_backward()
print('Dense linear backward test passed')
except Exception as e:
print(e)
print('Dense linear backward test failed')
try:
test_dense_attn.test_attn()
test_dense_attn.test_kernel_attn()
print('Dense attn subroutine test passed')
except Exception as e:
print(e)
print('Dense attn subroutine test failed')
try:
test_dense_attn.test_selfattn()
print('Dense attn forward test passed')
except Exception as e:
print(e)
print('Dense attn forward test failed')
try:
test_dense_attn.test_kernelselfattn()
print('Dense kernel attn forward test passed')
except Exception as e:
print(e)
print('Dense kernel attn forward test failed')
try:
test_dense_attn.test_backward()
print('Dense attn backward test passed')
except Exception as e:
print(e)
print('Dense attn backward test failed')
try:
test_sparse_linear.test_unique()
test_sparse_linear.test_loop_mask()
test_sparse_linear.test_1_0()
test_sparse_linear.test_1_1()
test_sparse_linear.test_2_0()
test_sparse_linear.test_2_1()
test_sparse_linear.test_2_2()
print('Sparse linear subroutine test passed')
except Exception as e:
print(e)
print('Sparse linear subroutine test failed')
try:
test_sparse_linear.test_forward()
test_sparse_linear.test_pool()
print('Sparse linear forward test passed')
except Exception as e:
print(e)
print('Sparse linear forward test failed')
try:
test_sparse_linear.test_backward()
print('Sparse linear backward test passed')
except Exception as e:
print(e)
print('Sparse linear backward test failed')
try:
test_sparse_attn.test_attn()
test_sparse_attn.test_kernel_attn()
print('Sparse attn subroutine test passed')
except Exception as e:
print(e)
print('Sparse attn subroutine test failed')
try:
test_sparse_attn.test_selfattn()
print('Sparse attn forward test passed')
except Exception as e:
print(e)
print('Sparse attn forward test failed')
try:
test_sparse_attn.test_kernelselfattn()
print('Sparse kernel attn forward test passed')
except Exception as e:
print(e)
print('Sparse kernel attn forward test failed')
try:
test_sparse_attn.test_backward()
print('Sparse attn backward test passed')
except Exception as e:
print(e)
print('Sparse attn backward test failed')