-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathevaluate_mELMO_dELMO.py
executable file
·121 lines (95 loc) · 7.19 KB
/
evaluate_mELMO_dELMO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import core.animation as anim
from core.utils import match_length, inference_err, calculate_average_error
import numpy as np
def main():
data_path = './datasets/evaluation_dataset/mELMO_dELMO/'
result_path = './datasets/evaluation_dataset/results/' + data_path
os.makedirs(result_path, exist_ok=True)
file_paths = []
for root, dirs, files in os.walk(data_path):
for file in files:
if file.endswith(".bvh"):
file_paths.append(os.path.join(root, file))
gt_paths, upsample_paths, base_paths, future_paths, future_aug_paths = [], [], [], [], []
for file_path in file_paths:
filename = os.path.splitext(os.path.basename(file_path))[0]
if "model_20" in filename:
upsample_paths.append(file_path)
elif "model_baseline" in filename:
base_paths.append(file_path)
elif "model_latency" in filename:
future_paths.append(file_path)
elif "model_latsyn" in filename:
future_aug_paths.append(file_path)
else:
gt_paths.append(file_path)
del file_paths
gt_paths.sort()
upsample_paths.sort()
base_paths.sort()
future_paths.sort()
future_aug_paths.sort()
file_names, joint_names, lengths = [], [], []
upsample_pos_errs, upsample_rot_errs, upsample_linvel_errs, upsample_angvel_errs = [], [], [], []
base_pos_errs, base_rot_errs, base_linvel_errs, base_angvel_errs = [], [], [], []
future_pos_errs, future_rot_errs, future_linvel_errs, future_angvel_errs = [], [], [], []
future_aug_pos_errs, future_aug_rot_errs, future_aug_linvel_errs, future_aug_angvel_errs = [], [], [], []
for i in range(len(gt_paths)):
gt, upsample, base, future, future_aug = anim.Animation(), anim.Animation(), anim.Animation(), anim.Animation(), anim.Animation()
gt.load_bvh(gt_paths[i], ftrim=60, btrim=60)
upsample.load_bvh(upsample_paths[i], upsample=3, ftrim=20, btrim=20)
base.load_bvh(base_paths[i], ftrim=60, btrim=60)
future.load_bvh(future_paths[i], ftrim=60, btrim=60)
future_aug.load_bvh(future_aug_paths[i], ftrim=60, btrim=60)
match_length([gt, upsample, base, future, future_aug])
file_names.append(os.path.splitext(os.path.basename(gt_paths[i]))[0])
joint_names = np.insert(gt.joints, 0, 'length')
gt.compute_world_transform(fix_root=True)
upsample.compute_world_transform(fix_root=True)
base.compute_world_transform(fix_root=True)
future.compute_world_transform(fix_root=True)
future_aug.compute_world_transform(fix_root=True)
upsample_metrics = inference_err(upsample, gt)
base_metrics = inference_err(base, gt)
future_metrics = inference_err(future, gt)
future_aug_metrics = inference_err(future_aug, gt)
lengths.append(base_metrics[-1])
upsample_pos_errs.append(np.multiply(upsample_metrics[8], upsample_metrics[-1]))
upsample_rot_errs.append(np.multiply(upsample_metrics[9], upsample_metrics[-1]))
upsample_linvel_errs.append(np.multiply(upsample_metrics[10], upsample_metrics[-1]))
upsample_angvel_errs.append(np.multiply(upsample_metrics[11], upsample_metrics[-1]))
base_pos_errs.append(np.multiply(base_metrics[8], base_metrics[-1]))
base_rot_errs.append(np.multiply(base_metrics[9], base_metrics[-1]))
base_linvel_errs.append(np.multiply(base_metrics[10], base_metrics[-1]))
base_angvel_errs.append(np.multiply(base_metrics[11], base_metrics[-1]))
future_pos_errs.append(np.multiply(future_metrics[8], future_metrics[-1]))
future_rot_errs.append(np.multiply(future_metrics[9], future_metrics[-1]))
future_linvel_errs.append(np.multiply(future_metrics[10], future_metrics[-1]))
future_angvel_errs.append(np.multiply(future_metrics[11], future_metrics[-1]))
future_aug_pos_errs.append(np.multiply(future_aug_metrics[8], future_aug_metrics[-1]))
future_aug_rot_errs.append(np.multiply(future_aug_metrics[9], future_aug_metrics[-1]))
future_aug_linvel_errs.append(np.multiply(future_aug_metrics[10], future_aug_metrics[-1]))
future_aug_angvel_errs.append(np.multiply(future_aug_metrics[11], future_aug_metrics[-1]))
upsample_avg_pos_errs, upsample_avg_rot_errs, upsample_avg_linvel_errs, upsample_avg_angvel_errs = \
calculate_average_error(lengths, upsample_pos_errs, upsample_rot_errs, upsample_linvel_errs, upsample_angvel_errs, joint_names, file_names, result_path, "upsample")
base_avg_pos_errs, base_avg_rot_errs, base_avg_linvel_errs, base_avg_angvel_errs = \
calculate_average_error(lengths, base_pos_errs, base_rot_errs, base_linvel_errs, base_angvel_errs, joint_names, file_names, result_path, "base")
future_avg_pos_errs, future_avg_rot_errs, future_avg_linvel_errs, future_avg_angvel_errs = \
calculate_average_error(lengths, future_pos_errs, future_rot_errs, future_linvel_errs, future_angvel_errs, joint_names, file_names, result_path, "future")
future_aug_avg_pos_errs, future_aug_avg_rot_errs, future_aug_avg_linvel_errs, future_aug_avg_angvel_errs = \
calculate_average_error(lengths, future_aug_pos_errs, future_aug_rot_errs, future_aug_linvel_errs, future_aug_angvel_errs, joint_names, file_names, result_path, "future_aug")
print("------------------ELMO_20 interpolation------------------")
print("joint p: %f, joint r: %f, joint lv: %f, joint av: %f" % (np.mean(upsample_avg_pos_errs[1:]), np.mean(upsample_avg_rot_errs[1:]), np.mean(upsample_avg_linvel_errs[1:]), np.mean(upsample_avg_angvel_errs[1:])))
print("pelv p: %f, pelv r: %f, pelv lv: %f, pelv av: %f" % (upsample_avg_pos_errs[0], upsample_avg_rot_errs[0], upsample_avg_linvel_errs[0], upsample_avg_angvel_errs[0]))
print("------------------ELMO Baseline------------------")
print("joint p: %f, joint r: %f, joint lv: %f, joint av: %f" % (np.mean(base_avg_pos_errs[1:]), np.mean(base_avg_rot_errs[1:]), np.mean(base_avg_linvel_errs[1:]), np.mean(base_avg_angvel_errs[1:])))
print("pelv p: %f, pelv r: %f, pelv lv: %f, pelv av: %f" % (base_avg_pos_errs[0], base_avg_rot_errs[0], base_avg_linvel_errs[0], base_avg_angvel_errs[0]))
print("------------------ELMO Future------------------")
print("joint p: %f, joint r: %f, joint lv: %f, joint av: %f" % (np.mean(future_avg_pos_errs[1:]), np.mean(future_avg_rot_errs[1:]), np.mean(future_avg_linvel_errs[1:]), np.mean(future_avg_angvel_errs[1:])))
print("pelv p: %f, pelv r: %f, pelv lv: %f, pelv av: %f" % (future_avg_pos_errs[0], future_avg_rot_errs[0], future_avg_linvel_errs[0], future_avg_angvel_errs[0]))
print("------------------ELMO Future Augmented------------------")
print("joint p: %f, joint r: %f, joint lv: %f, joint av: %f" % (np.mean(future_aug_avg_pos_errs[1:]), np.mean(future_aug_avg_rot_errs[1:]), np.mean(future_aug_avg_linvel_errs[1:]), np.mean(future_aug_avg_angvel_errs[1:])))
print("pelv p: %f, pelv r: %f, pelv lv: %f, pelv av: %f" % (future_aug_avg_pos_errs[0], future_aug_avg_rot_errs[0], future_aug_avg_linvel_errs[0], future_aug_avg_angvel_errs[0]))
if __name__ == "__main__":
main()