-
Notifications
You must be signed in to change notification settings - Fork 0
/
index.html
1609 lines (1506 loc) · 134 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<html lang="zh-cn">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1">
<title>Naruto's AI blog</title>
<meta name="renderer" content="webkit" />
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1"/>
<meta http-equiv="Cache-Control" content="no-transform" />
<meta http-equiv="Cache-Control" content="no-siteapp" />
<meta name="theme-color" content="#f8f5ec" />
<meta name="msapplication-navbutton-color" content="#f8f5ec">
<meta name="apple-mobile-web-app-capable" content="yes">
<meta name="apple-mobile-web-app-status-bar-style" content="#f8f5ec">
<meta name="author" content="Naruto" /><meta name="description" content="Artificial Intelligence serves humanity." /><meta name="keywords" content="Deep Learning, Medical Field, Artificial Intelligence" />
<meta name="generator" content="Hugo 0.66.0 with theme even" />
<link rel="canonical" href="https://Naruto-AI-WY.github.io/" />
<link href="https://Naruto-AI-WY.github.io/index.xml" rel="alternate" type="application/rss+xml" title="Naruto's AI blog" />
<link href="https://Naruto-AI-WY.github.io/index.xml" rel="feed" type="application/rss+xml" title="Naruto's AI blog" />
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png">
<link rel="icon" type="image/png" sizes="32x32" href="/favicon-32x32.png">
<link rel="icon" type="image/png" sizes="16x16" href="/favicon-16x16.png">
<link rel="manifest" href="/manifest.json">
<link rel="mask-icon" href="/safari-pinned-tab.svg" color="#5bbad5">
<script async src="//busuanzi.ibruce.info/busuanzi/2.3/busuanzi.pure.mini.js"></script>
<link href="/dist/even.c2a46f00.min.css" rel="stylesheet">
<meta property="og:title" content="Naruto's AI blog" />
<meta property="og:description" content="Artificial Intelligence serves humanity." />
<meta property="og:type" content="website" />
<meta property="og:url" content="https://Naruto-AI-WY.github.io/" />
<meta property="og:updated_time" content="2020-03-17T21:38:52+08:00" />
<meta itemprop="name" content="Naruto's AI blog">
<meta itemprop="description" content="Artificial Intelligence serves humanity."><meta name="twitter:card" content="summary"/>
<meta name="twitter:title" content="Naruto's AI blog"/>
<meta name="twitter:description" content="Artificial Intelligence serves humanity."/>
<!--[if lte IE 9]>
<script src="https://cdnjs.cloudflare.com/ajax/libs/classlist/1.1.20170427/classList.min.js"></script>
<![endif]-->
<!--[if lt IE 9]>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/html5shiv.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dest/respond.min.js"></script>
<![endif]-->
</head>
<body>
<div id="mobile-navbar" class="mobile-navbar">
<div class="mobile-header-logo">
<a href="/" class="logo">Artificial Intelligence</a>
</div>
<div class="mobile-navbar-icon">
<span></span>
<span></span>
<span></span>
</div>
</div>
<nav id="mobile-menu" class="mobile-menu slideout-menu">
<ul class="mobile-menu-list">
<a href="/">
<li class="mobile-menu-item">首页</li>
</a><a href="/post/">
<li class="mobile-menu-item">档案</li>
</a><a href="/tags/">
<li class="mobile-menu-item">标签</li>
</a><a href="/categories/">
<li class="mobile-menu-item">分类</li>
</a><a href="/about/">
<li class="mobile-menu-item">About</li>
</a>
</ul>
</nav>
<div class="container" id="mobile-panel">
<header id="header" class="header">
<div class="logo-wrapper">
<a href="/" class="logo">Artificial Intelligence</a>
</div>
<nav class="site-navbar">
<ul id="menu" class="menu">
<li class="menu-item">
<a class="menu-item-link" href="/">首页</a>
</li><li class="menu-item">
<a class="menu-item-link" href="/post/">档案</a>
</li><li class="menu-item">
<a class="menu-item-link" href="/tags/">标签</a>
</li><li class="menu-item">
<a class="menu-item-link" href="/categories/">分类</a>
</li><li class="menu-item">
<a class="menu-item-link" href="/about/">About</a>
</li>
</ul>
</nav>
</header>
<main id="main" class="main">
<div class="content-wrapper">
<div id="content" class="content">
<section id="posts" class="posts">
<article class="post">
<header class="post-header">
<h1 class="post-title"><a class="post-link" href="/post/datawhale_cv5/">Datawhale 零基础入门CV赛事-模型集成</a></h1>
<div class="post-meta">
<span class="post-time"> 2020-06-02 </span>
<div class="post-category">
<a href="/categories/computer-vision/"> computer vision </a>
</div>
<span class="more-meta"> 约 2302 字 </span>
<span class="more-meta"> 预计阅读 5 分钟 </span>
</div>
</header>
<div class="post-content">
<div class="post-summary">
<h2 id="集成学习方法">集成学习方法</h2>
<p>在机器学习中的集成学习可以在一定程度上提高预测精度,常见的集成学习方法有Stacking、Bagging和Boosting,同时这些集成学习方法与具体验证集划分联系紧密。</p>
<h3 id="bagging">Bagging</h3>
<p>Bagging是bootstrap aggregating的简写。bootstrap也称为自助法,它是一种有放回的抽样方法,目的为了得到统计量的分布以及置信区间。具体步骤如下:</p>
<ul>
<li>采用重抽样方法(有放回抽样)从原始样本中抽取一定数量的样本</li>
<li>根据抽出的样本计算想要得到的统计量T</li>
<li>重复上述N次(一般大于1000),得到N个统计量T</li>
<li>根据这N个统计量,即可计算出统计量的置信区间</li>
</ul>
<p>在Bagging方法中,利用bootstrap方法从整体数据集中采取有放回抽样得到N个数据集,在每个数据集上学习出一个模型,最后的预测结果利用N个模型的输出得到,具体地:分类问题采用N个模型预测投票的方式,回归问题采用N个模型预测平均的方式。</p>
<p>例如随机森林(Random Forest)就属于Bagging。随机森林简单地来说就是用随机的方式建立一个森林,森林由很多的决策树组成,随机森林的每一棵决策树之间是没有关联的。</p>
<p>在我们学习每一棵决策树的时候就需要用到Bootstrap方法。在随机森林中,有两个随机采样的过程:对输入数据的行(数据的数量)与列(数据的特征)都进行采样。对于行采样,采用有放回的方式,若有N个数据,则采样出N个数据(可能有重复),这样在训练的时候每一棵树都不是全部的样本,相对而言不容易出现overfitting;接着进行列采样从M个feature中选择出m个(m«M)。最近进行决策树的学习。</p>
<p>预测的时候,随机森林中的每一棵树的都对输入进行预测,最后进行投票,哪个类别多,输入样本就属于哪个类别。这就相当于前面说的,每一个分类器(每一棵树)都比较弱,但组合到一起(投票)就比较强了。</p>
<p><img src="/img/bagging.jpeg" alt="CV5"></p>
<h3 id="boosting">Boosting</h3>
<ul>
<li>提升方法(Boosting)是一种可以用来减小监督学习中偏差的机器学习算法。主要也是学习一系列弱分类器,并将其组合为一个强分类器。Boosting中有代表性的是AdaBoost(Adaptive boosting)</li>
<li>算法:刚开始训练时对每一个训练例赋相等的权重,然后用该算法对训练集训练t轮,每次训练后,对训练失败的训练例赋以较大的权重,也就是让学习算法在每次学习以后更注意学错的样本,从而得到多个预测函数。</li>
</ul>
<p><img src="/img/boosting.jpeg" alt="CV5"></p>
<ul>
<li>xdgboost</li>
</ul>
<h3 id="stacking">Stacking</h3>
<p>Stacking 是一种集成学习技术,通过元分类器或元回归聚合多个分类或回归模型。基础层次模型(level model)基于完整的训练集进行训练,然后元模型基于基础层次模型的输出进行训练。<br>
基础层次通常由不同的学习算法组成,因此 stacking 集成通常是异构的。下面的算法概括了 stacking 算法的逻辑:</p>
<p><img src="/img/stacking.png" alt="CV5"></p>
<p>由于深度学习模型一般需要较长的训练周期,如果硬件设备不允许,则建议选取留出法,如果需要追求精度可以使用交叉验证的方法。</p>
<p>下面假设构建了10折交叉验证,训练得到10个CNN模型。<br>
<img src="/img/crossvalidation.png" alt="CV5"><br>
那么在10个CNN模型可以使用如下方式进行集成:</p>
<ul>
<li>对预测的结果的概率值进行平均,然后解码为具体字符;</li>
<li>对预测的字符进行投票,得到最终字符。</li>
</ul>
<h2 id="深度学习中的集成学习">深度学习中的集成学习</h2>
<p>在深度学习中本身还有一些集成学习思路的做法,值得借鉴学习:</p>
<h3 id="dropout">Dropout</h3>
<p>Dropout可以作为训练深度神经网络的一种技巧。在每个训练批次中,通过随机让一部分的节点停止工作。但是在预测的过程中让所有的节点都起作用。<br>
<img src="/img/Droopout.png" alt="CV5"></p>
<p>Dropout经常出现在在先有的CNN网络中,可以有效的缓解模型过拟合的情况,也可以在预测时增加模型的精度。</p>
<p>加入Dropout后的网络结构如下:</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="c1"># 定义模型</span>
<span class="k">class</span> <span class="nc">SVHN_Model1</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">SVHN_Model1</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="c1"># CNN提取特征模块</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cnn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.25</span><span class="p">),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.25</span><span class="p">),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
<span class="p">)</span>
<span class="c1"># </span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc4</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc5</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc6</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
<span class="n">feat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cnn</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">feat</span> <span class="o">=</span> <span class="n">feat</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">feat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">c1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c3</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc3</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c4</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc4</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c5</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc5</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c6</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc6</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="k">return</span> <span class="n">c1</span><span class="p">,</span> <span class="n">c2</span><span class="p">,</span> <span class="n">c3</span><span class="p">,</span> <span class="n">c4</span><span class="p">,</span> <span class="n">c5</span><span class="p">,</span> <span class="n">c6</span>
</code></pre></td></tr></table>
</div>
</div><h3 id="tta">TTA</h3>
<p>测试集数据扩增(Test Time Augmentation,简称TTA)也是常用的集成学习技巧,数据扩增不仅可以在训练时候用,而且可以同样在预测时候进行数据扩增,对同一个样本预测三次,然后对三次结果进行平均。</p>
<p><img src="/img/tta.png" alt="CV5"></p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="n">test_loader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">tta</span><span class="o">=</span><span class="mi">10</span><span class="p">):</span>
<span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="n">test_pred_tta</span> <span class="o">=</span> <span class="bp">None</span>
<span class="c1"># TTA 次数</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tta</span><span class="p">):</span>
<span class="n">test_pred</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">test_loader</span><span class="p">):</span>
<span class="n">c0</span><span class="p">,</span> <span class="n">c1</span><span class="p">,</span> <span class="n">c2</span><span class="p">,</span> <span class="n">c3</span><span class="p">,</span> <span class="n">c4</span><span class="p">,</span> <span class="n">c5</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">c0</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">c1</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span>
<span class="n">c2</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">c3</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span>
<span class="n">c4</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">c5</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">numpy</span><span class="p">()],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">test_pred</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
<span class="n">test_pred</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">(</span><span class="n">test_pred</span><span class="p">)</span>
<span class="k">if</span> <span class="n">test_pred_tta</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">test_pred_tta</span> <span class="o">=</span> <span class="n">test_pred</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">test_pred_tta</span> <span class="o">+=</span> <span class="n">test_pred</span>
<span class="k">return</span> <span class="n">test_pred_tta</span>
</code></pre></td></tr></table>
</div>
</div><h3 id="snapshot">Snapshot</h3>
<p>假设我们训练了10个CNN则可以将多个模型的预测结果进行平均。但是如果只训练了一个CNN模型,如何做模型集成呢?</p>
<p>在论文Snapshot Ensembles中,作者提出使用cyclical learning rate进行训练模型,并保存精度比较好的一些checkopint,最后将多个checkpoint进行模型集成。<br>
原论文:<a href="https://openreview.net/pdf?id=BJYwwY9ll">《SNAPSHOT ENSEMBLES: TRAIN 1, GET M FOR FREE》</a></p>
<p><img src="/img/Snapshot.png" alt="CV5"></p>
<p>由于在cyclical learning rate中学习率的变化有周期性变大和减少的行为,因此CNN模型很有可能在跳出局部最优进入另一个局部最优。在Snapshot论文中作者通过使用表明,此种方法可以在一定程度上提高模型精度,但需要更长的训练时间。</p>
<p><img src="/img/%E5%AF%B9%E6%AF%94.png" alt="CV5"></p>
<h2 id="结果后处理">结果后处理</h2>
<p>在不同的任务中可能会有不同的解决方案,不同思路的模型不仅可以互相借鉴,同时也可以修正最终的预测结果。</p>
<p>在本次赛题中,可以从以下几个思路对预测结果进行后处理:</p>
<ul>
<li>统计图片中每个位置字符出现的频率,使用规则修正结果;</li>
<li>单独训练一个字符长度预测模型,用来预测图片中字符个数,并修正结果。</li>
</ul>
<h2 id="小结">小结</h2>
<ul>
<li>集成学习只能在一定程度上提高精度,并需要耗费较大的训练时间,因此建议先使用提高单个模型的精度,再考虑集成学习过程;</li>
<li>具体的集成学习方法需要与验证集划分方法结合,Dropout和TTA在所有场景有可以起作用。</li>
</ul>
</div>
<div class="read-more">
<a href="/post/datawhale_cv5/" class="read-more-link">阅读更多</a>
</div>
</div>
</article>
<article class="post">
<header class="post-header">
<h1 class="post-title"><a class="post-link" href="/post/datawhale_cv4/">Datawhale_零基础入门CV赛事-模型训练与验证</a></h1>
<div class="post-meta">
<span class="post-time"> 2020-05-28 </span>
<div class="post-category">
<a href="/categories/computer-vision/"> computer vision </a>
</div>
<span class="more-meta"> 约 3364 字 </span>
<span class="more-meta"> 预计阅读 7 分钟 </span>
</div>
</header>
<div class="post-content">
<div class="post-summary">
<h2 id="模型训练与验证">模型训练与验证</h2>
<h3 id="构造验证集">构造验证集</h3>
<p>在机器学习模型(特别是深度学习模型)的训练过程中,模型是非常容易过拟合的。深度学习模型在不断的训练过程中训练误差会逐渐降低,但测试误差的走势则不一定。<br>
在模型的训练过程中,模型只能利用训练数据来进行训练,模型并不能接触到测试集上的样本。因此模型如果将训练集学的过好,模型就会记住训练样本的细节,导致模型在测试集的泛化效果较差,这种现象称为过拟合(Overfitting)。与过拟合相对应的是欠拟合(Underfitting),即模型在训练集上的拟合效果较差。<br>
<img src="/img/Error.png" alt="CV3"><br>
如图所示:随着模型复杂度和模型训练轮数的增加,CNN模型在训练集上的误差会降低,但在测试集上的误差会逐渐降低,然后逐渐升高,而我们为了追求的是模型在测试集上的精度越高越好。</p>
<p>导致模型过拟合的情况有很多种原因,其中最为常见的情况是模型复杂度(Model Complexity )太高,导致模型学习到了训练数据的方方面面,学习到了一些细枝末节的规律。</p>
<p>解决上述问题最好的解决方法:构建一个与测试集尽可能分布一致的样本集(可称为验证集),在训练过程中不断验证模型在验证集上的精度,并以此控制模型的训练。</p>
<p>在给定赛题后,赛题方会给定训练集和测试集两部分数据。参赛者需要在训练集上面构建模型,并在测试集上面验证模型的泛化能力。因此参赛者可以通过提交模型对测试集的预测结果,来验证自己模型的泛化能力。同时参赛方也会限制一些提交的次数限制,以此避免参赛选手“刷分”。<br>
在一般情况下,参赛选手也可以自己在本地划分出一个验证集出来,进行本地验证。训练集、验证集和测试集分别有不同的作用:</p>
<ul>
<li>训练集(Train Set):模型用于训练和调整模型参数;</li>
<li>验证集(Validation Set):用来验证模型精度和调整模型超参数;</li>
<li>测试集(Test Set):验证模型的泛化能力。<br>
因为训练集和验证集是分开的,所以模型在验证集上面的精度在一定程度上可以反映模型的泛化能力。在划分验证集的时候,需要注意验证集的分布应该与测试集尽量保持一致,不然模型在验证集上的精度就失去了指导意义。</li>
</ul>
<p>数据划分的方法并没有明确的规定,不过可以参考3个原则:</p>
<ul>
<li>对于小规模样本集(几万量级),常用的分配比例是 60% 训练集、20% 验证集、20% 测试集。</li>
<li>对于大规模样本集(百万级以上),只要验证集和测试集的数量足够即可,例如有 100w 条数据,那么留 1w 验证集,1w 测试集即可。1000w 的数据,同样留 1w 验证集和 1w 测试集。</li>
<li>超参数越少,或者超参数很容易调整,那么可以减少验证集的比例,更多的分配给训练集。</li>
</ul>
<p>既然验证集这么重要,那么如何划分本地验证集呢。在一些比赛中,赛题方会给定验证集;如果赛题方没有给定验证集,那么参赛选手就需要从训练集中拆分一部分得到验证集。验证集的划分有如下几种方式:
<img src="/img/Validation_CV.png" alt="CV3"></p>
<ul>
<li>
<p>留出法(Hold-Out)
直接将训练集划分成两部分,新的训练集和验证集。这种划分方式的优点是最为直接简单;缺点是只得到了一份验证集,有可能导致模型在验证集上过拟合。留出法应用场景是数据量比较大的情况。</p>
</li>
<li>
<p>交叉验证法(Cross Validation,CV)<br>
将训练集划分成K份,将其中的K-1份作为训练集,剩余的1份作为验证集,循环K训练。这种划分方式是所有的训练集都是验证集,最终模型验证精度是K份平均得到。这种方式的优点是验证集精度比较可靠,训练K次可以得到K个有多样性差异的模型;CV验证的缺点是需要训练K次,不适合数据量很大的情况。</p>
<p>K-fold Cross Validation具体步骤如下:</p>
<ol>
<li>将数据集分为训练集和测试集,将测试集放在一边。</li>
<li>将训练集分为k份。</li>
<li>每次使用k份中的1份作为验证集,其他全部作为训练集。</li>
<li>通过k次训练后,我们得到了k个不同的模型。</li>
<li>评估k个模型的效果,从中挑选效果最好的超参数。</li>
<li>使用最优的超参数,然后将k份数据全部作为训练集重新训练模型,得到最终模型。<br>
k一般取10。数据量小的时候,k可以设大一点,这样训练集占整体比例就比较大,不过同时训练的模型个数也增多。 数据量大的时候,k 可以设小一点。</li>
</ol>
</li>
<li>
<p>自助采样法(BootStrap)<br>
在统计学中,自助法(Bootstrap Method,Bootstrapping,或自助抽样法)是一种从给定训练集中有放回的均匀抽样,也就是说,每当选中一个样本,它等可能地被再次选中并被再次添加到训练集中。机器学习中可通过交叉验证评估模型效果,当我们的数据量特别小的时候,我们可以采用自助法。比如我们有m个样本(m较小),每次在这m个样本中随机采集一个样本,放入训练集,采样完后把样本放回。这样重复采集m次,我们得到m个样本组成的训练集。m次采样过程中,有的样本可能会被重复采样,有的样本从来没有被采用过,可将这些没有被采样过的样本作为验证集,进行模型验证。由于我们的训练集有重复数据,这会改变数据的分布,因而训练结果会有估计偏差,因此,此种方法不是很常用,除非数据量真的很少。</p>
</li>
</ul>
<p>通过有放回的采样方式得到新的训练集和验证集,每次的训练集和验证集都是有区别的。这种划分方式一般适用于数据量较小的情况。<br>
在本次赛题中已经划分为验证集,因此选手可以直接使用训练集进行训练,并使用验证集进行验证精度(当然你也可以合并训练集和验证集,自行划分验证集)。</p>
<p>当然这些划分方法是从数据划分方式的角度来讲的,在现有的数据比赛中一般采用的划分方法是留出法和交叉验证法。如果数据量比较大,留出法还是比较合适的。当然任何的验证集的划分得到的验证集都是要保证训练集-验证集-测试集的分布是一致的,所以如果不管划分何种的划分方式都是需要注意的。</p>
<p>这里的分布一般指的是与标签相关的统计分布,比如在分类任务中“分布”指的是标签的类别分布,训练集-验证集-测试集的类别分布情况应该大体一致;如果标签是带有时序信息,则验证集和测试集的时间间隔应该保持一致。</p>
<h3 id="评估指标">评估指标</h3>
<h4 id="分类问题评估指标">分类问题评估指标:</h4>
<ul>
<li>准确率 – Accuracy</li>
<li>精确率(差准率)- Precision</li>
<li>召回率(查全率)- Recall</li>
<li>F1分数</li>
<li>ROC曲线</li>
<li>AUC曲线</li>
</ul>
<h4 id="回归问题评估指标">回归问题评估指标:</h4>
<ul>
<li>MAE</li>
<li>MSE</li>
</ul>
<p><a href="https://easyai.tech/ai-definition/accuracy-precision-recall-f1-roc-auc/">分类模型评估指标——准确率、精准率、召回率、F1、ROC曲线、AUC曲线</a></p>
<h3 id="模型训练与验证-1">模型训练与验证</h3>
<ul>
<li>构造训练集和验证集;</li>
<li>每轮进行训练和验证,并根据最优验证集精度保存模型。</li>
</ul>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">train_dataset</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
<span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
<span class="n">num_workers</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">val_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">val_dataset</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
<span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
<span class="n">num_workers</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">SVHN_Model1</span><span class="p">()</span>
<span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span> <span class="p">(</span><span class="n">size_average</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span> <span class="c1"># 交叉熵的方式计算loss损失值</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="mf">0.001</span><span class="p">)</span>
<span class="n">best_loss</span> <span class="o">=</span> <span class="mf">1000.0</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">20</span><span class="p">):</span>
<span class="k">print</span><span class="p">(</span><span class="s1">'Epoch: '</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span>
<span class="n">train</span><span class="p">(</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">criterion</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span>
<span class="n">val_loss</span> <span class="o">=</span> <span class="n">validate</span><span class="p">(</span><span class="n">val_loader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">criterion</span><span class="p">)</span>
<span class="c1"># 记录下验证集精度</span>
<span class="k">if</span> <span class="n">val_loss</span> <span class="o"><</span> <span class="n">best_loss</span><span class="p">:</span>
<span class="n">best_loss</span> <span class="o">=</span> <span class="n">val_loss</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="s1">'./model.pt'</span><span class="p">)</span>
</code></pre></td></tr></table>
</div>
</div><p>其中每个Epoch的训练代码如下:</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">criterion</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span>
<span class="c1"># 切换模型为训练模式</span>
<span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">):</span> <span class="c1"># 训练集每次释放batch size=10个数据进行处理,知道处理完所有数据。即所谓的分批处理。</span>
<span class="n">c0</span><span class="p">,</span> <span class="n">c1</span><span class="p">,</span> <span class="n">c2</span><span class="p">,</span> <span class="n">c3</span><span class="p">,</span> <span class="n">c4</span><span class="p">,</span> <span class="n">c5</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">c0</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">0</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c1</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c2</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">2</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c3</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">3</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c4</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">4</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c5</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">5</span><span class="p">])</span>
<span class="n">loss</span> <span class="o">/=</span> <span class="mi">6</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
</code></pre></td></tr></table>
</div>
</div><p>其中每个Epoch的验证代码如下:</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="n">val_loader</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">criterion</span><span class="p">):</span>
<span class="c1"># 切换模型为预测模型</span>
<span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
<span class="n">val_loss</span> <span class="o">=</span> <span class="p">[]</span>
<span class="c1"># 不记录模型梯度信息</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">val_loader</span><span class="p">):</span> <span class="c1"># 验证集每次释放batch size=10个数据进行处理,知道处理完所有数据。即所谓的分批处理。</span>
<span class="n">c0</span><span class="p">,</span> <span class="n">c1</span><span class="p">,</span> <span class="n">c2</span><span class="p">,</span> <span class="n">c3</span><span class="p">,</span> <span class="n">c4</span><span class="p">,</span> <span class="n">c5</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">c0</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">0</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c1</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c2</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">2</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c3</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">3</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c4</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">4</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c5</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">5</span><span class="p">])</span>
<span class="n">loss</span> <span class="o">/=</span> <span class="mi">6</span>
<span class="n">val_loss</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">val_loss</span><span class="p">)</span>
</code></pre></td></tr></table>
</div>
</div><h3 id="模型保存与加载">模型保存与加载</h3>
<p>在Pytorch中模型的保存和加载非常简单,比较常见的做法是保存和加载模型参数:</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">model_object</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="s1">'model.pt'</span><span class="p">)</span>
</code></pre></td></tr></table>
</div>
</div><div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s1">' model.pt'</span><span class="p">))</span>
</code></pre></td></tr></table>
</div>
</div><h3 id="模型调参流程">模型调参流程</h3>
<p>深度学习原理少但实践性非常强,基本上很多的模型的验证只能通过训练来完成。同时深度学习有众多的网络结构和超参数,因此需要反复尝试。训练深度学习模型需要GPU的硬件支持,也需要较多的训练时间,如何有效的训练深度学习模型逐渐成为了一门学问。</p>
<p>深度学习有众多的训练技巧,比较推荐的阅读链接有:</p>
<ul>
<li><a href="http://lamda.nju.edu.cn/weixs/project/CNNTricks/CNNTricks.html">http://lamda.nju.edu.cn/weixs/project/CNNTricks/CNNTricks.html</a></li>
<li><a href="http://karpathy.github.io/2019/04/25/recipe/">http://karpathy.github.io/2019/04/25/recipe/</a></li>
</ul>
<p>本节挑选了常见的一些技巧来讲解,并针对本次赛题进行具体分析。与传统的机器学习模型不同,深度学习模型的精度与模型的复杂度、数据量、正则化、数据扩增等因素直接相关。所以当深度学习模型处于不同的阶段(欠拟合、过拟合和完美拟合)的情况下,大家可以知道可以什么角度来继续优化模型。</p>
<p>在参加本次比赛的过程中,我建议大家以如下逻辑完成:<br>
1.初步构建简单的CNN模型,不用特别复杂,跑通训练、验证和预测的流程;<br>
2.简单CNN模型的损失会比较大,尝试增加模型复杂度,并观察验证集精度;<br>
3.在增加模型复杂度的同时增加数据扩增方法,直至验证集精度不变。</p>
</div>
<div class="read-more">
<a href="/post/datawhale_cv4/" class="read-more-link">阅读更多</a>
</div>
</div>
</article>
<article class="post">
<header class="post-header">
<h1 class="post-title"><a class="post-link" href="/post/datawhale_cv3/">Datawhale_零基础入门CV赛事-字符识别模型</a></h1>
<div class="post-meta">
<span class="post-time"> 2020-05-26 </span>
<span class="more-meta"> 约 2062 字 </span>
<span class="more-meta"> 预计阅读 5 分钟 </span>
</div>
</header>
<div class="post-content">
<div class="post-summary">
<h2 id="字符识别模型">字符识别模型</h2>
<p>卷积神经网络(Convolutional Neural Network, CNN)</p>
<h3 id="cnn介绍">CNN介绍</h3>
<p>卷积神经网络(简称CNN)是一类特殊的人工神经网络,是深度学习中重要的一个分支。CNN在很多领域都表现优异,精度和速度比传统计算学习算法高很多。特别是在计算机视觉领域,CNN是解决图像分类、图像检索、物体检测和语义分割的主流模型。<br>
CNN每一层由众多的卷积核组成,每个卷积核对输入的像素进行卷积操作,得到下一次的输入。随着网络层的增加卷积核会逐渐扩大感受野,并缩减图像的尺寸。<br>
<img src="/img/CNN.png" alt="CV3"><br>
CNN是一种层次模型,输入的是原始的像素数据。CNN通过卷积(convolution)、池化(pooling)、非线性激活函数(non-linear activation function)和全连接层(fully connected layer)构成。</p>
<p>如下图所示为LeNet网络结构,是非常经典的字符识别模型。两个卷积层,两个池化层,两个全连接层组成。卷积核都是5×5,stride=1,池化层使用最大池化。<br>
<img src="/img/LeNet-5.png" alt="CV3">
通过多次卷积和池化,CNN的最后一层将输入的图像像素映射为具体的输出。如在分类任务中会转换为不同类别的概率输出,然后计算真实标签与CNN模型的预测结果的差异,并通过反向传播更新每层的参数,并在更新完成后再次前向传播,如此反复直到训练完成。</p>
<p>与传统机器学习模型相比,CNN具有一种端到端(End to End)的思路。在CNN训练的过程中是直接从图像像素到最终的输出,并不涉及到具体的特征提取和构建模型的过程,也不需要人工的参与。</p>
<h3 id="cnn发展">CNN发展</h3>
<p>随着网络结构的发展,研究人员最初发现网络模型结构越深、网络参数越多模型的精度更优。比较典型的是AlexNet、VGG、InceptionV3和ResNet的发展脉络。<br>
<img src="/img/Model.png" alt="CV3"></p>
<ul>
<li>LeNet-5(1998)<br>
原论文:<a href="http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf">《Gradient-based Learning applied to documents recognition》</a></li>
</ul>
<p><img src="/img/LeNet-5(1998).png" alt="CV3"></p>
<ul>
<li>AlexNet(2012)<br>
原论文: <a href="https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf">《ImageNet Classification with Deep Convolutional Neural Networks》</a></li>
</ul>
<p><img src="/img/Alex-net.png" alt="CV3"></p>
<ul>
<li>VGG-16(2014)<br>
原论文:<a href="https://arxiv.org/pdf/1409.1556.pdf">《Very Deep Convolutional Networks for Large-Scale Image Recognition》</a><br>
homepage:<a href="http://www.robots.ox.ac.uk/~vgg/research/very_deep/">Visual Geometry Group Home Page</a></li>
</ul>
<p><img src="/img/VGG.png" alt="CV3"></p>
<ul>
<li>Inception-v1 (2014)<br>
原论文:<a href="https://arxiv.org/pdf/1409.4842v1.pdf">《Going deeper with convolutions》</a></li>
</ul>
<p><img src="/img/Incep-net.png" alt="CV3"></p>
<ul>
<li>ResNet-50 (2015)<br>
原论文:<a href="https://arxiv.org/pdf/1512.03385.pdf">《Deep Residual Learning for Image Recognition》</a></li>
</ul>
<p><img src="/img/Resnet50.png" alt="CV3"></p>
<h3 id="pytorch构建cnn模型">Pytorch构建CNN模型</h3>
<ul>
<li>准备数据</li>
<li>定义网络结构model</li>
<li>定义损失函数</li>
<li>定义优化算法 optimizer</li>
<li>训练
<ul>
<li>准备好tensor形式的输入数据和标签(可选)</li>
<li>前向传播计算网络输出output和计算损失函数loss</li>
<li>反向传播更新参数(以下三句话一句也不能少:)
<ul>
<li>1、将上次迭代计算的梯度值清0, optimizer.zero_grad()</li>
<li>2、反向传播,计算梯度值,loss.backward()</li>
<li>3、更新权值参数,optimizer.step()</li>
</ul>
</li>
<li>保存训练集上的loss和验证集上的loss以及准确率以及打印训练信息(可选)。</li>
</ul>
</li>
<li>图示训练过程中loss和accuracy的变化情况(可选)</li>
<li>在测试集上测试</li>
</ul>
<p>在Pytorch中构建CNN模型非常简单,只需要定义好模型的参数和正向传播即可,Pytorch会根据正向传播自动计算反向传播。<br>
我们会构建一个非常简单的CNN,然后进行训练。这个CNN模型包括两个卷积层,最后并联6个全连接层进行分类。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span><span class="lnt">30
</span><span class="lnt">31
</span><span class="lnt">32
</span><span class="lnt">33
</span><span class="lnt">34
</span><span class="lnt">35
</span><span class="lnt">36
</span><span class="lnt">37
</span><span class="lnt">38
</span><span class="lnt">39
</span><span class="lnt">40
</span><span class="lnt">41
</span><span class="lnt">42
</span><span class="lnt">43
</span><span class="lnt">44
</span><span class="lnt">45
</span><span class="lnt">46
</span><span class="lnt">47
</span><span class="lnt">48
</span><span class="lnt">49
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">torch</span>
<span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">backends</span><span class="o">.</span><span class="n">cudnn</span><span class="o">.</span><span class="n">deterministic</span> <span class="o">=</span> <span class="bp">False</span>
<span class="n">torch</span><span class="o">.</span><span class="n">backends</span><span class="o">.</span><span class="n">cudnn</span><span class="o">.</span><span class="n">benchmark</span> <span class="o">=</span> <span class="bp">True</span>
<span class="kn">import</span> <span class="nn">torchvision.models</span> <span class="kn">as</span> <span class="nn">models</span>
<span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="kn">as</span> <span class="nn">transforms</span>
<span class="kn">import</span> <span class="nn">torchvision.datasets</span> <span class="kn">as</span> <span class="nn">datasets</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="kn">as</span> <span class="nn">nn</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="kn">as</span> <span class="nn">F</span>
<span class="kn">import</span> <span class="nn">torch.optim</span> <span class="kn">as</span> <span class="nn">optim</span>
<span class="kn">from</span> <span class="nn">torch.autograd</span> <span class="kn">import</span> <span class="n">Variable</span>
<span class="kn">from</span> <span class="nn">torch.utils.data.dataset</span> <span class="kn">import</span> <span class="n">Dataset</span>
<span class="c1"># 定义模型</span>
<span class="k">class</span> <span class="nc">SVHN_Model1</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">SVHN_Model1</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="c1"># CNN提取特征模块</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cnn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="c1"># 第1个卷积层</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
<span class="c1"># 第2个卷积层</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
<span class="p">)</span>
<span class="c1"># 并联6个全连接层。</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span> <span class="c1"># Linear regression model</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc4</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc5</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc6</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="o">*</span><span class="mi">3</span><span class="o">*</span><span class="mi">7</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span> <span class="c1"># 前向传播,计算输出output和计算损失函数loss</span>
<span class="n">feat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cnn</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="n">feat</span> <span class="o">=</span> <span class="n">feat</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">feat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># resize矩阵</span>
<span class="n">c1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c3</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc3</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c4</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc4</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c5</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc5</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c6</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc6</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="k">return</span> <span class="n">c1</span><span class="p">,</span> <span class="n">c2</span><span class="p">,</span> <span class="n">c3</span><span class="p">,</span> <span class="n">c4</span><span class="p">,</span> <span class="n">c5</span><span class="p">,</span> <span class="n">c6</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">SVHN_Model1</span><span class="p">()</span>
</code></pre></td></tr></table>
</div>
</div><p><a href="http://digtime.cn/articles/159/pytorch-zhong-nn-linear-han-shu-jie-du">Pytorch 中 nn.Linear 函数解读</a></p>
<p>接下来是训练代码:</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="c1"># 损失函数</span>
<span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span>
<span class="c1"># 优化器</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="mf">0.005</span><span class="p">)</span>
<span class="n">loss_plot</span><span class="p">,</span> <span class="n">c0_plot</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="c1"># 迭代10个Epoch</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span>
<span class="k">for</span> <span class="n">data</span> <span class="ow">in</span> <span class="n">train_loader</span><span class="p">:</span>
<span class="n">c0</span><span class="p">,</span> <span class="n">c1</span><span class="p">,</span> <span class="n">c2</span><span class="p">,</span> <span class="n">c3</span><span class="p">,</span> <span class="n">c4</span><span class="p">,</span> <span class="n">c5</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="c1"># 当执行model(x)的时候,底层自动调用forward方法计算结果。</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">c0</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">0</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c1</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c2</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">2</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c3</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">3</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c4</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">4</span><span class="p">])</span> <span class="o">+</span> \
<span class="n">criterion</span><span class="p">(</span><span class="n">c5</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">5</span><span class="p">])</span>
<span class="n">loss</span> <span class="o">/=</span> <span class="mi">6</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="c1"># 清除计算的梯度数值</span>
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="c1"># 反向传播,计算梯度值</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="c1"># 更新权值参数</span>
<span class="n">loss_plot</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="c1"># loss.item():从tensor张量中获得 python number</span>
<span class="n">c0_plot</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">c0</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">][:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="o">*</span><span class="mf">1.0</span> <span class="o">/</span> <span class="n">c0</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="n">epoch</span><span class="p">)</span>
</code></pre></td></tr></table>
</div>
</div><p>在训练完成后我们可以将训练过程中的损失和准确率进行绘制,如下图所示。从图中可以看出模型的损失在迭代过程中逐渐减小,字符预测的准确率逐渐升高。<br>
<img src="/img/loss.png" alt="CV3"></p>
<p>当然为了追求精度,也可以使用在ImageNet数据集上的预训练模型,具体方法如下:</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="k">class</span> <span class="nc">SVHN_Model2</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">SVHN_Model1</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="n">model_conv</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">resnet18</span><span class="p">(</span><span class="n">pretrained</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="c1"># 使用预训练模型</span>
<span class="n">model_conv</span><span class="o">.</span><span class="n">avgpool</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">AdaptiveAvgPool2d</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="n">model_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="nb">list</span><span class="p">(</span><span class="n">model_conv</span><span class="o">.</span><span class="n">children</span><span class="p">())[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="c1"># 继承nn.module的model_conv它包含一个叫做children()的函数,这个函数可以用来提取出model每一层的网络结构。-1:除去model_conv的最后一层结果。</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cnn</span> <span class="o">=</span> <span class="n">model_conv</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc4</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc5</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">11</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
<span class="n">feat</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cnn</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<span class="c1"># print(feat.shape)</span>
<span class="n">feat</span> <span class="o">=</span> <span class="n">feat</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">feat</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">c1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c3</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc3</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c4</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc4</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="n">c5</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc5</span><span class="p">(</span><span class="n">feat</span><span class="p">)</span>
<span class="k">return</span> <span class="n">c1</span><span class="p">,</span> <span class="n">c2</span><span class="p">,</span> <span class="n">c3</span><span class="p">,</span> <span class="n">c4</span><span class="p">,</span> <span class="n">c5</span>
</code></pre></td></tr></table>
</div>
</div><p><a href="https://blog.csdn.net/whut_ldz/article/details/78845947">pytorch中的pre-train函数模型引用及修改</a></p>
</div>
<div class="read-more">
<a href="/post/datawhale_cv3/" class="read-more-link">阅读更多</a>
</div>
</div>
</article>
<article class="post">
<header class="post-header">
<h1 class="post-title"><a class="post-link" href="/post/datawhale_cv2/">零基础入门CV赛事-数据读取与数据扩增</a></h1>
<div class="post-meta">
<span class="post-time"> 2020-05-23 </span>
<div class="post-category">
<a href="/categories/deep-learning/"> Deep Learning </a>
</div>
<span class="more-meta"> 约 2840 字 </span>
<span class="more-meta"> 预计阅读 6 分钟 </span>
</div>
</header>
<div class="post-content">
<div class="post-summary">
<h2 id="图像读取">图像读取</h2>
<p>在Python中有很多库可以完成数据读取的操作,比较常见的有Pillow和OpenCV。<br>
<a href="https://www.cnblogs.com/ocean1100/p/9494640.html">PyTorch载入图片后ToTensor解读(含PIL和OpenCV读取图片对比)</a></p>
<h3 id="1pillow">(1)Pillow</h3>
<p>Pillow是Python图像处理函式库(PIL)的一个分支。Pillow提供了常见的图像读取和处理的操作,而且可以与ipython notebook无缝集成,是应用比较广泛的库。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
<span class="c1"># 导入Pillow库</span>
<span class="c1"># 读取图片</span>
<span class="n">im</span> <span class="o">=</span><span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">cat</span><span class="o">.</span><span class="n">jpg</span><span class="s1">')</span>
</code></pre></td></tr></table>
</div>
</div><p><img src="/img/cat.png" alt="CV3"></p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">ImageFilter</span>
<span class="n">im</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="s1">'cat.jpg'</span><span class="p">)</span>
<span class="c1"># 应用模糊滤镜:</span>
<span class="n">im2</span> <span class="o">=</span> <span class="n">im</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">ImageFilter</span><span class="o">.</span><span class="n">BLUR</span><span class="p">)</span>
<span class="n">im2</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">'blur.jpg'</span><span class="p">,</span> <span class="s1">'jpeg'</span><span class="p">)</span>
</code></pre></td></tr></table>
</div>
</div><p><img src="/img/cat1.png" alt="CV3"></p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
<span class="c1"># 打开一个jpg图像文件,注意是当前路径:</span>
<span class="n">im</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="s1">'cat.jpg'</span><span class="p">)</span>
<span class="n">im</span><span class="o">.</span><span class="n">thumbnail</span><span class="p">((</span><span class="n">w</span><span class="o">//</span><span class="mi">2</span><span class="p">,</span> <span class="n">h</span><span class="o">//</span><span class="mi">2</span><span class="p">))</span>
<span class="n">im</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">'thumbnail.jpg'</span><span class="p">,</span> <span class="s1">'jpeg'</span><span class="p">)</span>
</code></pre></td></tr></table>
</div>
</div><p><img src="/img/cat3.png" alt="CV3"></p>
<p>Pillow的官方文档:https://pillow.readthedocs.io/en/stable/</p>
<h3 id="2opencv">(2)OpenCV</h3>
<p>OpenCV是一个跨平台的计算机视觉库,最早由Intel开源得来。OpenCV发展的非常早,拥有众多的计算机视觉、数字图像处理和机器视觉等功能。OpenCV在功能上比Pillow更加强大很多,学习成本也高很多。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">
<pre class="chroma"><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">cv2</span>
<span class="c1"># 导入Opencv库</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="s1">'cat.jpg'</span><span class="p">)</span>
<span class="c1"># Opencv默认颜色通道顺序是BRG,转换一下</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2RGB</span><span class="p">)</span>
</code></pre></td></tr></table>
</div>
</div><p><img src="/img/cat.png" alt="CV3"></p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span></code></pre></td>
<td class="lntd">