forked from chenzomi12/AISystem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path05.srt
1063 lines (798 loc) · 17.3 KB
/
05.srt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1
00:00:01,150 --> 00:00:03,150
字幕组:赵含霖 谢鑫鑫
2
00:00:04,750 --> 00:00:08,320
Hello,大家好,我是ZOMI酱
3
00:00:08,320 --> 00:00:10,120
欢迎来到没什么人观看
4
00:00:10,120 --> 00:00:12,840
但是我依然在坚持的ZOMI酱的课堂
5
00:00:12,840 --> 00:00:14,200
在这一节课里面
6
00:00:14,200 --> 00:00:17,320
我主要是想给大家一起去分享一个
7
00:00:17,320 --> 00:00:20,880
前向操作符重载的自动微分的具体实现
8
00:00:20,880 --> 00:00:23,160
所以在这一节课里面
9
00:00:23,160 --> 00:00:25,600
没有了PPT
10
00:00:25,600 --> 00:00:29,560
而是通过Jupyter去承载今天的课程
11
00:00:30,040 --> 00:00:32,520
在这里面重新回顾一下
12
00:00:32,520 --> 00:00:34,840
什么叫做前向自动微分
13
00:00:34,840 --> 00:00:37,720
前项自动微分又叫做Forward model
14
00:00:37,720 --> 00:00:39,480
或者叫做Tangent model
15
00:00:39,480 --> 00:00:44,360
甚至会叫做前向累积梯度的一种方式
16
00:00:44,360 --> 00:00:46,800
公式还是熟悉的那个味道
17
00:00:46,800 --> 00:00:53,120
y=f(x1, x2)=ln(x1)+x1*x2-sin(x2)
18
00:00:53,120 --> 00:00:55,320
在这里面看看有多少个操作
19
00:00:55,320 --> 00:00:58,000
首先第一个就是x1*x2
20
00:00:58,000 --> 00:00:58,800
乘法
21
00:00:58,800 --> 00:01:00,680
第二个就是ln(x1)
22
00:01:00,680 --> 00:01:02,960
第二个第三个就是sin
23
00:01:02,960 --> 00:01:04,680
第四个就是加号
24
00:01:04,680 --> 00:01:06,440
第五个就是减号
25
00:01:06,440 --> 00:01:08,360
有一共五个操作
26
00:01:08,360 --> 00:01:13,360
也就是意味着今天要实现一个五个操作符的重载
27
00:01:14,760 --> 00:01:17,280
下面这个图还是熟悉的味道
28
00:01:17,280 --> 00:01:22,000
左边是整个数学公式的一个前向计算
29
00:01:22,000 --> 00:01:23,360
前向计算的时候
30
00:01:23,360 --> 00:01:25,680
首先会把x1, x2
31
00:01:25,720 --> 00:01:28,640
也就是对应的v-1, v0
32
00:01:28,640 --> 00:01:30,280
赋一个具体的值
33
00:01:30,280 --> 00:01:32,200
然后一步步地去计算
34
00:01:32,200 --> 00:01:34,240
最终得到y
35
00:01:34,240 --> 00:01:37,120
输出等于11.652
36
00:01:37,120 --> 00:01:41,080
右边的这个对应的就是高等数学里面学习的
37
00:01:41,080 --> 00:01:43,640
一元复合函数求导的法则
38
00:01:43,640 --> 00:01:45,560
从最原始的开始
39
00:01:45,560 --> 00:01:48,560
只对x1的导数进行求解
40
00:01:48,560 --> 00:01:50,720
一步步下来可以求得到
41
00:01:50,720 --> 00:01:55,040
y对于x1的导数是等于5.5
42
00:01:55,480 --> 00:01:57,360
回到一个熟悉的图
43
00:01:57,360 --> 00:01:59,320
就是计算图
44
00:01:59,320 --> 00:02:01,400
计算图里面每一个节点
45
00:02:01,400 --> 00:02:03,280
就是中间变量
46
00:02:03,280 --> 00:02:06,800
每一条边代表是一个计算或者一个连接
47
00:02:06,800 --> 00:02:08,000
所以简单的
48
00:02:08,000 --> 00:02:11,280
可以把自动微分分成几个关键的步骤
49
00:02:11,280 --> 00:02:12,080
第一个呢
50
00:02:12,080 --> 00:02:15,000
就是把微分的规则表达出来
51
00:02:15,000 --> 00:02:16,520
然后根据微分的规则
52
00:02:16,520 --> 00:02:19,520
把所有的微分的结果计算出来
53
00:02:19,520 --> 00:02:22,000
第三个就是通过链式求导法则
54
00:02:22,000 --> 00:02:23,800
或者累积的方式
55
00:02:23,800 --> 00:02:25,920
把刚才第二步求得到的
56
00:02:25,920 --> 00:02:29,240
每一步的微分的结果组合起来
57
00:02:29,240 --> 00:02:30,040
那最后呢
58
00:02:30,040 --> 00:02:32,320
就得到输出的结果了
59
00:02:33,600 --> 00:02:35,120
具体的实现当中呢
60
00:02:35,120 --> 00:02:38,040
使用Numpy去代表计算
61
00:02:38,040 --> 00:02:40,160
首先需要实现一个类
62
00:02:40,160 --> 00:02:40,840
这个类呢
63
00:02:40,840 --> 00:02:44,440
有点类似于PyTorch里面的Tensor
64
00:02:44,440 --> 00:02:46,600
或者把它叫做张量
65
00:02:46,600 --> 00:02:47,600
所有的内容
66
00:02:47,600 --> 00:02:50,960
通过这个ADTangent这个类来表示
67
00:02:50,960 --> 00:02:52,520
在初始化的时候
68
00:02:52,640 --> 00:02:54,320
输入有两个
69
00:02:54,320 --> 00:02:55,200
第一个是x
70
00:02:55,200 --> 00:02:56,320
第二个是dx
71
00:02:56,320 --> 00:02:57,040
那x呢
72
00:02:57,040 --> 00:02:59,920
就是一个正式的数值
73
00:02:59,920 --> 00:03:00,880
第二个dx呢
74
00:03:00,880 --> 00:03:03,680
就是对应的求导后的数
75
00:03:05,000 --> 00:03:07,640
另外需要重载String操作
76
00:03:07,640 --> 00:03:09,480
这个String操作的重载方式呢
77
00:03:09,480 --> 00:03:11,480
直接下面前面加两杠
78
00:03:11,480 --> 00:03:12,560
后面加两杠
79
00:03:12,560 --> 00:03:15,560
就是实现Python里面的String重载
80
00:03:15,560 --> 00:03:17,960
String重载是为了方便直接print
81
00:03:17,960 --> 00:03:19,560
这个对象出来的时候
82
00:03:19,600 --> 00:03:22,800
可以按照想要的方式打印出来
83
00:03:22,800 --> 00:03:24,520
可以看到context呢
84
00:03:24,520 --> 00:03:26,200
输出一个value值
85
00:03:26,200 --> 00:03:27,880
也就是self.x
86
00:03:27,880 --> 00:03:28,320
第二个呢
87
00:03:28,320 --> 00:03:30,280
就是输出Gradient的值
88
00:03:30,280 --> 00:03:31,800
就是梯度的值
89
00:03:31,800 --> 00:03:33,760
dx两个数
90
00:03:35,680 --> 00:03:37,160
下面来看看
91
00:03:37,160 --> 00:03:39,320
对加法这个操作符重载
92
00:03:39,320 --> 00:03:41,320
具体是怎么实现的
93
00:03:41,320 --> 00:03:43,240
那在利用Python高级语言
94
00:03:43,240 --> 00:03:44,880
去实现操作符重载
95
00:03:44,880 --> 00:03:47,560
刚才说了String是一种方式
96
00:03:47,560 --> 00:03:49,440
加号也是一种方式
97
00:03:49,440 --> 00:03:53,080
在前面和后面分别加了两个斜杠
98
00:03:53,080 --> 00:03:57,120
代表对这个add操作进行操作符重载
99
00:03:57,120 --> 00:03:57,800
这里面呢
100
00:03:57,800 --> 00:04:00,760
输入other是一个变量
101
00:04:00,760 --> 00:04:02,560
首先去判断一下
102
00:04:02,560 --> 00:04:05,560
这个other是不是一个ADtangent的类
103
00:04:05,560 --> 00:04:06,360
如果是呢
104
00:04:06,360 --> 00:04:10,040
我的x等于self.x加上other.x
105
00:04:10,040 --> 00:04:10,920
这个时候呢
106
00:04:10,920 --> 00:04:13,640
我需要同步的去计算dx
107
00:04:13,640 --> 00:04:16,080
也就是我求导的公式的规则
108
00:04:16,080 --> 00:04:18,280
这个求导的公式等于self.dx
109
00:04:18,320 --> 00:04:20,320
加上other.dx
110
00:04:20,320 --> 00:04:23,280
如果我的输入x是一个值
111
00:04:23,280 --> 00:04:25,240
也就是等于float
112
00:04:25,240 --> 00:04:28,120
加法里面我对本身的数进行求导
113
00:04:28,120 --> 00:04:30,080
我的other就会去掉
114
00:04:30,080 --> 00:04:31,320
就是这个数呢
115
00:04:31,320 --> 00:04:32,280
等于0
116
00:04:32,280 --> 00:04:33,240
所以我的x呢
117
00:04:33,240 --> 00:04:34,320
正向计算的时候
118
00:04:34,320 --> 00:04:36,360
等于self.x加上other
119
00:04:36,360 --> 00:04:37,600
而求导的时候呢
120
00:04:37,600 --> 00:04:40,600
就直接是dx=self.dx了
121
00:04:40,600 --> 00:04:42,480
只实现了两种求导方式
122
00:04:42,480 --> 00:04:45,960
就是输入的是我对本身的这个数进行求导
123
00:04:45,960 --> 00:04:48,320
或者我输入是一个常量
124
00:04:48,320 --> 00:04:50,760
只有这两种方式加号
125
00:04:50,760 --> 00:04:53,840
其他方式并没有去实现
126
00:04:53,840 --> 00:04:54,760
那下面呢
127
00:04:54,760 --> 00:04:58,240
有了对add操作符重载这个理念之后呢
128
00:04:58,240 --> 00:05:02,280
我可以对减号进行操作符重载
129
00:05:02,280 --> 00:05:03,240
如果我的输入呢
130
00:05:03,240 --> 00:05:05,960
同样是一个ADtangent的对象
131
00:05:05,960 --> 00:05:07,040
我的求导呢
132
00:05:07,040 --> 00:05:10,560
就是等于两个数实际上没有变化
133
00:05:10,560 --> 00:05:12,960
否则我是对自己进行求导
134
00:05:12,960 --> 00:05:14,080
这里面应该是d
135
00:05:16,480 --> 00:05:19,000
然后我的乘法也是相同的
136
00:05:19,000 --> 00:05:20,040
因为我的数呢
137
00:05:20,040 --> 00:05:21,760
是一个求导的对象
138
00:05:21,760 --> 00:05:23,720
所以我的乘法了就变成展开
139
00:05:23,720 --> 00:05:25,160
就变得比较复杂了
140
00:05:25,160 --> 00:05:27,440
等于self.x乘以other.dx
141
00:05:27,440 --> 00:05:30,480
再加上self.dx加上others.x
142
00:05:30,480 --> 00:05:32,680
如果我输的是一个数值
143
00:05:32,680 --> 00:05:34,520
那数值就很好办了
144
00:05:34,520 --> 00:05:36,000
就数值的导数呢
145
00:05:36,000 --> 00:05:38,400
等于other乘以self.dx
146
00:05:38,400 --> 00:05:41,480
其他方式同样没有去实现的哦
147
00:05:41,480 --> 00:05:44,360
另外还有第四个操作就是log
148
00:05:44,360 --> 00:05:45,920
log没有其他输入
149
00:05:45,920 --> 00:05:49,800
直接是对当前的数进行求log的导数
150
00:05:49,800 --> 00:05:51,400
同样sin也是
151
00:05:51,400 --> 00:05:52,560
那sin的导数呢
152
00:05:52,560 --> 00:05:55,920
就变成cos乘以dx了
153
00:05:57,280 --> 00:06:00,320
回到熟悉的味道这一条公式
154
00:06:00,320 --> 00:06:01,320
一开始的时候呢
155
00:06:01,320 --> 00:06:03,280
我要初始化两个变量
156
00:06:03,280 --> 00:06:03,720
第一个呢
157
00:06:03,720 --> 00:06:05,440
就是我的x1等于2
158
00:06:05,440 --> 00:06:05,880
第二个呢
159
00:06:05,880 --> 00:06:08,320
就是我的x2等于5
160
00:06:08,320 --> 00:06:09,560
在计算的时候呢
161
00:06:09,560 --> 00:06:12,360
我只需要计算我的x1
162
00:06:12,360 --> 00:06:14,800
然后我去计算我的x2
163
00:06:14,800 --> 00:06:16,960
不需要求它的导数
164
00:06:16,960 --> 00:06:19,920
所以x2的dx等于0
165
00:06:19,920 --> 00:06:23,880
那对应公式1的实现就是ADTangent的log(x)
166
00:06:23,880 --> 00:06:26,760
加上x乘以x2
167
00:06:26,760 --> 00:06:30,160
然后再减去ADTangent的sin(x2)
168
00:06:30,160 --> 00:06:32,320
这个就是正式的计算了
169
00:06:32,320 --> 00:06:35,760
printf出来直接完成整体的计算
170
00:06:35,760 --> 00:06:37,560
非常有意思
171
00:06:37,560 --> 00:06:40,160
value等于11.6521
172
00:06:40,240 --> 00:06:42,080
Gradient等于5.5
173
00:06:42,080 --> 00:06:46,560
可以看到跟刚才这个熟悉的图是一模一样的
174
00:06:46,560 --> 00:06:49,280
正向计算的时候等于11.652
175
00:06:49,280 --> 00:06:51,440
那可能做了一个小数的归一
176
00:06:51,440 --> 00:06:55,960
另外y对x的求导等于5.5
177
00:06:55,960 --> 00:06:58,760
为什么导数这么快就求出来了
178
00:06:58,760 --> 00:07:01,560
ADTangent这个类的时候呢
179
00:07:01,560 --> 00:07:04,480
同时把每一个操作加法的时候
180
00:07:04,480 --> 00:07:08,480
还有求导的时候的计算都同时求出来了
181
00:07:08,480 --> 00:07:10,160
包括dx
182
00:07:10,160 --> 00:07:11,640
还有mul
183
00:07:11,640 --> 00:07:12,400
加号
184
00:07:12,400 --> 00:07:13,360
还有log
185
00:07:13,360 --> 00:07:14,480
还有sin
186
00:07:14,480 --> 00:07:17,120
五个操作都同时重载了
187
00:07:17,120 --> 00:07:21,280
所以在最后五个操作都同时重载了
188
00:07:21,280 --> 00:07:22,000
因此呢
189
00:07:22,000 --> 00:07:24,720
在最后计算的时候直接print出来
190
00:07:24,720 --> 00:07:29,080
就可以把计算结果中全部都打印出来
191
00:07:29,080 --> 00:07:34,120
那下面来看看在PyTorch和MindSpore对应是怎么实现的
192
00:07:34,120 --> 00:07:36,240
PyTorch的实现比较简单
193
00:07:36,280 --> 00:07:39,800
同样需要去声明一个gradient
194
00:07:39,800 --> 00:07:41,240
就是variable
195
00:07:42,760 --> 00:07:46,600
同样需要去声明我这个变量x1等于2
196
00:07:46,600 --> 00:07:47,360
x2呢
197
00:07:47,360 --> 00:07:48,440
等于5
198
00:07:48,440 --> 00:07:50,640
然后去计算log
199
00:07:51,960 --> 00:07:53,480
这个时候可以看到啊
200
00:07:53,480 --> 00:07:54,680
这条公式呢
201
00:07:54,680 --> 00:07:57,680
跟刚才自己去实现的公式一样的
202
00:07:57,680 --> 00:07:58,280
只是呢
203
00:07:58,280 --> 00:08:00,760
它实现的一个框架叫做Torch
204
00:08:00,760 --> 00:08:02,560
叫做adTangent
205
00:08:02,560 --> 00:08:03,120
然后呢
206
00:08:03,120 --> 00:08:08,320
它通过f.backward()去声明我需要进行反向梯度操作了
207
00:08:08,320 --> 00:08:10,840
就是去声明我的required gradient等于true
208
00:08:10,840 --> 00:08:12,200
这个变量是真的
209
00:08:12,200 --> 00:08:15,000
然后把反向的图求出来
210
00:08:15,000 --> 00:08:15,520
然后呢
211
00:08:15,520 --> 00:08:18,720
把printf出来就是11.6521
212
00:08:18,720 --> 00:08:20,800
跟计算是相同的
213
00:08:20,800 --> 00:08:24,640
然后对于x的gradient也是5.5000
214
00:08:24,640 --> 00:08:26,920
这个也是相同的
215
00:08:26,920 --> 00:08:28,600
在MindSpore的实现里面呢
216
00:08:28,600 --> 00:08:29,520
比较特别
217
00:08:29,520 --> 00:08:34,080
因为MindSpore一切皆函数是一个函数设定编程框架
218
00:08:34,080 --> 00:08:40,080
所以会把刚才那条实现的公式包成一个函数去实现
219
00:08:40,080 --> 00:08:40,840
那这个呢
220
00:08:40,840 --> 00:08:43,720
就在Construct function里面去实现的
221
00:08:43,720 --> 00:08:48,080
同样需要去声明两个变量
222
00:08:48,080 --> 00:08:48,520
第一个呢
223
00:08:48,520 --> 00:08:50,200
就是我的x1等于2
224
00:08:50,200 --> 00:08:50,680
第二个呢
225
00:08:50,680 --> 00:08:52,480
就是我的x2等于y
226
00:08:52,480 --> 00:08:53,240
然后呢
227
00:08:53,240 --> 00:08:55,840
初始化这个函数f
228
00:08:57,480 --> 00:08:58,040
接着呢
229
00:08:58,040 --> 00:09:02,400
我去声明这个函数需要进行计算梯度
230
00:09:02,400 --> 00:09:03,800
然后通过gradient
231
00:09:03,800 --> 00:09:06,520
然后把反函数function包起来
232
00:09:06,520 --> 00:09:08,760
把x1,x2输进去
233
00:09:08,760 --> 00:09:09,360
然后呢
234
00:09:09,360 --> 00:09:12,600
这个时候就求得了一个梯度了
235
00:09:12,600 --> 00:09:17,600
这里面printf就是把这个Fun正向的计算
236
00:09:17,600 --> 00:09:18,760
计算出来了
237
00:09:18,760 --> 00:09:19,280
第二个呢
238
00:09:19,280 --> 00:09:21,680
就是print(grad[0])
239
00:09:21,680 --> 00:09:26,000
通过AI框架的自动微分的方式把grad求导出来
240
00:09:26,000 --> 00:09:27,080
为什么是0呢
241
00:09:27,120 --> 00:09:29,520
因为还有一个grad是x2
242
00:09:29,520 --> 00:09:31,320
就是对应这里面的y
243
00:09:31,320 --> 00:09:31,800
这边呢
244
00:09:31,800 --> 00:09:32,960
就不再打印
245
00:09:32,960 --> 00:09:36,880
只需要正向计算的求得5.5就可以了
246
00:09:38,480 --> 00:09:38,960
好了
247
00:09:38,960 --> 00:09:39,880
谢谢各位
248
00:09:39,880 --> 00:09:41,680
欢乐的时间过得特别快
249
00:09:41,680 --> 00:09:43,560
又是时候说拜拜
250
00:09:43,560 --> 00:09:48,600
今天学习了一个前向操作符自动重载的微分实现方式