From f49cd51c12ff5c8c3dd3264352116266aa3d7d85 Mon Sep 17 00:00:00 2001 From: michail iliadis Date: Sun, 28 Oct 2018 12:29:58 -0700 Subject: [PATCH] retinanet, pep8, changes for pytorch0.4.1 --- BENCHMARK.md | 66 +++++ configs/baselines/retinanet_R-50-FPN_1x.yaml | 41 +++ demo/loss_retinanet_R-50-FPN_1x.jpg | Bin 0 -> 49340 bytes lib/core/test.py | 8 + lib/core/test_retinanet.py | 196 ++++++++++++++ lib/modeling/FPN.py | 10 +- lib/modeling/model_builder.py | 223 ++++++++++------ lib/modeling/retinanet_heads.py | 207 +++++++++++++++ lib/roi_data/minibatch.py | 13 +- lib/roi_data/retinanet.py | 258 +++++++++++++++++++ lib/utils/net.py | 70 ++++- lib/utils/training_stats.py | 20 +- 12 files changed, 1008 insertions(+), 104 deletions(-) create mode 100644 configs/baselines/retinanet_R-50-FPN_1x.yaml create mode 100644 demo/loss_retinanet_R-50-FPN_1x.jpg create mode 100644 lib/core/test_retinanet.py create mode 100644 lib/modeling/retinanet_heads.py create mode 100644 lib/roi_data/retinanet.py diff --git a/BENCHMARK.md b/BENCHMARK.md index 5af7fc56..d0bbd63e 100644 --- a/BENCHMARK.md +++ b/BENCHMARK.md @@ -29,6 +29,72 @@ ARl AR for large objects: area > 962 +## RetinaNet +### retinanet-R-50-FPN_1x + +- Training command: + + ``` + python tools/train_net_step.py \ + --dataset coco2017 --cfg configs/baselines/retinanet_R-50-FPN_1x.yaml \ + --bs 8 --iter_size 1 --use_tfboard + ``` + on four V100 GPUs. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Box
sourceAP50:95AP50AP75APsAPmAPlAR1AR10AR100ARsARmARl
PyTorch35.354.637.919.439.147.530.748.951.832.456.367.4
Detectron35.754.738.519.539.947.530.749.152.032.056.968.0
+ +- Total loss comparison: + + ![img](demo/loss_retinanet_R-50-FPN_1x.jpg) + + ## Faster-RCNN ### e2e_faster_rcnn-R-50-FPN_1x diff --git a/configs/baselines/retinanet_R-50-FPN_1x.yaml b/configs/baselines/retinanet_R-50-FPN_1x.yaml new file mode 100644 index 00000000..989807ac --- /dev/null +++ b/configs/baselines/retinanet_R-50-FPN_1x.yaml @@ -0,0 +1,41 @@ +DEBUG: False +MODEL: + TYPE: retinanet + CONV_BODY: FPN.fpn_ResNet50_conv5_body + NUM_CLASSES: 81 +RESNETS: + IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_networks/resnet50_caffe.pth' +NUM_GPUS: 8 +SOLVER: + WEIGHT_DECAY: 0.0001 + LR_POLICY: steps_with_decay + BASE_LR: 0.01 + GAMMA: 0.1 + MAX_ITER: 90000 + STEPS: [0, 60000, 80000] +FPN: + FPN_ON: True + MULTILEVEL_RPN: True + RPN_MAX_LEVEL: 7 + RPN_MIN_LEVEL: 3 + COARSEST_STRIDE: 128 + EXTRA_CONV_LEVELS: True +RETINANET: + RETINANET_ON: True + NUM_CONVS: 4 + ASPECT_RATIOS: (1.0, 2.0, 0.5) + SCALES_PER_OCTAVE: 3 + ANCHOR_SCALE: 4 + LOSS_GAMMA: 2.0 + LOSS_ALPHA: 0.25 +TRAIN: + SCALES: (800,) + MAX_SIZE: 1333 + RPN_STRADDLE_THRESH: -1 # default 0 +TEST: + SCALE: 800 + MAX_SIZE: 1333 + NMS: 0.5 + RPN_PRE_NMS_TOP_N: 10000 # Per FPN level + RPN_POST_NMS_TOP_N: 2000 +OUTPUT_DIR: . \ No newline at end of file diff --git a/demo/loss_retinanet_R-50-FPN_1x.jpg b/demo/loss_retinanet_R-50-FPN_1x.jpg new file mode 100644 index 0000000000000000000000000000000000000000..921df780188c419d431f81d0d2a456352aab08f4 GIT binary patch literal 49340 zcmd?R1z1(v)&RO_DUohaK|xRuq-$f)(%mYOTclHAgOnnTfFRwqLApVtLAnG4q&uZ! z@3+w75zjgI+ z`}!GpLP14C$H2rogN=g=zEE@yKt(}ALq$i!z&L3J#RYs0pc7&cUEsKhNvwDu>!KA2 zr+aY18HQVVwWLbj>x?%ZSU6zKN`Gv)e%`aQqJG*=P2PfkK;rVS?;Q8Cg zelspYa9pV9=xFFzC*wjvbvPL~Av(qd4osq(idgrph%a)wpCP#woRC+G&2U3$o%Dfq zHx3yi*Tkiblc9Yb*}pchXMbsAzYgrjxJH2602=Da@8paI{?IYePkv`GF)=W)&*0!- zpTWk)!6U%K!99zMjg3!;f0p3fIl^-|ctpfR=ZL}bIV4yKIH73Z$j^Wup2Nk)1@Hcw zKjbKgmZHdE01piXL?$#s00JBg@Y@ts9s~$$W1=aSyJvo~8LWq0vF!1?qiL=2GPfcDynY}NA z>7#J(M#oo8s7;O~m=ijI5mv7$B5#&VR=|iw_bIB~22g9N9(!`&0Lk z0F3^Kc{~>hbW#%>?jeB*EiMI2`(qfw@gr=!2?^kwszlA71SO_M9|^?r(C=;_f%pCg zNZ@GWEfS~+M35$l0mc@CNMO$2=lBS=p5$7C1cV+z)~1kvmr1uMO~d#L?lipSAphac9>Da@Leh%N>^%crXo}`>Q z+iU(QFD~D>{B&?0f>S>zWR*GZfS@iY>w)V*2*;=yBaH>`g|->%z6-edQG#|Up^~zn ze0V_mb2UeX&%{n>R=s^^DwN8wUUuQlD{&;yq?^^zq@HohbFubF;3!l!H>zAXy8v5) z^hJtzVhyC?ONc2Xv_Kf0=hR53Pr;2<)TV1Tyo$Kt#@D|80_oELDm015TR8`Np-A917jWw9--7q= zpU`wR3%e!LYg0Ts_&YD2zI{5l@5uDEhei5(2fZ~*;PgE-15Tlf;!lu}kOqf)8Yg~& zq_{88aT*K7PhntpqT9a-;IOY4cmiO$fnB}w8z>g{s_(XAA1{6Jv*{g6JKc~>vRyS2 z$ZdfH!EyeZUHC=>U+4JG2HjXyJ(!(9Obx~Q@ccDV>0uM)!kQW{^jAty%!M^Y!3OIg z2NZC^{{^c339VSpSlBH*&bj8=ehS^cKg;m%3m64^|(iztO`#-FK=+6VS|X#9{sHf%lRG$N>jd`hk6 zrrbXv*gu=hzl_1K;-&G&JP_lz%|0EYvo*VU0bwGi2ta>7Fz#TJ{($m-1w}tN*=>p2 z;JXiFBLR(LpPr<`iou+f`{iUiIG05CzUWLryi7lV%(OCYm3K39nCa-*RV1+V48{Tk z{w=`zQ+|57#hcewt4KeQr4GM~tP2{0ZWIz&IDgDO?A*s&7+;k?3U?zu{r2C;j{XrZ zC<2$@nQu!ltGI_KDOW{{MRw1m$pG)AurO@~mU>GUzAU?#co2VuV1t^lts7cgD9D~a zkQ2&Fj-IVSmOSGQu(MRv)sR4t?%c5@xM67}z zz%f;fZxwg0kiSxDJx!wrj&-0>;A9?{+)X;ZT3S0 zBoH43be$FzPs*7BYe(+6uq6*9z*41RH@=UihXe+(5Gp9vkZtG|$>Jl};^6$1K0r~7 z;}OfNm33I^LRP61RD# z014z^AEl$N{N`f}<1tvXF$h;{*mO~X=5E|3WpB-=xczH@$!~fDD44`|VcLTOISD4same z0#?(TbF>@Q7s{hFhO#Z&9ny8EnJAHC0wi+{6f#tV0A3hsVGrM9c+_*A#Qt)0O7_eUbc32 z-0Ze!FVJnsux$Bdj%gRe#HB_VoMpzE-F---8<$D;TFD}83q1mX1eo@^caG%|ZXrlu zd>RQ*0o{oO)#Di;jEM9PkK(2zMz$dbs4}2vA9`ehChpsD=$Zyf{Zg>CO_)EuMIREF z$7Vl9$Zvp{&nmcb5O<7n2^8U55L`dC$=?n|NT+#2mTA+G03WoFmbG*HWyREnV4a+8 zB$~gV;(MozP7a5o8i6Imsn$iQF6Y_0a^p@Q2Dh^yr&`MvG*?zED(^GaCAy^?O&LlH zrON4fDUi;@EXSIE)B$U(rXHB+&qZ9`qt+qQDoS}>Tv=H0@alY)`}0xv7)tPaXpx_PT@EobH+rQrRp0 zMJJF)f1BKbE!GuDQ!A82ncX;qHq&l9U`z6k+W}0vHOEguHR<&Hp(KeJpPs|OQ#jLw z=}#K-6W)mYG;7>3ncyNm5Q7Bt^$!^oe|j?^W86%9POjr6`;iN*^C4m){;ZpAVu2Up zui(|+gv%H`0^_j3>9}S!uY0EI!$Htx6OcffrsFn;gt4mhPS4ff2E#u?0IJ<1r?lgz zAYV=J?I!J9au^QI*O@=*Vm)MY2vW_01Wx1DU#Hvu0KRvAT5z~P!QlrFO4kHglvpizM8%r z_9LlF^N^W*m_Pfeul@rD`A_=!C$IsK!0UZbya1(ugsASEr}}34BlSq&flORp;Qphy zaglA!l`L^24|Gz~%5Gr}x)Nbh{wK%2(`UpU3{{ z!-6Lol*fT^UQSN_X;gqNt8CP!@GWf$x>@!zR&m_iRDvS#eU`cu)8bug{o4B$wDmQv zs(zRjI#Jx;M4HM!(joNg?!n$eQelJ9kOKgI6{ll;{OIL@@v&1KD5KzX$D^V8;V}Qf z`7FC85(vqd2t34Ggmh4?0Hh5|`kSgk97y2aio|}MZv28M3R=%wJIlASHeonSo)}XZ zSW#oyapzL+yXW;NKm=OzcPYEQ>@!0b#U{F1*$7`{v_=!^$onK&;68mc_!RPZ#uf<> znYO^2aB1A_{bxS3&ExDl(N*r|EY7WT>*-V_-@|$#Wj|kvFsiJZ)IFRZPV*7DbmyT& zSB}yELN>)cFY?)k&H&#U*3c1iC&&G(!hEW*%~kc9@J#)E)#_M%m*cy>J%jK2)0G6g_UAF1Z&n;LITB*1MGNwJe=#GQuIgXq#XBS|0f_f zq2E_sJZ+o74R)W&Jxq*omkm7DI>r$A!>>_yI#1 z+VU^jNhzZFn|6}f-g)$-1u+FVaQ%xj*$ia#U7iv%trB0>Gk z5|rXm3!#EP=@CC`u^_1Us*JzM@YLUB_?H7U)0vJ2h}z-6FWwb5d)#o+GDl?W9_7_1 zyq1JNvwFvE?W?5I`XOEFy|hdMelrq~cm;bq-dfg9HAbCbSW@=lp|=SoEyY95*fST) zO}!1Q7L#4W+D&b6)o+9~OHT7`*XB64jr(_eVs&jb6P&jKT?j?#N3`6#@BJ;>pM8<8 z2;NE#rGCt{K=ay9V(%H++khmKZOUy+=m!x_lL)P1pQOF4oB3b3+c8&##Xm4p&xj!b z?K6;-D-27z#$J{84JHq1d9}Mos~cdr)kDP4tc*HVwC3~A&g|q|#h1@&|D1{-DSJAF zufVh^zw9g#0~b z0oT7|7C23WsK4=W{*Rdj;*bEx&UZ1TD;qQzUHJ0dG?4U@S>VJ>@c%N?z`zgE5@rpo zs_Y9NQQ9Cf>&ep4r!Xi&ce>1Fr{u>G(jHod0)N3~9A+YfUzIEsgYL)R zq)x3HnsDE}^X*nVo<8CD5;+}%E}E?GRHX%wsoOC>YsqId5q095g1S zoA4s8dT5nFhU%YR942HWWQibR+Kmp!isfst-4&teSyj852V{@9ha;;tnab&If4`hp&P?~f( zD%2#6yhSi^r4Vm-j=LgPKc&d)gRF(s1*j=CN4*BYt8L-*qV=%|W9O2t_qE+IQ@2?u zCXEkg&P87lnxjk2UL4=Vtpq352{d2aMq{cgua}88^;lMKIzCJ1+A;o#G~u|2{%CI; z_W6d=H*;a!zKzd`=P<=r)EH(ifdoW&$M^91eiZwYd^toZ$43-FPon=W0*#ygN6}=` zPXW_mrw{gJ$^{9ywM+acYVjq3KE+>KQ~b507Qoy9IPK(SijNalax)iJzF!y)Y`w?N z&tXgNWsjy6hW^ipyEwTtRTc%4j4RrmFT1_ZB%k&0L>J^Ua8 zHzO-}tDiqM%UOpI?lkcIVouyJMzUwIpq)HGH@1JV>Xy6pf{#YspYj(Y6~(+u*X*xf)8kGFPe@abmG_2MTs z+?yaO1m^DVTC7sdb7&5yk!v|XHowVIAQsVY{2nFGTm8*_3p~}&!js10dhA}igZ=#4 z*A~dkXcdP3ufxu;{nfGzqVTV`{~}y|*Y=;#%Un^r5ra`=el`8VIV8}QhuDJ?th{cb zzyFQ8bSK-d!)7Mwkid#>cLVT0$QO)hf2HwEGUpGVSwjA!McF27X0)ap8qe@uHtjqD zjmKTG0MieC^DncW|9mtmMO8)&Nn;)cE{FsqFf&xe`h-+McS^i!4xw@4ikuoq{8*Y- z4Vol|GGIS2xrY9r=m5w5@1>64VL=?#2hap@?0tw3YC5VY*=;|Cf3OpA`@@NI;J?q| z{|wAFOB@u8^!IBH{gD7$`+{nrxzUB>A@>t_f-tLMHJUoU6O9DWPDJ=G$wQ#Aq75DN zbwvtpxVPgcZnw4y=WDhU*@rdzc&0~403D2GF+n-Su1*zy$GLttH-AE* z{~|juU?g}uZVj6o{l-Ft_|HekNU3Vxmb!xCeP+K?aky|v@Fpk)JA#fIP^|-J+5`iz6X7KsCw;YWrJ^7;{6f!5UqrEx%#?6_Ia_4> zUJ-nz)Z^AbeTc+4YEMH-xf_cbIBTE@fyPn4A(m8E`St)j&6nFN4OOumRouSixOXH@ zI<$V6g#3UpN-hgW3gc5}nZPui#AXhw-X80OtkFb{7Uk9u?Pu-gtK)6nJLjb~XKvOW zY?Ds(rM;|EkCDcf=9FHTUo*;yODPn|_kUDD?Uf=!enng%v@oQ6#$1p)O?$E0Rj<~9 zMF!8B0aIP!okx&3E0aWx?=>9iR~Uosj(bya*L^qSarRo}y<)HUfMWQIi_!BLmhEX~ zei}^{B+0M5Z#nMEx1M#}hjvv$CgG_1ST_Zh$YRxZUXBz*zCvM&5+BbHBG6iTKU=uL zVOm7VI?9`AHn{J!>{Pl+m+Ky7%HJCn{V?2LaHJ~s&EtU0j*Z9lUWyTI);_|`S)`GP zL265gYJyD?(vCrbGXu%;FcI zgXksX;B#DS_NPA#sHpgjuO8);57k7?W$wt^c5RmlAtOo!0v`D3pkn4i`6-eB^nm{y zR{(~Tey8(9kDfAUm9h2h8{Yw>-e)c-N^12t={I*ZU@vI1;_4c5NlIc?T?7{8t*siAEeOpDG{zmh8#jk_1!d1WxvKL@X2Ka( zJDV+=Oqvus^=RFeSx(0?LgRXi+I1VlvUEBh?+DaaYl-HaTgSrd4jS4IAfIhsbX@2t zuFBBVGTpuY=@aAQogPl_Rh}p{6a+OA2!`DSbs9d{7K#>k=bnRZ54mVBN=@L!4to7n zoG!$8Hz*k+*azIWm_g37a!ikaO&**4mv8e|{oZNnZ*bS*1%m3eUZbMz${pfP^w)Vu z{`T}E!w_(3XgKJ?^dgNrt^uXm&Le#|{m8*zx{WR{-=>*CYm{CjIebn0GTL-ZFGQ7_ zE+fk{v?xo$88nRI8h|XIj4dzWgZhU`uVUJ^AV+QwAHj&`C4W`9>-*77zCs*^IX7M6 z7sC@*J7x%?UL%NLw}hb`MFh_f_rei1QWEz=d=L#6zqwc_xo@g+9#hS(vEVlp;TUm78cXfGL zTiWO)LCv-NMyiV4uJ216_7QZ0-KQ|H&~dkxTv6-g1e7Pmg5}J8LC=TfdHA(7LDwa? zpUCcQf7WDQnU-3!$Bl5GIb?nP;L=;%X3Q??ymq@ynTf-yv#Wwl1+iD6(Wp~DJI$T< zP0HAFs2v<5EnVMURtsMEB7;V8QP;OM`z#-=Ak~)v`s*pARCX#M>bv13+iKk-1;%+- zjb!zb)3ncio?#V{GJOoc9GiJ~IdqfjDZ>=~S*Sp^rD=Jza`T3CYt|=hp3wcln^x)V zm%2(u-!|E1^2AkEZ%;<0)xnu0Hb}iI*4tDI*H@N_+h!$IlRb0B!|;svq#8vtzdRM) z%S%~&k!^Os6Zyo@Aia=$VE92^xC(Y0H)ZLtmA}Zp#McCu=0bC9Vg~&N!(O z-c2g1)(07B^xo#;QlrT=VdJS)cEvv3eGkPsF)$0l*DbyWK?-bHAI`9BVs}L zu|8ksY^vLFcIY!paY&dg?j$}w(N2K>mFK&HROynhttk7V}{2BoZKB62%N2 zI7HMiB5oD3qDc+fdhrg`1ionzb^%qZ{}EO|tOqKdTp&k!4%_0;x>$B>quYZY@&Lic zT7zGnU9T zIV3+8Ve=S>ID~YLYRfgTwn>fT%wvlHVNpr2lS8e2K(`Yo?P?wl?u$tvch zIB$l+GI#&u`r_ipn_43xqro?fODXT6T(~)pZ({ZiW;hDkA=oxy^OxcMy4r}F4sqM> z<&V*OP;0#V=m|`t3HHADlpLZK+Ox+WW&Io(=0^PBfS}d?k z;y-@CSM`u!Pb-#!0t_;EoHQ~yAxpslGl!CBA)>}~<*}jT>oU!#PrK4Q_42q4?@=5$zhgU8H zJ78c;JCW0D)MoBpHPmId8dS}Xdy2PG!eb}e# zZ&3=B9m!a@zriBKZ)%J%CS8};%In^OLu*TRgPJaTWR5bjGGxYG82KQIA`DnUaJGc4 z4=A?!n<+}sphN&IzGZpZdl;%?>V6Ua?e_!9x{`@ftM)5lZys}M*mR-jIZe`IK2_*> z#P*bbg5X@cOZ;=$dPV)~PM1=$hH?6lxBU(+@XGbZmozym3ceGqr(6(2+ zGuaUj)6Rcq9lks}Y^|*3Kei&pJE#(+m-e&F{0;xgSVwA(SffGgw)#{>>*T|lU;V#Or&^yr%DThFQK&+NFJ1P2Jl`vhL(0$vxN>#;WT+&65YzcPhRxAnGcl&}5Jkxv5JCg!r@Qaq?H+xEdxD=W1c zuEcGaRmlCcLYMV&Pe?&YPE+}c{!NX+2q;em6#5nozpHk8AK%EdhEV8QmZMdX|BZTk za#kt=j`~F&dNuToC5^XxSrUlR-gsFh>{Y||rto~!F-;}3xJL-P6(Mu-*1oW~$j!Xg z4QyFCe7=1VpZggd3+bBgu9I8ZEIKtn!N&9qK$ia^mTbKd{y^wlZyKy%52G;+#A@y9B;m5F2zX6w3+z{B@{ZlhJjJkioUJaSG)z?Sg51v zw%SIe1t0nDTW)+V_j6?ZkuTu_dqb}J$FlA+o#E)8JSZn$R~ zYQgb1^=@3f(vs}OcC&TAO2cSZw5T}KRhuXQ`fA(UV~tR)(AI!Y86%N{X*6mC&AGz_eDq<)OXC}}aoZaN2yfOq z|BOfc5%Vyo^VrNjyhJ~I5cCd=8WXM=;!gZc5OdARcj70k=o$&|=7bxaup*^yQS3i` zpkmj6CSLU!mhL*%@0z|;vx!*++O7z|pyN+wIP7TgD{?1nGWnNCf`1ntV2nX=;K~=U z&w&?4jtjI-xTI?Ak05E2e~lBLZWfi*3l*D>-m2n>syXS+o=1E1i^L^GT4i7pPLNL) zGpdF1`{^n4lv5b#DLD zOiIKbWtjXk)4{*HB|-mm)Pi(DZ}u@`#bOyp@tXu;(L zgel1-WQRubH|`tG+*}HTbHeaP%;dSqou86(UcD*ZMpdku7dBU9tL4YhnW34SPN^Ys zGpl>3j%3X`ek};YB0d|DMOhp>zK)Nm5m0;?5HOy5#6KetSZjf@m@C{q_bmHvLf#e0 z9U`9HCPkgX8k8gIM_k7RX0lvfNg;;i4~@_+Nj2>{0=ARY9*zTd9zMQ}O3+hgt}h>O zP5*9T2^K~iN}x+>Y1YG1p8cw5{T2t|yG$fT3UWAwA=@tGaof8zH9pluj{0@G;Q}eI zc*)DNi|e^(j#h;Hmv-~XMLc|GpIUtujNv($YJ{9kuI#mwoYO8VY;&?78J9fsiBjMS z0|Nuy_HwXx8|gr7w$H92XC954)siZ=|F!eJGTjU}gd%YRuA8d0rb`DFc$%BrcBh&5 zuI+ErkwY48lVtX#PY)NxW@Oda!QuP1c@x!C>-zO^#G_`_7y6#w&om(~AvdQQ8a5BV zK;wc*VT&GZKnZ`cFCwBjT$`xfl)MNr*qnQ*;jlS8OFPbGPKkBa>e(o!wJ_^2*XTu7 z{qYx?e!-ugarh0WZkG%tmwDdcyz`>KY94z*XQQQ`%2_4q@vfP?UWH)FrsQ%2QKcHg zSdwgmR!gm>ZFd&$h+3~uZaoqpMqF$WiQgH=ohdS>8|Ke#Z*R6cBL}lpe!+PoP;P5Q zV#+@Gt^jCI#;?3n0ySas9a5!tCaC9><5m~3yz33ISsAJ_1)eEwFF=XtHLba3ZD(Now4y)x{asVn(9Mq{v3=6GSt z;T$I;$C)-tYU2z5?_f|d(-A5)P((S|SNKx9jBGSfnM->53fg(%k9~BMd(a8k8R=o~ z7ZX|6MD80iUSpI*J11wQN+ZIrd{CaTsy{gxHCUmW+jF3y+&IHlGH58>PveU5SR4#r zwJk4#AI6#|9Z78rJ|{*Hx)9X%fSZNo(J@&MU)FOBX|J>v#s_H7JdL+I!Z+Pxm}K%5 z!sT43AwPP3mCffs1Wsez>Z?0VF?av{!+TjYXGfvEPX!5iIi(5$*)j44N4T0-ESVPYfS@StsyVf>P+5v#-iqzSq|Ms&7K7a;iq0Qfl6mZ!QbTL7!!b9#%M=BiH!Sui7&8_Pl!T61k$7A)~bjl^9D>-(ZGEql=kOVMu|xti9>S z{i9A+M?9qjvQG~~UluP)ajkqARL^vd_;S##lD#^q&8ikInkt4#<;Bzi@X zSYi~08FZb28|=k%x>-VsaSBuso6_K8mK1J$%@_ge^c5ngDMyAL)>D)MAXY#X37E}3 zHo7`UL>!`pDr*}*ZI_+t1<$~#T~TJH0Q*YOXYc0 zr%Co`jqHnZMxRS~1Gz3*kB71rt-4~}*fAwn9x^#VscAlq)8Q7x6C0V}PVU17_wm_! z_xJ(*JQ%h~f%pukFB%5T{U7vqKV}D>HPzBwN4%DcF_Zr|G1sEMN?_K5Xz8h$rv|ry zyQ&{uL17-)`iGk}6B4H{M2U`b1`)0gkU(_< zJ+uh4puMhH%Tt6sT+rXb2QO<_5hOod{N2?zFY%JiO-6X@cLt)#_E|dwYCsFf?Q6b# zAwS+>eLldoBJ@;Y{ll&IpBPMsU$9)%p`oa7y<9<VqC!N)S^K_DZ)gwFUJWkpimB!pgrYo0TOdjgob`I;|{4tB{rwRF%ujuS( z+=HA~K^pHjg!`^VFcO}hR@F6{rARjQ%zG7rM)rBiyOR>86XN4=wqp=r8nsYAmte^bCcbALbN%35Pu{WGDk zxvfl1`^#}$33}>1g9ZA++=m~=?4ISe$fXjIIn0)DI_CsNrWEZ!*Nh`Ij$HPvP(Khn z=a_pe;VVf?FC1{@GtA#w>By;%(ObRe2sb^?dd2ugz87;=OjB%9Xp@C2Yk%$$_5`$@ zSEcyPt%4^`tF(qb&Gh_6E&NLZ+42)H)Y3BMJiRg{lJ|>MRhX$w&}f&UQmHvJHdGa5 z_6>4!+Y7pq0$Z4iCwyKdkM+cc%N%A3AB@YT(To@Jx-rVK?aV1uX!aYBOYQK|8yTG^ zzR*bOwDieF*LIw`n)C(wIma(zksZs4x)}^;B1M9ZPwtwFd|9E=P0Es>wm}Vg7ki{~ z(Bv+}>NHVy_0sxkw8DEd^I|c+CL#A=UCRJN^x2kx^K(Q&?JX-z$}jzJTH$ zt0zkXxX_ioSZ?@UlImVYx&#<;f~x;%!JaI?GU6cHmyB1z^(MW28y36`(4Xp&+yt|F6y;JqUXYGnpB6fvVx!b*sCany&P>PdjQO?L7xqV*j zUXU0P-%;{=IF`9-o;98l{=#(DvFW+U#ixoDT?s`H0vtWl{5sP1@nJ%EYMmcNqXR7% zQ&}Rsp~O@V@wIex)_nl;UMJOjZe7wA^c)FX-#iGxw0sgLy>+a7tEixht;by86QLE^ zz-B~nr8UbmmE7xLnoU}PkNM3S^jEB}i|Oy!Uuaz;6dWmp?bGksKvuHpFZ$N@mT)eK zR2+k8MlOAvgF{T@y*aa?PqKAPd#bzC4mvbM4MuF3Xa7mWcBfiMD5et^IK1Ph z`=%StwdnXMqK5YBH7ZMEoxol;<>JxN%4paWD61^$w z@U8IfC9-T!;mE5uE9A~F%u>5Iq47_?eHN>z9;N(h0SqsGVmGU%?XX>QJHK$jii}p- zpRYF?-qM<#k{HNP;XGgVq{uj{Yj{YVY-;uJO9t5uC;rWePKrG& zwe_&H=Z|K3@i^~r=AY@GdgL;Q6(r$2vSb^kZ#10?9ognb0tIxgepcKfsqm)edc`8^Z_xZN` z8B+&d_2JNk9fehYhr7ORL7qXvj+t=1a~`4gi^kry{aLCv#VSS}`nEWnpuH;!bBJ=B z)&Xuo-eQc$4@7j|XW>tkaCeOEdDNEcRnQi-H~NL@u`ph*IPZXtN`X_i5%{taYvKia zNfZzF%kY0(LCL`i2^?mR&jrb)-*|BM9NWfIXLZf=< z?5=KWMWbXkGdBfOd38jm|DCthSg!@u&W25~dmjuWB;&KIceBy$jCv|&U7ouqa=)a{ z*)GRGdoN|ZsjftSXps1}@CVK|(ypRH&gaw?tvM?9h#b*!pe*|Vbk|icDX|EzRWV_&zAR$Cz-r2tXHQl;-{CZ@a$d{mu_O0tvgJ^8qUvk zWqZnz--j8V-@vn9FKb1c!82bcSo6d`&6MIE0m;mp`>M<>eKM)#f(zx&iKgqT_cn;1 zCC~Kp{5#=ukMzqxj@$<@Bv%WMz8*hGV#ERzzo zB&h<-S@vyg@G`ly$2Je$=a6S`mG>*YsCC=IuP*RBTqv$tBt~fDruDdS{eMhpTU~{N z5#GTz?X_d=(bU{mFfNKN9gI%T*wQat-4<+k6ntW5kzLLp5Q^QV_^}nLy z{e6}9ZS4Hsf)F<(@Iz4oZpWV+w$2Z_8Lo5A5yB6rYG2=E)OsL6#gf>ymOLCp0^hPlNr|`Tr(EvQ^=#zNK3193zJLVeS+%(FcJg$%%rDyz zuj$a(NJAqvGL;mF)nOjLO!EjFNfKckrm|OQfE_(XWHA5P6_c``sAJ^7UI~YVeDX%=<{?00zWHdA{dlae0H zAYD`CN<6PJMhw0q0#bDZ<>C7cM^{91QQs~+mCJHWw{?@zx<)gc(>^PDQZ19QGBgfR- z^{l3`%iY)RZNLJ}W8Z%qX~IwOiUn?kqQ7>_#PdDwp3% zNSEK+CY`~gI%C?|Fbnh)w~y2wW7!&7Zv+iXjPVOYqYW}Ly=6SDt0Zl7U2>hkyaR)) zU_5r-QS_@!Yt0H~HZf4$z|SpFuCJLu*@*MOC=OnqYuFrFa8ReaJ3#deyH{R0TIR8% z5FHPZ1Or7Qp6R;2-rzR-jY;+wH6`bJ`znTU?R+_xJy{>x1l`z=+EzBcmXq|jCkMAf z^1#@)w@55hwXfTa<<9v)#lpVZm+Dk1jbb8(DrOFgOCshiz@8N97)murjcqCh3Xbx& zL&rOI_hTQ_pt;hpg_+Mf+8KpF6IU2F3-blME+|yctkB-t!i*5zqhDW>t;`SRV&KXS z4G)q}ob?oR~g^WlZM3b{5wfR#@?Q)mwh!9H2 z3`xk?>PfI4w7iyPqZ_k#IF#>bvuQ61^wz@|2vI9vUrFZt5S%^8(>5bVDfMx2=^2G{ zfEg0-e2SasC{b%=N_OtWt{$aPtQ|gM569sLd1VAC7-?^yxz8z9pEjX#jJ{o#x!t|yrn{=q#&Fx-_wU&w z;c`@a2Vo4&v2p8XC6v3lT>zQsH4uz@P7}fsbzx zJ^9HJ4rbAkx$_R zTXLaa{~MVnQjix+nf=@j)z7!nJDZa}Dlhr$-JHB8Ur&eU#XW1(zZIF>%Z+J`?KB}M zvvEA(u{UTbpq}Rgrx9cB4`|of_cJ1M1k(v_^S*?(?B_C-gIi9D`JeeRxBt2a!WapB zaH(1Pn&OW2|Hy`n)5#$`-epqZn(jp4LWRd)V-7N+-m+s!absF?vot|- z_H7y9u80w_xHxzh8m^%=DdFVFk%;DJA@o?-r-V@zWr{n5tj~O_EKbg&jhy$=``0g- z4L0rWxqDCT8%|R4Xv>5%hV0sVe@3ppJGZWW7j;qv@#b zYu0CjlwYUs`&JgxA88xG0GBRKu2O5kugmxElOblWs3SxQ*A76-@emc5C_@!%MazNy zmz}>R0k0GUe$H%olEQ%UN9M-tV+Q!{v8&XPm&xx$4NQA5rkTzZ%#?Nx9NL3IF=J@i zA=N!t+t)O`RNwNKzXOw)%f76z$8~#oZ)BdRZ z=+}I^1MUl&JNldCeKjkAU?OlusbXGU#?W~aMPKGrQ50V=1IWqlD@raYE^}hPtY>FZ zaW(wQM3&yu)WaB#_k7yP)(rXDHdMwUw6LrUf!_v0_hV(py0cB$!Mo4+!<} zIM}bIC-EL;@Q#c=UL6ZeGfr*SUNZYk_C%_~_xOmYT>a^Wny)>=Buj4u?jT#^DAOeH zDqKvwo5J$RG0VpsYKgsIIsJRAqnzuKTA+xkW)LvLZB_lp4T5S5_jL*-}b zk|JB9A`&63l#CBPG-HdTrtf*E<-P5B{Fsd`>5j(B8)OtXm{Jfgi3wSl(zT2ZCUU`=F_R+(vSl(!)*RSSFOeJ}nnDqwcsJnb2#&i=BxNZ4)d(ih8J^%EPM zoBXvj{o03xm8AhDG7pAs^N90*V!hH|VCqCm^7hjL4@iFR1TI8f32qMY@8-L|Y;)Ip z;c~c!0?m8T2NE;`gAz=vgUwOCSFm?dFPNlXTb(R7MQ5_~TVbLhUb!27{Ggnto!W{r z!r{@`excUEYs#VadECEG2bA%2!selosC}iOw}dduHFCsF9@us*Bus6($k<31)M3sk ziZO7ex1Hm(Y6q@;&2_3^Bl=U4)A$=>1*@R&#sS*LPLh-To=wGfgK6_mN3{QR0l<(y zU>a#TS6CoW~Da<$$VEY^k2ItV5lm)+O7gf2MNlx*| zYhQ|ia*q3Bx7=qNDwxl;11cFlZ;C1 zm*o*Gp{Iia#e*o3z(*~M6J`2cBxElW^0p96+KdF;Ibj>B2_VdnzlSH#e*x()p-=sra*^@xXJz}svQY|nF19rOUz~yVa zCwm<}-;Fy`HBl6$gJqkNv2H->6ucs!pjR4F^K|{eTJ$dWNBq^h?Th^gCz! zVA0&eXJ7qi&HfjLygkJG3k+O65*E^AJeS{qiYo~2amb4L&C?1mpnrPM^rvCn1;)Q3 zfl~lift`kR)-7M)|7q{5ko4_(Chzqtl8Unfsk9tPeT{PzYtDqpx z7UqjTnarOss(&O%sNt^{8}kant(|F)t0Wirm%f1DUpMf#ZNwkT+&|k*G{IaW7psE1 zz!-000sVw-uB)HSxz?{zPjhNZ^RpF<96`ZOUY7R{lsgVxrun;WX9;wP^abUl2Zn>& zC=Z;%kJ2V8kP%o1kPrm$Zz}sw9r#S%37<5kV3*r6UqL^n<>dCWT;TK5k>krZCSsU! z$kQbwt_d|?s@1e#R?e=vG_}Yl6h3K;BV%t7;Lq;!s@~;{+?fhTIY{hC4fhs?GCHPBMtnW!6rQ%b_AZ*4j1|T}-7)j@CuC`J%2yHmN$MP*OA&4g{=__94oc++PF6VkJMy-&g{x!&#EFg&2 zKJWD{Rl(%NX^l5k))C@nSbK%CB?`Vu0qIPVrrCDj@Y?qK7FV030;QbFUGLv=EgUd* zRqtRw?t5T8KS2q>j*fd?3yI6rT;`4YfO)RmVR{Q(;_#4`FSVjvN_k zi1bGmx*BH&eCL-OJUc0!;i;z57HD;;`&uW zwIvw|cl$T7aNSTg2AZC`3qEa1PL;tETFjt%%_e-dd@7RvqdY%~{>NiMNQ1$QZ0TmT z&!>t~3cDfe){0?Y>L$Fv$qo8_86)uf0wBN7@vtwkz$fy0%T?iX!2<79$qMdr6;)GD zy{XD+6b$2+(o_mm<}oc+&>D@qUYxYFcDz4nJD&qOR@vxS;WdNd6c1_0lWsJ08X#2M6-bBrgmEL;$}_0P2p>O z&UYZHUXcz9fme8?IgWIg#Akqf&E530h#}?vhuX$=5Ut2@-N}g#2E&hPDz;fZZzfgB z2PUAj`Z*K+ZF+T#9!wCAA?dQdW4WZs5y;cQ! z=VKK6$PvSj*=b|LjscI6KDjfH+$YmFY*WvxDC$uwo;&0koMa&9jvZoZxW>mxF+fqX zJ?Cdxq&~nn#OB~S(fq=VZZkeS%a}LbYx-cYuaQLP<)lVVgQp?~afH=a0FG9gH)VrM zpEgZ`lSMQVx`bXuh$p!t^OedVQ)E?$@su3N5|&O2X%2*gR?z zsdSvs+IBC^ghR|LmH_LCGJf%CaKNYH1lei`K>lLP_t@#`_En={to>L{yx=vMTKX=Q zZ|x;M4O{aZr_#ny2pD31qljkgd#R&aA~a8H6p4F9GMm; za^=3=1ZVc;xRte3uqeUk?&$pMUJX+U?B3#M?3L9HgXh=YVEa2W539Z0AsQbn;b_BP z9XwRgyPDbXSp9Pya{uQFRWd=tb+knf%8LA|f+P3F{##7NV$*A4jB!V`PrSQ>BJdK< zu)99{TsgGs4py5YxMJ<`NgQ=TewE5)K6hnMA?sy&nN-;Nrg=@UI0XETS6}2>oTv)X zb2TGqbm-lBj$}$!_2)#*sIic zQ-LgVy?Qj$hYmMKwBL)beZUFJ84Ga80T-Ld2RwY1(q9B|pM|ZeL@$MQK)s~#dHU;B z{QWA6pskGTljAOnDe;xL6QfL_DGj93mtQIH=6jkW9IkmIXY%a9RWPVO0ZW+l8ZcfLvf%{;+i{ve}3 zhvo3z60o56MelJ(JlfK}#3b|n6_7)iNe-6mQ|zP-kZ$+%-g|`w>(G{+ETVjxTwZcX zZJFgtu|16pul&-n?oH0);nrCfg=w%K1w&2G>WUlwr)mdu|Ga31IK}rVES@Q0s5VYy z$U&yp4PQ$3MNyY0=w}@T`MF(xdApcHG2yDae$uj|<7msSSV=KKsUrT%0*MR?+VI-r z3pt_`UNk+hkn{9{L#_RLM+37HIGdJDg2X}^q8knoY40g5M>!ax+HL4Y+I7*lZMz>S z&o>B?usfF4mTnF_O?IBk;)v)74!69F(K5!`*1z4=n@L=hDXN?*YZL6Mu~sQ_oGYVE z;X@zx`dNkpiZB?W< zZf2WWg_FZh;Zsd%*Ty4yp45b2k~32x!(Kt+OOR@>gj8oPtCncvZ45nKTdh*!+qyk? zSaeOVDB$sgn~%+fkeeXBFjP@3&E*SI$&N1So%_`}rk1mMzEf`R?(I)!7h6U%st;8f zR}+FEXqoC7#(B2n;Tm1xBq5U~;o(7FKpXRJDZ=oJ^=mWF#0CJ(9k7qR#9pnVG#l(ZiY^vj3-chFMi0pO42f2(1+?NDGH0$NE^>K zx?Q+&6Dmt+eD$X0)-)dt1y^2s@iMX7p{Yfh88Vr&D2`1H2_ z&s~nh`A3rwoNpc*38;Nz))@0esMtB>FnLxG>vX!bk4l@rVLG40Bp>(oFe-#%edyG5H7u4BWiYiKPt{5CyZ%f} zR7d2ZcRvZI7f^?8h!hBf?7srfS*ta`M84xZvxeWf(~r^#?U^lC>gO_3d67}WC1j;0 zal8n@qljR69Mlw^#T&@OAHWY(Qa2A?oNb+FReo}~dPe9&owFh;&l&@>zKMgl)O%Z%IZHVRisVC#>9bw*J!@2@$oQEX zA~ch%H%Bo)cj<{o5Wl#qM_N~BW`mY0P!zypIcM9M9GBB`KgiW-?TV9rt11yU_Zv-( z=ZxCpnB%Q)FO8%nWu&wyn(FatF^wK^k4h7nKa-t$Hy%_Ro@De?`dmOtiYwKlW*(wN zE%FxV`>&K*6RVy%FzLzbKM-n)0&kHM8 z^z-uvV;#hfRoq}PQ3>fFQ?k7EoW;8v7EuW{h~jKNs~@L);I_kIsyoj;>zLYXGQzHR zws*-^`F&@yziTpW!`l=G3kMWED%6?YoFz^MSCGF)JrViG*28;=nYDIEJE=`3$!C*W zyF2Yu>4TES+In+TXew_%x$=VkF|Q|YKY>aw_T(*h)dON)*%&-Sg&+<2H{vMIA5;l(3qbJa_Q9xC=+j`5 z(WQ!#7>4R@=e4VnQjJ!vr6~RO`{YMMg`QaI&G8+&n-mFFG56eF_30477CUTpZ{k^B zV9CO#j$7#&_wHoj=8?6d=a!NhlF5BqrI4bKtw5(Ngs+SVA_t|()8#^guPps~BJuV2 z-vZaleqJT_=AFx+87%`@*4a{YoSTgddM0dpe0guF&xo4FN3?1$jx>xeQBVfoWh94A zy<%!$jtsr^RJSf}bSMNww%Mf#GB=5a=0>nPbfFBq?-lEFcG$Gk540}tN^C{$-$Ttx zAT*}XePTT8yYITDi2LwaYoOO?J`eFuODBDNIZd_G#;9bpP%9)Nvw-Q+SDjN@gT;%?Te!!*QzZ1vdIi@0ZR>AY>-PIvP0kdL;KH5 zYE4XdcRNGwZsPkn$W`G~&Y<;ib9d6uWsC0t{9~>_z$Ouh3a_%MU@wf29omJ8C%03F zEwrRwQ&v$P_72mqqP?t#T2j2sL;JQo)H3%Wy7}c%9QmBBcv<1sp6eJ1khJfVrKzEs zUqHTg3R*>e%a}^GFw8qsL7q&h9ml){`{Xb5Qau*uxr26 zQ8Vu;4+BLxKrPMJS2SGU)$B-p4=%BUZ383>V8GWc6e36c31j!O`fif&8$cDjL=^7w zGd;9T!&>zRV#0^8Eie$12D%Qq8sOJOa`z+fj=xSLCH!rk?5uPx09X=&Jc%Pn5Lm?Ry3D zf1wAoL=ajaY8iF2XI*Zu@c;@iy5PEiD}_iyLq7v`tOX1l>=}N{$Fi1(STjI0(|GcK zrb1u?N0GsYHYX&hWPX4LSTF~hpgIDwA^||-P1^mgBaYWUUkll}|4dK+f=;1Fq3z>- z<4!zbC+RxgkoznzmwNYd2!O8i0NMnj1`@L!vKB@(n)era;(1p8R~nHG4f2YW`yZa9Cs zj{c|AL(KMvZ9|3^+KC9|?E}HO=ck#PripfnM5rWGB!vx@CV6HD!0Q`-M;(47~_20-pBc;1C@L0InkVBed$7M4U#@;F&|VTz+)}b+G43s?^2W#mhZ}4Bp7OaqxqmdN)aEO z*urMgj-oEUZg*9PFcTh<*W{MFqeMlGJLd$9usqnxsKGYksLp}fvB`QOWuSJJ&>5hF z`I-WD`S(-6ejyF}B`peW?5B4V4(Q{FfRkp6_3^@(YiuDzL`(d5ktTh?3#Yiee<~=Z(-Jr8XY1e(7V=_;b5uN z^hS_v4V53Gow(7kM8||J! znnzYkUX&g;Yt)otLj!B3D_vM*KUwo?+O(p2B{Rm=PZ~R-tj!t`-#e4#Yd))U;NwL^ z%d^MPW7<7glVCGg8e0E&MCe10CG{%*7K{2zcvGGVq_yyy=a3`cWyq{Y`#3eTGP#1a zmnE`NF$qO^4Md?Iz&-?He9eRWt^}P&Or^li)(E@-EaU~r@J)-101A_WW$+GxFF@Y{ zKB5)+;>?XI-Tat_uvN-aTbmbURw~Khb z^s^Rt*uM4>D;3$^df*fBcqw1HYV7>VF*x zw6Ei+j4*~OKsm`jJgk2_DYl3AN{ZBvEV8hUJtTny-iZciW-Y?RO8-VOc;QE>{Pzd$ zZ|cK*ahv%gH75J}SnEKIdKAEuesnL<+`h&-OJoQb_9J?ZNyxU`Aq3DT5&R^8Qe~wA z&=|ly{;HYqs@4OfJk`LfTRMU$<&gY(^RH_C_iz4fRMOQoOa(_qo{?npK)a4B!3E%v zTu(%>W>CU6`S_-1r{p6Lpv*#;4sm!LGwK0&xZsrZ-?$lGNu@PP2A>AI$Y!0V0MxK8 zuk>gDky{>Dec6?eq=+BwxB@T<$_N(h8<~y$f3Gx8kE+2Y9zwG)J(_rqxQ+}@3$yrd zfK30b9o3gJ&2eTmRPiMxg=x9hPGlbPmlWB>@R@XDjnN)gx`L}UxPAvE+LH=FRR4+! zY5&2J@|&3cU6Ywfs4_CNR}+t!h3O!7Vm8`f19sntGvmGyXLg@KFm%7WSaf|~h<-7V z|IAV&&uCv2u+%*Rs|L(OPhvN!q3>D$)T9qnW0FjB$e_6SSce?mQ*51ZJiPx-1>>K4 zcK`o$5qcJ&!#M4g{_E4l&~G4XN_yXTzJKu9fK1ba;L>ksYd%1{{(BA34|t(xZY4nW z!+~haJ&ij+{6aJVY9;qs#o|Yo4|J0GjQ~&z=g<|w1E&MX@TYgr7onq!j2md+?NFksv?eu6--pfWEinRqBBO`>bjr&d6hT*r=s1GbP+xx(>%oo69B6*&a`kUn<1m zDjj9Ias@Mi!Fgr=tpOGB(U@+9yK^y;o(-d3sS&J$W}E7*jzj!QyqnaB9%7zAcP@+p zgHVp(koGAFb3xrl`!BakW%97=3f2n=sJ!aO|)fJOfR z2=QL3Yrg2bs%{)fe^fa zHf4E?1q2ELyOBr=2Q(yJNF+r)0Ku3eK6qv9MtZ;Z|&(4z4T9Y z=I0Gg#BUW=xsHghJU~vdK3&!~gKZQ~c)ff>{m{x|vulF{v_rz+e|iMnPt5{ol`fYV z@#}g{O#*UJQduqHgo>YJ?qtJ?NL$%gv_q;HRtmwYs#?EXEm|0dZn#G3>emf6kDk06ZB-;9uz3R1=_i)6ILWf*)%Lw>>_f3XBVIu73g94*t*Aq7|ke zH!C)M+HRU#3A~g(`9r2=Ne7h3yu^MRby4W6YCT+Xfq&MrxsH(phlUl=bYnvyQ)B_1STB&pdBN!b!F z&+!An*8-kwLhH~3igXP58}a0kykg;*Wd<^{vcN4&+psbqe3KXr%A2Y^Ap9+rMs0!= zfh~_60lDNPtFUodaO)l)%NLN&xcEVXR^pNr;5~6dNHv0)C20Al$K2KT;A>oR) zLEV5@#cRJF{J2~oM>`rk^tCb@=ZDrv@a5B{rwFrL$);shm@fn=Jmxivq|528S+jjY2-+YhZl6cGvfUqDyfrD1?%EYK940xSQKrxRd^ z2K0tw9cl&steOeWo#x3T4$JX)9QSJ`?P=eV4Nk1^T7c82<`Xe(n4DR)nMyeCPKp2xk7+Y!9$9H|GWqOVlWbg=pX$&isFS|%M86RMPfG^VTN#)rqleL!KYzMkq)Y%>;}qX zjqOfUu{!2Z(_~mrKOis2f>2oSvB`|E<+Ts8&XM@^JPV}6EypMl6s)P<*O4>OQctQh zjh4VWNC~hnAaNKw818~2>B3tSE8zuECKM2mNCx)_VzX2)vHx-d31W2ov$)mb)pS2X z@jdt}b(C(9G!rI>XL(WTLi`i!K~J{YB74C@5U?WvCw~Ev2qIkd)+L|uWCJ7o@8`GZ z9*Ze?WYJ_jNm2AV^f}1e+j~`rh(rsP^k(0{pUq&k6}n|*#l5VG*wgZ}1cMz{8$Atd z*)Gn8S%)8I%52UCwr%as|FLwYHW&Z1+7>#3L<_-6F7v0lO`uTzfh_@f81HtCS7lPAE!kzMx?)p76%YTK*eK@#CpIC?u@JgHbEiss8 z8J3p4^V9$6VfKj91(XsUl$LM(S}p+Vz?>2S;90;KJC5Y7q?rIZDBZ=t0M^)mAQ#{m z9sW^9@E4}iemCimxe!Dm8R07D#)bcS3nmmYS9vtSJ*dc$S?S)5(V7<2Ne*61zTK=z z1Y4$wtr1gb$1;EJy?DHT>_xoL@NUB#kRiOKN3`cj7eYgXQKS9?2BUX6hz^UmDWPFbf1F5}E3A=w|$?auC|PoO(9CB@~>)#nt(dzJBd;33H&NXseN!EH%F zy-dJOwws90e~v#L`PLsb{uzSyFTMl7iDLtxHsVhf=x^&~JelsH`mHNstWr|pVQJ3O zMN14C_syoabQCknW6sWp*?`p}Jq5q#H)fcqHwHy|j z$9AJBiRvy+(7&y*zW6$4)&c~s0F#Ws9nAna^uQG>-S>Z8lp}6s_n@dS8V_Y7di5ei z1^1JL4G$%gbM`V?U(vY7-lT}@{T%I1$}p}o%*D9%H_>F;6YZ8j90LWuDEA*xoZC&m z(wEJcjH99U4i|NTfO|m5QWI%tB3j4!O_GjxL$u6Z_+^ruHa*2qvV#sJAZ723n{CAY zQQ8}NSSn_y<+KL5pbR#R4{yCawU2Lj<=ht6Oz=xjto^s;ou<`6O&_kOl>7DKMOsjY z>LDd8S%=F6-|b8CTGh=jm3Ty5Ws3X6#${Fy)!!>(%DyDXU42r9Y6W}3+)HVV%Guu&e9`nVRPJfmKQ`xF4AB1k>~eZluP;55jn#pTRhj>@Q@>-DLL_$xVh=Y zX^pANT)a>B-UoWMiXH^bjA0}rb1dFDbJHY+SN4@TbAn}>dLzxbiIB`ve&;jMI<|+# zJw;vEtpvf>ia6!;38VA5u@AD39xgWR5${uxn@0Dy6Trv{``KTS9yu6~TcpTAAo^|Y zDcs9G3ZV%Vw`Yo7?(kY9SQ7h9e3+nStjfHTiSF_Vvt}nl&k%3jo;>?HpEnkL6Y|@A z?$#uLTK$BlbK={bRV)NT`}WDyna6?x46C&9las%+xK07bZluCTK_Te!keeTZjtPlW(XY z`A$urIiXfSY7q$@P4BH<nox`Hw5cDp$&c2OgaPgm`SLuVX>-NH)MF5aoX zM!8BxMsdE;e%pAD3!Ow$?H&DlM=^?PI==mesk|Du)CTXwPw)_Ep2zRThAhVJrde^5 z)05-QYYyBf)!g7El}cHQw?SRo2HAgdS^Lr|!Ro>kfI0JSSzlUsaE)5gW!(aD);y;xW1>;n5RV;~2@; z?qad*cc;&ciuL76PIsp*)5jOizRuRf?AS z60&v99E$1GXwj!aH!vb9dhz4gned=j>`Oww#)xcdc$w^sJ_XcMxVBI|C}P1jbmER8 zd%@jwD-6Y!mz^^&PcqeXQP&*T^f>{y+#$LqUVx9#I1$a3@)r<}E<=yOQ_^SA=cgs9 zte=X@e`x6T9!!LN=)*|Q=XK97Q3}NhtGO>wY3*n1)lAx^G4_-1DCEIy*#S5og3}r+s)k>8O0WU%zDB3 zkr`RqzDN83;huKP-h;PaK!&642|QbF5nJ1^H5SpjZ7sI@wS64!!Z*CdQ=CKYa9o`h zEh(O;=g)N3kyUjRC8Fh6@^2xVOCl(lk~*h%LHP

V+Wh#Z^U5P#w{kr)z*!QT!T% z%C%K1c;ZyrusP{T&M`YdNSP{Asx^E*$ z^_c_?EL^No^zvw5!@3HNM6{q`l>9?_(w}qp9^4+6$T-WwfQFNsB)hRZVuE4KQKi1xD{P|A$}v`B8WHH9@Y0a6 zwYe?4E^V>GATlb$cv`T4)UY^yX=1E}WaLaQB@1@^dC}o6vsfKfbAF$Qc8`0FDY$P| zkF+@+Rt}-nR*TsMGA3&1y%~9_P?~LkKSJ7nQHRpS5jjk{(zF)K$RRqekH1p$4%3ss zZFa`mmTlKn-Hf3;kKJ(F46TKU%Qb);8MJrkU1UYNs?%st8Pd7%sy#*)vvEB>Vv$u{ zL#p7$Te0bT7p6k6KRE~_M_(x_E*14dcS^cLdQahu4}~Y@V}H*`IkJ+FHu9D-)j%}^ zBIqZor!zq=n@UQ0e%vppjqru)F5CpAeG`+D5M}J9d#qwx@^~xt=<2StElGBEMwmn! z9k<2#<>;ot^>zJC*;jmfBrm{*wUHwqSkRs2Vw~9+Zhm@fh3bM7|E97ekdm~Y6H`pg z%H5cR-G=xDZW%BKG$!Nkx?!DX?UwzVpULGdipUS*lrAuS1uIqg8Xtd3aPsv^` zU(!5p375}M?ePLS;p9>O0bvWy{>U&NCZ+~Y?zr6%(QKsA$|u!?=l8`DTVb2)kltdS z@GkN;_1k@SIGe8oSO_m*DTvlh!zfKjeCkIQE1rNK1~N*I+?Fb@)S6|l^p#0NHk3a{ zn~UbCkW+WquORL`kvHrt-1bKO&$+FXA6D8f%iKbmFX{f6_q@s98Q*&`$slwDf7q1q z-E+4hqgHw{Iz3M-C+-e4)0UZy)X%Ba9|`?Kkct6`Pn-`U^Z|(O&U4!nzTd^}~ksyV7u;b!tufNwk$?Vp`#El|3)I|mDqzrID z16|M+n}<;e!bkaGXte0d(d+E!MYX-=mbgc&@n)_86Az-w=0i&%RyLfwMDOOZ-#yef z@yE|extWx&e}Coz62Am;wNxmz(6fFz>M&35vUFeIlq^@`Fj5zEPo3BNW%i-VM0&*cczeZ*@SWsH&^GG2-P z1hhu;F}O0cj(<#f5UlHZi%~jaXm;RnYuWV7wT;a(vQ-?zpqIqC6XpUeVP+#EopuOUziNYJ3>V$xw%h$`Pb zEKH{)AR-H(AZa9=GZ|63&4y7t{H)cclIu(iQ2<43e6e7F#SS{@blhJz5&=$OA@mfy0n{4d@7dQhpJMUs*Uz|mZ2MFH{QF<671!qcYMZv^xWH-%2BL| zG-zPb=>|(bE!EmH=VM4lJ3fc{PDU9`(#rba^>M9I!-_2KtaLHS0 zT4u+6Ji$SfL~aTjx#lFym`dUKZy|Qn8i8&0XIl4G%-(743YEk}@O`?w?lZ)Tl2H^rzZ&MAOd67Y{-`Cywrqd3x4L~NFWv*q zG`ji=D9ERHf0g5n)0$xZ-N|9jy$sS>eCV2L0Z_E(rP|$8WB8m-zG24s=1xD=ha%0r zWdPw2Pakss(1R>Ty9PZ0hl@vR>43Qp=3XwoS!>&5#wWR%)l5OPsi}=U#7CR5SVK9O z(n|%QnvaBz!^Z>L*2|?l%{2^Z*dPHC!` z(l;M98YW}QfAvK~rFqX{|CVVUuA9+@s$UhY@eH|TeRN68G{2@XGx^gPj-q^mr;=yA zqE*l+v{{0nrXPp+-{s*!Voawsz~m4Lb3?_FvbPD%@+DCht#PhZ2@Y?gs+t`STg_VC zJh!a<5zFLLhcY=mzM`Lve=U6hBte={rK@}WLu`c9PE15HHDIYBswxw~_9xg3PR(zB zwd(&{e}hk|XVPb9Q94zar~B0Lfx?}Qy&R> z+m)DQN1fwGMm}Cn;;z_8Sr9W9t{gdhP+j>_8vVBYJ!fH_RMSxnhg!YX<0~8M4cZsV zE6X-^tJf3Kv#mS#3B$%JpLD)^bshj9XBevDN)dQvpl@lyRo=)nxo021a`q=R(5wBp7oA8>CvA z!{YN+DFOqV=e|2m-_tZvNP{RsCF>7GPm@`arbKG zACDe)T?42TwKhd=Z=yZ3r$lm->fIT|!Zwl1Y%_O#4b7N&)q~?{O@CY_?B$m$K;mYi zWd?cSCbbc~SEzw5;`#-2E_glQO6j@%gM-m{Ywbo+7t>pf_VN~km1(FWQX@4Q)fMp% z_Gk!`qg5?F*h+GaR!FI+m*&s*9N&b*wI8rDyl~PLUHhjI&x*T{?nq)L3P|Zt;48|<(*RV k>JSA{{1*GKYy0U_z5hkmlwW=)Q~M-g;{Tx^ th)[0] + if (len(candidate_inds) == 0): + continue + + pre_nms_topn = min(cfg.RETINANET.PRE_NMS_TOP_N, len(candidate_inds)) + inds = np.argpartition( + cls_prob_ravel[candidate_inds], -pre_nms_topn)[-pre_nms_topn:] + inds = candidate_inds[inds] + + inds_5d = np.array(np.unravel_index(inds, cls_prob.shape)).transpose() + classes = inds_5d[:, 2] + anchor_ids, y, x = inds_5d[:, 1], inds_5d[:, 3], inds_5d[:, 4] + scores = cls_prob[:, anchor_ids, classes, y, x] + + boxes = np.column_stack((x, y, x, y)).astype(dtype=np.float32) + boxes *= stride + boxes += cell_anchors[anchor_ids, :] + + if not cfg.RETINANET.CLASS_SPECIFIC_BBOX: + box_deltas = box_pred[0, anchor_ids, :, y, x] + else: + box_cls_inds = classes * 4 + box_deltas = np.vstack( + [box_pred[0, ind:ind + 4, yi, xi] + for ind, yi, xi in zip(box_cls_inds, y, x)] + ) + pred_boxes = ( + box_utils.bbox_transform(boxes, box_deltas) + if cfg.TEST.BBOX_REG else boxes) + pred_boxes /= im_scale + pred_boxes = box_utils.clip_tiled_boxes(pred_boxes, im.shape) + box_scores = np.zeros((pred_boxes.shape[0], 5)) + box_scores[:, 0:4] = pred_boxes + box_scores[:, 4] = scores + + for cls in range(1, cfg.MODEL.NUM_CLASSES): + inds = np.where(classes == cls - 1)[0] + if len(inds) > 0: + boxes_all[cls].extend(box_scores[inds, :]) + timers['im_detect_bbox'].toc() + + # Combine predictions across all levels and retain the top scoring by class + timers['misc_bbox'].tic() + detections = [] + for cls, boxes in boxes_all.items(): + cls_dets = np.vstack(boxes).astype(dtype=np.float32) + # do class specific nms here + keep = box_utils.nms(cls_dets, cfg.TEST.NMS) + cls_dets = cls_dets[keep, :] + out = np.zeros((len(keep), 6)) + out[:, 0:5] = cls_dets + out[:, 5].fill(cls) + detections.append(out) + + # detections (N, 6) format: + # detections[:, :4] - boxes + # detections[:, 4] - scores + # detections[:, 5] - classes + detections = np.vstack(detections) + # sort all again + inds = np.argsort(-detections[:, 4]) + detections = detections[inds[0:cfg.TEST.DETECTIONS_PER_IM], :] + + # Convert the detections to image cls_ format (see core/test_engine.py) + num_classes = cfg.MODEL.NUM_CLASSES + cls_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)] + for c in range(1, num_classes): + inds = np.where(detections[:, 5] == c)[0] + cls_boxes[c] = detections[inds, :5] + timers['misc_bbox'].toc() + + return cls_boxes diff --git a/lib/modeling/FPN.py b/lib/modeling/FPN.py index 03cd60bc..a551699c 100644 --- a/lib/modeling/FPN.py +++ b/lib/modeling/FPN.py @@ -140,7 +140,7 @@ def __init__(self, conv_body_func, fpn_level_info, P2only=False): self.extra_pyramid_modules = nn.ModuleList() dim_in = fpn_level_info.dims[0] for i in range(HIGHEST_BACKBONE_LVL + 1, max_level + 1): - self.extra_pyramid_modules( + self.extra_pyramid_modules.append( nn.Conv2d(dim_in, fpn_dim, 3, 2, 1) ) dim_in = fpn_dim @@ -214,7 +214,7 @@ def detectron_weight_mapping(self): }) if hasattr(self, 'extra_pyramid_modules'): - for i in len(self.extra_pyramid_modules): + for i in range(len(self.extra_pyramid_modules)): p_prefix = 'extra_pyramid_modules.%d' % i d_prefix = 'fpn_%d' % (HIGHEST_BACKBONE_LVL + 1 + i) mapping_to_detectron.update({ @@ -246,9 +246,9 @@ def forward(self, x): if hasattr(self, 'extra_pyramid_modules'): blob_in = conv_body_blobs[-1] - fpn_output_blobs.insert(0, self.extra_pyramid_modules(blob_in)) + fpn_output_blobs.insert(0, self.extra_pyramid_modules[0](blob_in)) for module in self.extra_pyramid_modules[1:]: - fpn_output_blobs.insert(0, module(F.relu(fpn_output_blobs[0], inplace=True))) + fpn_output_blobs.insert(0, module(F.relu(fpn_output_blobs[0]))) if self.P2only: # use only the finest level @@ -294,7 +294,7 @@ def forward(self, top_blob, lateral_blob): lat = self.conv_lateral(lateral_blob) # Top-down 2x upsampling # td = F.upsample(top_blob, size=lat.size()[2:], mode='bilinear') - td = F.upsample(top_blob, scale_factor=2, mode='nearest') + td = F.interpolate(top_blob, scale_factor=2, mode='nearest') # Sum lateral and top-down return lat + td diff --git a/lib/modeling/model_builder.py b/lib/modeling/model_builder.py index 0c7f1f49..59a274ef 100644 --- a/lib/modeling/model_builder.py +++ b/lib/modeling/model_builder.py @@ -12,6 +12,7 @@ from model.roi_crop.functions.roi_crop import RoICropFunction from modeling.roi_xfrom.roi_align.functions.roi_align import RoIAlignFunction import modeling.rpn_heads as rpn_heads +import modeling.retinanet_heads as retinanet_heads import modeling.fast_rcnn_heads as fast_rcnn_heads import modeling.mask_rcnn_heads as mask_rcnn_heads import modeling.keypoint_rcnn_heads as keypoint_rcnn_heads @@ -62,8 +63,9 @@ def wrapper(self, *args, **kwargs): with torch.no_grad(): return net_func(self, *args, **kwargs) else: - raise ValueError('You should call this function only on inference.' - 'Set the network in inference mode by net.eval().') + raise ValueError( + 'You should call this function only on inference.' + 'Set the network in inference mode by net.eval().') return wrapper @@ -84,42 +86,59 @@ def __init__(self): self.RPN = rpn_heads.generic_rpn_outputs( self.Conv_Body.dim_out, self.Conv_Body.spatial_scale) - if cfg.FPN.FPN_ON: - # Only supports case when RPN and ROI min levels are the same - assert cfg.FPN.RPN_MIN_LEVEL == cfg.FPN.ROI_MIN_LEVEL - # RPN max level can be >= to ROI max level - assert cfg.FPN.RPN_MAX_LEVEL >= cfg.FPN.ROI_MAX_LEVEL - # FPN RPN max level might be > FPN ROI max level in which case we - # need to discard some leading conv blobs (blobs are ordered from - # max/coarsest level to min/finest level) - self.num_roi_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1 - - # Retain only the spatial scales that will be used for RoI heads. `Conv_Body.spatial_scale` - # may include extra scales that are used for RPN proposals, but not for RoI heads. - self.Conv_Body.spatial_scale = self.Conv_Body.spatial_scale[-self.num_roi_levels:] + if cfg.FPN.FPN_ON: + # Only supports case when RPN and ROI min levels are the same + assert cfg.FPN.RPN_MIN_LEVEL == cfg.FPN.ROI_MIN_LEVEL + # RPN max level can be >= to ROI max level + assert cfg.FPN.RPN_MAX_LEVEL >= cfg.FPN.ROI_MAX_LEVEL + # FPN RPN max level might be > FPN ROI max level in which case we + # need to discard some leading conv blobs (blobs are ordered from + # max/coarsest level to min/finest level) + self.num_roi_levels = cfg.FPN.ROI_MAX_LEVEL - cfg.FPN.ROI_MIN_LEVEL + 1 + + # Retain only the spatial scales that will be used for RoI heads. `Conv_Body.spatial_scale` + # may include extra scales that are used for RPN proposals, but + # not for RoI heads. + self.Conv_Body.spatial_scale = self.Conv_Body.spatial_scale[-self.num_roi_levels:] # BBOX Branch if not cfg.MODEL.RPN_ONLY: - self.Box_Head = get_func(cfg.FAST_RCNN.ROI_BOX_HEAD)( - self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale) - self.Box_Outs = fast_rcnn_heads.fast_rcnn_outputs( - self.Box_Head.dim_out) + if cfg.FAST_RCNN.ROI_BOX_HEAD is '': + # RetinaNet + self.Box_Outs = retinanet_heads.fpn_retinanet_outputs( + self.Conv_Body.dim_out, self.Conv_Body.spatial_scale) + else: + self.Box_Head = get_func( + cfg.FAST_RCNN.ROI_BOX_HEAD)( + self.RPN.dim_out, + self.roi_feature_transform, + self.Conv_Body.spatial_scale) + self.Box_Outs = fast_rcnn_heads.fast_rcnn_outputs( + self.Box_Head.dim_out) # Mask Branch if cfg.MODEL.MASK_ON: - self.Mask_Head = get_func(cfg.MRCNN.ROI_MASK_HEAD)( - self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale) + self.Mask_Head = get_func( + cfg.MRCNN.ROI_MASK_HEAD)( + self.RPN.dim_out, + self.roi_feature_transform, + self.Conv_Body.spatial_scale) if getattr(self.Mask_Head, 'SHARE_RES5', False): self.Mask_Head.share_res5_module(self.Box_Head.res5) - self.Mask_Outs = mask_rcnn_heads.mask_rcnn_outputs(self.Mask_Head.dim_out) + self.Mask_Outs = mask_rcnn_heads.mask_rcnn_outputs( + self.Mask_Head.dim_out) # Keypoints Branch if cfg.MODEL.KEYPOINTS_ON: - self.Keypoint_Head = get_func(cfg.KRCNN.ROI_KEYPOINTS_HEAD)( - self.RPN.dim_out, self.roi_feature_transform, self.Conv_Body.spatial_scale) + self.Keypoint_Head = get_func( + cfg.KRCNN.ROI_KEYPOINTS_HEAD)( + self.RPN.dim_out, + self.roi_feature_transform, + self.Conv_Body.spatial_scale) if getattr(self.Keypoint_Head, 'SHARE_RES5', False): self.Keypoint_Head.share_res5_module(self.Box_Head.res5) - self.Keypoint_Outs = keypoint_rcnn_heads.keypoint_outputs(self.Keypoint_Head.dim_out) + self.Keypoint_Outs = keypoint_rcnn_heads.keypoint_outputs( + self.Keypoint_Head.dim_out) self._init_modules() @@ -127,10 +146,16 @@ def _init_modules(self): if cfg.MODEL.LOAD_IMAGENET_PRETRAINED_WEIGHTS: resnet_utils.load_pretrained_imagenet_weights(self) # Check if shared weights are equaled - if cfg.MODEL.MASK_ON and getattr(self.Mask_Head, 'SHARE_RES5', False): - assert compare_state_dict(self.Mask_Head.res5.state_dict(), self.Box_Head.res5.state_dict()) - if cfg.MODEL.KEYPOINTS_ON and getattr(self.Keypoint_Head, 'SHARE_RES5', False): - assert compare_state_dict(self.Keypoint_Head.res5.state_dict(), self.Box_Head.res5.state_dict()) + if cfg.MODEL.MASK_ON and getattr( + self.Mask_Head, 'SHARE_RES5', False): + assert compare_state_dict( + self.Mask_Head.res5.state_dict(), + self.Box_Head.res5.state_dict()) + if cfg.MODEL.KEYPOINTS_ON and getattr( + self.Keypoint_Head, 'SHARE_RES5', False): + assert compare_state_dict( + self.Keypoint_Head.res5.state_dict(), + self.Box_Head.res5.state_dict()) if cfg.TRAIN.FREEZE_CONV_BODY: for p in self.Conv_Body.parameters(): @@ -154,25 +179,30 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs): blob_conv = self.Conv_Body(im_data) - rpn_ret = self.RPN(blob_conv, im_info, roidb) - # if self.training: # # can be used to infer fg/bg ratio # return_dict['rois_label'] = rpn_ret['labels_int32'] - if cfg.FPN.FPN_ON: + if cfg.RPN.RPN_ON: + rpn_ret = self.RPN(blob_conv, im_info, roidb) + + if cfg.FPN.FPN_ON and cfg.FAST_RCNN.ROI_BOX_HEAD is not '': # Retain only the blobs that will be used for RoI heads. `blob_conv` may include - # extra blobs that are used for RPN proposals, but not for RoI heads. + # extra blobs that are used for RPN proposals, but not for RoI + # heads. blob_conv = blob_conv[-self.num_roi_levels:] if not self.training: return_dict['blob_conv'] = blob_conv if not cfg.MODEL.RPN_ONLY: - if cfg.MODEL.SHARE_RES5 and self.training: - box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret) + if cfg.FAST_RCNN.ROI_BOX_HEAD is not '': + if cfg.MODEL.SHARE_RES5 and self.training: + box_feat, res5_feat = self.Box_Head(blob_conv, rpn_ret) + else: + box_feat = self.Box_Head(blob_conv, rpn_ret) else: - box_feat = self.Box_Head(blob_conv, rpn_ret) + box_feat = blob_conv cls_score, bbox_pred = self.Box_Outs(box_feat) else: # TODO: complete the returns for RPN only situation @@ -181,46 +211,62 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs): if self.training: return_dict['losses'] = {} return_dict['metrics'] = {} - # rpn loss - rpn_kwargs.update(dict( - (k, rpn_ret[k]) for k in rpn_ret.keys() - if (k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred')) - )) - loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses(**rpn_kwargs) - if cfg.FPN.FPN_ON: - for i, lvl in enumerate(range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)): - return_dict['losses']['loss_rpn_cls_fpn%d' % lvl] = loss_rpn_cls[i] - return_dict['losses']['loss_rpn_bbox_fpn%d' % lvl] = loss_rpn_bbox[i] - else: - return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls - return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox - # bbox loss - loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses( - cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'], - rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights']) - return_dict['losses']['loss_cls'] = loss_cls - return_dict['losses']['loss_bbox'] = loss_bbox - return_dict['metrics']['accuracy_cls'] = accuracy_cls + if cfg.FAST_RCNN.ROI_BOX_HEAD is not '': + # rpn loss + rpn_kwargs.update(dict((k, rpn_ret[k]) for k in rpn_ret.keys() if ( + k.startswith('rpn_cls_logits') or k.startswith('rpn_bbox_pred')))) + loss_rpn_cls, loss_rpn_bbox = rpn_heads.generic_rpn_losses( + **rpn_kwargs) + if cfg.FPN.FPN_ON: + for i, lvl in enumerate( + range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)): + return_dict['losses']['loss_rpn_cls_fpn%d' % + lvl] = loss_rpn_cls[i] + return_dict['losses']['loss_rpn_bbox_fpn%d' % + lvl] = loss_rpn_bbox[i] + else: + return_dict['losses']['loss_rpn_cls'] = loss_rpn_cls + return_dict['losses']['loss_rpn_bbox'] = loss_rpn_bbox + + # bbox loss + loss_cls, loss_bbox, accuracy_cls = fast_rcnn_heads.fast_rcnn_losses( + cls_score, bbox_pred, rpn_ret['labels_int32'], rpn_ret['bbox_targets'], + rpn_ret['bbox_inside_weights'], rpn_ret['bbox_outside_weights']) + return_dict['losses']['loss_cls'] = loss_cls + return_dict['losses']['loss_bbox'] = loss_bbox + return_dict['metrics']['accuracy_cls'] = accuracy_cls + + if cfg.RETINANET.RETINANET_ON: + loss_retnet_cls, loss_retnet_bbox = retinanet_heads.add_fpn_retinanet_losses( + cls_score, bbox_pred, **rpn_kwargs) + for i, lvl in enumerate( + range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1)): + return_dict['losses']['loss_retnet_cls_fpn%d' % + lvl] = loss_retnet_cls[i] + return_dict['losses']['loss_retnet_bbox_fpn%d' % + lvl] = loss_retnet_bbox[i] if cfg.MODEL.MASK_ON: if getattr(self.Mask_Head, 'SHARE_RES5', False): - mask_feat = self.Mask_Head(res5_feat, rpn_ret, - roi_has_mask_int32=rpn_ret['roi_has_mask_int32']) + mask_feat = self.Mask_Head( + res5_feat, rpn_ret, roi_has_mask_int32=rpn_ret['roi_has_mask_int32']) else: mask_feat = self.Mask_Head(blob_conv, rpn_ret) mask_pred = self.Mask_Outs(mask_feat) # return_dict['mask_pred'] = mask_pred # mask loss - loss_mask = mask_rcnn_heads.mask_rcnn_losses(mask_pred, rpn_ret['masks_int32']) + loss_mask = mask_rcnn_heads.mask_rcnn_losses( + mask_pred, rpn_ret['masks_int32']) return_dict['losses']['loss_mask'] = loss_mask if cfg.MODEL.KEYPOINTS_ON: if getattr(self.Keypoint_Head, 'SHARE_RES5', False): # No corresponding keypoint head implemented yet (Neither in Detectron) - # Also, rpn need to generate the label 'roi_has_keypoints_int32' - kps_feat = self.Keypoint_Head(res5_feat, rpn_ret, - roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32']) + # Also, rpn need to generate the label + # 'roi_has_keypoints_int32' + kps_feat = self.Keypoint_Head( + res5_feat, rpn_ret, roi_has_keypoints_int32=rpn_ret['roi_has_keypoint_int32']) else: kps_feat = self.Keypoint_Head(blob_conv, rpn_ret) kps_pred = self.Keypoint_Outs(kps_feat) @@ -243,14 +289,22 @@ def _forward(self, data, im_info, roidb=None, **rpn_kwargs): else: # Testing - return_dict['rois'] = rpn_ret['rois'] + if cfg.FAST_RCNN.ROI_BOX_HEAD is not '': + return_dict['rois'] = rpn_ret['rois'] return_dict['cls_score'] = cls_score return_dict['bbox_pred'] = bbox_pred return return_dict - def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoIPoolF', - resolution=7, spatial_scale=1. / 16., sampling_ratio=0): + def roi_feature_transform( + self, + blobs_in, + rpn_ret, + blob_rois='rois', + method='RoIPoolF', + resolution=7, + spatial_scale=1. / 16., + sampling_ratio=0): """Add the specified RoI pooling method. The sampling_ratio argument is supported for some, but not all, RoI transform methods. @@ -273,12 +327,16 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI sc = spatial_scale[k_max - lvl] # in reversed order bl_rois = blob_rois + '_fpn' + str(lvl) if len(rpn_ret[bl_rois]): - rois = Variable(torch.from_numpy(rpn_ret[bl_rois])).cuda(device_id) + rois = Variable(torch.from_numpy( + rpn_ret[bl_rois])).cuda(device_id) if method == 'RoIPoolF': - # Warning!: Not check if implementation matches Detectron - xform_out = RoIPoolFunction(resolution, resolution, sc)(bl_in, rois) + # Warning!: Not check if implementation matches + # Detectron + xform_out = RoIPoolFunction( + resolution, resolution, sc)(bl_in, rois) elif method == 'RoICrop': - # Warning!: Not check if implementation matches Detectron + # Warning!: Not check if implementation matches + # Detectron grid_xy = net_utils.affine_grid_gen( rois, bl_in.size()[2:], self.grid_size) grid_yx = torch.stack( @@ -288,7 +346,8 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI xform_out = F.max_pool2d(xform_out, 2, 2) elif method == 'RoIAlign': xform_out = RoIAlignFunction( - resolution, resolution, sc, sampling_ratio)(bl_in, rois) + resolution, resolution, sc, sampling_ratio)( + bl_in, rois) bl_out_list.append(xform_out) # The pooled features from all levels are concatenated along the @@ -299,7 +358,10 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI device_id = xform_shuffled.get_device() restore_bl = rpn_ret[blob_rois + '_idx_restore_int32'] restore_bl = Variable( - torch.from_numpy(restore_bl.astype('int64', copy=False))).cuda(device_id) + torch.from_numpy( + restore_bl.astype( + 'int64', + copy=False))).cuda(device_id) xform_out = xform_shuffled[restore_bl] else: # Single feature level @@ -307,11 +369,14 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI # (batch_idx, x1, y1, x2, y2) specifying an image batch index and a # rectangle (x1, y1, x2, y2) device_id = blobs_in.get_device() - rois = Variable(torch.from_numpy(rpn_ret[blob_rois])).cuda(device_id) + rois = Variable(torch.from_numpy( + rpn_ret[blob_rois])).cuda(device_id) if method == 'RoIPoolF': - xform_out = RoIPoolFunction(resolution, resolution, spatial_scale)(blobs_in, rois) + xform_out = RoIPoolFunction( + resolution, resolution, spatial_scale)(blobs_in, rois) elif method == 'RoICrop': - grid_xy = net_utils.affine_grid_gen(rois, blobs_in.size()[2:], self.grid_size) + grid_xy = net_utils.affine_grid_gen( + rois, blobs_in.size()[2:], self.grid_size) grid_yx = torch.stack( [grid_xy.data[:, :, :, 1], grid_xy.data[:, :, :, 0]], 3).contiguous() xform_out = RoICropFunction()(blobs_in, Variable(grid_yx).detach()) @@ -319,7 +384,12 @@ def roi_feature_transform(self, blobs_in, rpn_ret, blob_rois='rois', method='RoI xform_out = F.max_pool2d(xform_out, 2, 2) elif method == 'RoIAlign': xform_out = RoIAlignFunction( - resolution, resolution, spatial_scale, sampling_ratio)(blobs_in, rois) + resolution, + resolution, + spatial_scale, + sampling_ratio)( + blobs_in, + rois) return xform_out @@ -329,7 +399,8 @@ def convbody_net(self, data): blob_conv = self.Conv_Body(data) if cfg.FPN.FPN_ON: # Retain only the blobs that will be used for RoI heads. `blob_conv` may include - # extra blobs that are used for RPN proposals, but not for RoI heads. + # extra blobs that are used for RPN proposals, but not for RoI + # heads. blob_conv = blob_conv[-self.num_roi_levels:] return blob_conv diff --git a/lib/modeling/retinanet_heads.py b/lib/modeling/retinanet_heads.py new file mode 100644 index 00000000..a4cdbb81 --- /dev/null +++ b/lib/modeling/retinanet_heads.py @@ -0,0 +1,207 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +"""RetinaNet model heads and losses. See: https://arxiv.org/abs/1708.02002.""" +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +import utils.net as net_utils +import math + +from core.config import cfg + + +class fpn_retinanet_outputs(nn.Module): + """Add RetinaNet on FPN specific outputs.""" + + def __init__(self, dim_in, spatial_scales): + super().__init__() + self.dim_out = dim_in + self.dim_in = dim_in + self.spatial_scales = spatial_scales + self.dim_out = self.dim_in + self.num_anchors = len(cfg.RETINANET.ASPECT_RATIOS) * \ + cfg.RETINANET.SCALES_PER_OCTAVE + + # Create conv ops shared by all FPN levels + self.n_conv_fpn_cls_modules = nn.ModuleList() + self.n_conv_fpn_bbox_modules = nn.ModuleList() + for nconv in range(cfg.RETINANET.NUM_CONVS): + self.n_conv_fpn_cls_modules.append( + nn.Conv2d(self.dim_in, self.dim_out, 3, 1, 1)) + self.n_conv_fpn_bbox_modules.append( + nn.Conv2d(self.dim_in, self.dim_out, 3, 1, 1)) + + cls_pred_dim = cfg.MODEL.NUM_CLASSES if cfg.RETINANET.SOFTMAX \ + else (cfg.MODEL.NUM_CLASSES - 1) + + # unpacked bbox feature and add prediction layers + self.bbox_regr_dim = 4 * (cfg.MODEL.NUM_CLASSES - 1) \ + if cfg.RETINANET.CLASS_SPECIFIC_BBOX else 4 + + self.fpn_cls_score = nn.Conv2d(self.dim_out, + cls_pred_dim * self.num_anchors, 3, 1, 1) + self.fpn_bbox_score = nn.Conv2d(self.dim_out, + self.bbox_regr_dim * self.num_anchors, 3, 1, 1) + + self._init_weights() + + def _init_weights(self): + def init_func(m): + if isinstance(m, nn.Conv2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + for child_m in self.children(): + if isinstance(child_m, nn.ModuleList): + child_m.apply(init_func) + + init.normal_(self.fpn_cls_score.weight, std=0.01) + init.constant_(self.fpn_cls_score.bias, + -math.log((1 - cfg.RETINANET.PRIOR_PROB) / cfg.RETINANET.PRIOR_PROB)) + + init.normal_(self.fpn_bbox_score.weight, std=0.01) + init.constant_(self.fpn_bbox_score.bias, 0) + + def detectron_weight_mapping(self): + k_min = cfg.FPN.RPN_MIN_LEVEL + mapping_to_detectron = { + 'n_conv_fpn_cls_modules.0.weight': 'retnet_cls_conv_n0_fpn%d_w' % k_min, + 'n_conv_fpn_cls_modules.0.bias': 'retnet_cls_conv_n0_fpn%d_b' % k_min, + 'n_conv_fpn_cls_modules.1.weight': 'retnet_cls_conv_n1_fpn%d_w' % k_min, + 'n_conv_fpn_cls_modules.1.bias': 'retnet_cls_conv_n1_fpn%d_b' % k_min, + 'n_conv_fpn_cls_modules.2.weight': 'retnet_cls_conv_n2_fpn%d_w' % k_min, + 'n_conv_fpn_cls_modules.2.bias': 'retnet_cls_conv_n2_fpn%d_b' % k_min, + 'n_conv_fpn_cls_modules.3.weight': 'retnet_cls_conv_n3_fpn%d_w' % k_min, + 'n_conv_fpn_cls_modules.3.bias': 'retnet_cls_conv_n3_fpn%d_b' % k_min, + + 'n_conv_fpn_bbox_modules.0.weight': 'retnet_bbox_conv_n0_fpn%d_w' % k_min, + 'n_conv_fpn_bbox_modules.0.bias': 'retnet_bbox_conv_n0_fpn%d_b' % k_min, + 'n_conv_fpn_bbox_modules.1.weight': 'retnet_bbox_conv_n1_fpn%d_w' % k_min, + 'n_conv_fpn_bbox_modules.1.bias': 'retnet_bbox_conv_n1_fpn%d_b' % k_min, + 'n_conv_fpn_bbox_modules.2.weight': 'retnet_bbox_conv_n2_fpn%d_w' % k_min, + 'n_conv_fpn_bbox_modules.2.bias': 'retnet_bbox_conv_n2_fpn%d_b' % k_min, + 'n_conv_fpn_bbox_modules.3.weight': 'retnet_bbox_conv_n3_fpn%d_w' % k_min, + 'n_conv_fpn_bbox_modules.3.bias': 'retnet_bbox_conv_n3_fpn%d_b' % k_min, + + 'fpn_cls_score.weight': 'retnet_cls_pred_fpn%d_w' % k_min, + 'fpn_cls_score.bias': 'retnet_cls_pred_fpn%d_b' % k_min, + 'fpn_bbox_score.weight': 'retnet_bbox_pred_fpn%d_w' % k_min, + 'fpn_bbox_score.bias': 'retnet_bbox_pred_fpn%d_b' % k_min + } + return mapping_to_detectron, [] + + def forward(self, blobs_in): + k_max = cfg.FPN.RPN_MAX_LEVEL # coarsest level of pyramid + k_min = cfg.FPN.RPN_MIN_LEVEL # finest level of pyramid + assert len(blobs_in) == k_max - k_min + 1 + bbox_feat_list = [] + cls_score = [] + bbox_pred = [] + + # ========================================================================== + # classification tower with logits and prob prediction + # ========================================================================== + for lvl in range(k_min, k_max + 1): + bl_in = blobs_in[k_max - lvl] # blobs_in is in reversed order + # classification tower stack convolution starts + for nconv in range(cfg.RETINANET.NUM_CONVS): + bl_out = self.n_conv_fpn_cls_modules[nconv](bl_in) + bl_in = F.relu(bl_out, inplace=True) + bl_feat = bl_in + + # cls tower stack convolution ends. Add the logits layer now + retnet_cls_pred = self.fpn_cls_score(bl_feat) + + if not self.training: + if cfg.RETINANET.SOFTMAX: + raise NotImplementedError("To be implemented") + else: # sigmoid + retnet_cls_probs = retnet_cls_pred.sigmoid() + cls_score.append(retnet_cls_probs) + else: + cls_score.append(retnet_cls_pred) + + if cfg.RETINANET.SHARE_CLS_BBOX_TOWER: + bbox_feat_list.append(bl_feat) + + # ========================================================================== + # bbox tower if not sharing features with the classification tower with + # logits and prob prediction + # ========================================================================== + if not cfg.RETINANET.SHARE_CLS_BBOX_TOWER: + for lvl in range(k_min, k_max + 1): + bl_in = blobs_in[k_max - lvl] # blobs_in is in reversed order + # classification tower stack convolution starts + for nconv in range(cfg.RETINANET.NUM_CONVS): + bl_out = self.n_conv_fpn_bbox_modules[nconv](bl_in) + bl_in = F.relu(bl_out, inplace=True) + # Add octave scales and aspect ratio + # At least 1 convolution for dealing different aspect ratios + bl_feat = bl_in + bbox_feat_list.append(bl_feat) + + # Depending on the features [shared/separate] for bbox, add prediction layer + for i, lvl in enumerate(range(k_min, k_max + 1)): + bl_feat = bbox_feat_list[i] + retnet_bbox_pred = self.fpn_bbox_score(bl_feat) + bbox_pred.append(retnet_bbox_pred) + + return cls_score, bbox_pred + + +def add_fpn_retinanet_losses(cls_score, bbox_pred, **kwargs): + k_max = cfg.FPN.RPN_MAX_LEVEL # coarsest level of pyramid + k_min = cfg.FPN.RPN_MIN_LEVEL # finest level of pyramid + + losses_cls = [] + losses_bbox = [] + for i, lvl in enumerate(range(k_min, k_max + 1)): + slvl = str(lvl) + h, w = cls_score[i].shape[2:] + retnet_cls_labels_fpn = kwargs['retnet_cls_labels_fpn' + + slvl][:, :, :h, :w] + retnet_bbox_targets_fpn = kwargs['retnet_roi_bbox_targets_fpn' + + slvl][:, :, :, :h, :w] + retnet_bbox_inside_weights_fpn = kwargs['retnet_bbox_inside_weights_wide_fpn' + + slvl][:, :, :, :h, :w] + retnet_fg_num = kwargs['retnet_fg_num'] + + # ========================================================================== + # bbox regression loss - SelectSmoothL1Loss for multiple anchors at a location + # ========================================================================== + bbox_loss = net_utils.select_smooth_l1_loss( + bbox_pred[i], retnet_bbox_targets_fpn, + retnet_bbox_inside_weights_fpn, + retnet_fg_num, + beta=cfg.RETINANET.BBOX_REG_BETA) + + # ========================================================================== + # cls loss - depends on softmax/sigmoid outputs + # ========================================================================== + if cfg.RETINANET.SOFTMAX: + raise NotImplementedError("To be implemented") + else: + cls_loss = net_utils.sigmoid_focal_loss( + cls_score[i], retnet_cls_labels_fpn.float(), + cfg.MODEL.NUM_CLASSES, retnet_fg_num, alpha=cfg.RETINANET.LOSS_ALPHA, + gamma=cfg.RETINANET.LOSS_GAMMA + ) + + losses_bbox.append(bbox_loss) + losses_cls.append(cls_loss) + + return losses_cls, losses_bbox diff --git a/lib/roi_data/minibatch.py b/lib/roi_data/minibatch.py index 65ef9e66..62bff744 100644 --- a/lib/roi_data/minibatch.py +++ b/lib/roi_data/minibatch.py @@ -4,6 +4,7 @@ from core.config import cfg import utils.blob as blob_utils import roi_data.rpn +import roi_data.retinanet def get_minibatch_blob_names(is_training=True): @@ -15,7 +16,9 @@ def get_minibatch_blob_names(is_training=True): # RPN-only or end-to-end Faster R-CNN blob_names += roi_data.rpn.get_rpn_blob_names(is_training=is_training) elif cfg.RETINANET.RETINANET_ON: - raise NotImplementedError + blob_names += roi_data.retinanet.get_retinanet_blob_names( + is_training=is_training + ) else: # Fast R-CNN like models trained on precomputed proposals blob_names += roi_data.fast_rcnn.get_fast_rcnn_blob_names( @@ -37,7 +40,13 @@ def get_minibatch(roidb): # RPN-only or end-to-end Faster/Mask R-CNN valid = roi_data.rpn.add_rpn_blobs(blobs, im_scales, roidb) elif cfg.RETINANET.RETINANET_ON: - raise NotImplementedError + im_width, im_height = im_blob.shape[3], im_blob.shape[2] + # im_width, im_height corresponds to the network input: padded image + # (if needed) width and height. We pass it as input and slice the data + # accordingly so that we don't need to use SampleAsOp + valid = roi_data.retinanet.add_retinanet_blobs( + blobs, im_scales, roidb, im_width, im_height + ) else: # Fast R-CNN like models trained on precomputed proposals valid = roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb) diff --git a/lib/roi_data/retinanet.py b/lib/roi_data/retinanet.py new file mode 100644 index 00000000..03b329bb --- /dev/null +++ b/lib/roi_data/retinanet.py @@ -0,0 +1,258 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +"""Compute minibatch blobs for training a RetinaNet network.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np +import logging + +import utils.boxes as box_utils +import roi_data.data_utils as data_utils +from core.config import cfg + +logger = logging.getLogger(__name__) + + +def get_retinanet_blob_names(is_training=True): + """ + Returns blob names in the order in which they are read by the data + loader. + """ + # im_info: (height, width, image scale) + blob_names = ['im_info'] + assert cfg.FPN.FPN_ON, "RetinaNet uses FPN for dense detection" + # Same format as RPN blobs, but one per FPN level + if is_training: + blob_names += ['roidb', 'retnet_fg_num', 'retnet_bg_num'] + for lvl in range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL + 1): + suffix = 'fpn{}'.format(lvl) + blob_names += [ + 'retnet_cls_labels_' + suffix, + 'retnet_roi_bbox_targets_' + suffix, + 'retnet_bbox_inside_weights_wide_' + suffix, + ] + return blob_names + + +def add_retinanet_blobs(blobs, im_scales, roidb, image_width, image_height): + """Add RetinaNet blobs.""" + # RetinaNet is applied to many feature levels, as in the FPN paper + k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL + scales_per_octave = cfg.RETINANET.SCALES_PER_OCTAVE + num_aspect_ratios = len(cfg.RETINANET.ASPECT_RATIOS) + aspect_ratios = cfg.RETINANET.ASPECT_RATIOS + anchor_scale = cfg.RETINANET.ANCHOR_SCALE + + # get anchors from all levels for all scales/aspect ratios + foas = [] + for lvl in range(k_min, k_max + 1): + stride = 2. ** lvl + for octave in range(scales_per_octave): + octave_scale = 2 ** (octave / float(scales_per_octave)) + for idx in range(num_aspect_ratios): + anchor_sizes = (stride * octave_scale * anchor_scale,) + anchor_aspect_ratios = (aspect_ratios[idx],) + foa = data_utils.get_field_of_anchors( + stride, anchor_sizes, anchor_aspect_ratios, octave, idx) + foas.append(foa) + all_anchors = np.concatenate([f.field_of_anchors for f in foas]) + + blobs['retnet_fg_num'], blobs['retnet_bg_num'] = 0.0, 0.0 + for im_i, entry in enumerate(roidb): + scale = im_scales[im_i] + im_height = np.round(entry['height'] * scale) + im_width = np.round(entry['width'] * scale) + gt_inds = np.where( + (entry['gt_classes'] > 0) & (entry['is_crowd'] == 0))[0] + assert len(gt_inds) > 0, \ + 'Empty ground truth empty for image is not allowed. Please check.' + + gt_rois = entry['boxes'][gt_inds, :] * scale + gt_classes = entry['gt_classes'][gt_inds] + + im_info = np.array([[im_height, im_width, scale]], dtype=np.float32) + blobs['im_info'].append(im_info) + + retinanet_blobs, fg_num, bg_num = _get_retinanet_blobs( + foas, all_anchors, gt_rois, gt_classes, image_width, image_height) + for i, foa in enumerate(foas): + for k, v in retinanet_blobs[i].items(): + level = int(np.log2(foa.stride)) + key = '{}_fpn{}'.format(k, level) + blobs[key].append(v) + blobs['retnet_fg_num'] += fg_num + blobs['retnet_bg_num'] += bg_num + + blobs['retnet_fg_num'] = blobs['retnet_fg_num'].astype(np.float32) + blobs['retnet_bg_num'] = blobs['retnet_bg_num'].astype(np.float32) + + N = len(roidb) + for k, v in blobs.items(): + if isinstance(v, list) and len(v) > 0: + # compute number of anchors + A = int(len(v) / N) + # for the cls branch labels [per fpn level], + # we have blobs['retnet_cls_labels_fpn{}'] as a list until this step + # and length of this list is N x A where + # N = num_images, A = num_anchors for example, N = 2, A = 9 + # Each element of the list has the shape 1 x 1 x H x W where H, W are + # spatial dimension of curret fpn lvl. Let a{i} denote the element + # corresponding to anchor i [9 anchors total] in the list. + # The elements in the list are in order [[a0, ..., a9], [a0, ..., a9]] + # however the network will make predictions like 2 x (9 * 80) x H x W + # so we first concatenate the elements of each image to a numpy array + # and then concatenate the two images to get the 2 x 9 x H x W + + if k.find('retnet_cls_labels') >= 0 \ + or k.find('retnet_roi_bbox_targets') >= 0: + tmp = [] + # concat anchors within an image + for i in range(0, len(v), A): + tmp.append(np.concatenate(v[i: i + A], axis=1)) + # concat images + blobs[k] = np.concatenate(tmp, axis=0) + else: + # for the bbox branch elements [per FPN level], + # we have the targets and the fg boxes locations + # in the shape: M x 4 where M is the number of fg locations in a + # given image at the current FPN level. For the given level, + # the bbox predictions will be. The elements in the list are in + # order [[a0, ..., a9], [a0, ..., a9]] + # Concatenate them to form M x 4 + blobs[k] = np.expand_dims(np.concatenate(v, axis=0), axis=0) + + valid_keys = [ + 'has_visible_keypoints', 'boxes', 'segms', 'seg_areas', 'gt_classes', + 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints' + ] + minimal_roidb = [{} for _ in range(len(roidb))] + for i, e in enumerate(roidb): + for k in valid_keys: + if k in e: + minimal_roidb[i][k] = e[k] + # blobs['roidb'] = blob_utils.serialize(minimal_roidb) + blobs['roidb'] = minimal_roidb + + return True + + +def _get_retinanet_blobs( + foas, all_anchors, gt_boxes, gt_classes, im_width, im_height): + total_anchors = all_anchors.shape[0] + logger.debug('Getting mad blobs: im_height {} im_width: {}'.format( + im_height, im_width)) + + inds_inside = np.arange(all_anchors.shape[0]) + anchors = all_anchors + num_inside = len(inds_inside) + + logger.debug('total_anchors: {}'.format(total_anchors)) + logger.debug('inds_inside: {}'.format(num_inside)) + logger.debug('anchors.shape: {}'.format(anchors.shape)) + + # Compute anchor labels: + # label=1 is positive, 0 is negative, -1 is don't care (ignore) + labels = np.empty((num_inside,), dtype=np.float32) + labels.fill(-1) + if len(gt_boxes) > 0: + # Compute overlaps between the anchors and the gt boxes overlaps + anchor_by_gt_overlap = box_utils.bbox_overlaps(anchors, gt_boxes) + # Map from anchor to gt box that has highest overlap + anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(axis=1) + # For each anchor, amount of overlap with most overlapping gt box + anchor_to_gt_max = anchor_by_gt_overlap[ + np.arange(num_inside), anchor_to_gt_argmax] + + # Map from gt box to an anchor that has highest overlap + gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(axis=0) + # For each gt box, amount of overlap with most overlapping anchor + gt_to_anchor_max = anchor_by_gt_overlap[ + gt_to_anchor_argmax, np.arange(anchor_by_gt_overlap.shape[1])] + # Find all anchors that share the max overlap amount + # (this includes many ties) + anchors_with_max_overlap = np.where( + anchor_by_gt_overlap == gt_to_anchor_max)[0] + + # Fg label: for each gt use anchors with highest overlap + # (including ties) + gt_inds = anchor_to_gt_argmax[anchors_with_max_overlap] + labels[anchors_with_max_overlap] = gt_classes[gt_inds] + # Fg label: above threshold IOU + inds = anchor_to_gt_max >= cfg.RETINANET.POSITIVE_OVERLAP + gt_inds = anchor_to_gt_argmax[inds] + labels[inds] = gt_classes[gt_inds] + + fg_inds = np.where(labels >= 1)[0] + bg_inds = np.where(anchor_to_gt_max < cfg.RETINANET.NEGATIVE_OVERLAP)[0] + labels[bg_inds] = 0 + num_fg, num_bg = len(fg_inds), len(bg_inds) + + bbox_targets = np.zeros((num_inside, 4), dtype=np.float32) + bbox_targets[fg_inds, :] = data_utils.compute_targets( + anchors[fg_inds, :], gt_boxes[anchor_to_gt_argmax[fg_inds], :]) + + # Bbox regression loss has the form: + # loss(x) = weight_outside * L(weight_inside * x) + # Inside weights allow us to set zero loss on an element-wise basis + # Bbox regression is only trained on positive examples so we set their + # weights to 1.0 (or otherwise if config is different) and 0 otherwise + bbox_inside_weights = np.zeros((num_inside, 4), dtype=np.float32) + bbox_inside_weights[labels >= 1, :] = (1.0, 1.0, 1.0, 1.0) + + # Map up to original set of anchors + labels = data_utils.unmap(labels, total_anchors, inds_inside, fill=-1) + bbox_inside_weights = data_utils.unmap( + bbox_inside_weights, total_anchors, inds_inside, fill=0 + ) + bbox_targets = data_utils.unmap( + bbox_targets, total_anchors, inds_inside, fill=0) + + # Split the generated labels, etc. into labels per each field of anchors + blobs_out = [] + start_idx = 0 + for i, foa in enumerate(foas): + H = foa.field_size + W = foa.field_size + end_idx = start_idx + H * W + _labels = labels[start_idx:end_idx] + _bbox_targets = bbox_targets[start_idx:end_idx, :] + _bbox_inside_weights = bbox_inside_weights[start_idx:end_idx, :] + start_idx = end_idx + # labels output with shape (1, height, width) + _labels = _labels.reshape((1, 1, H, W)) + # bbox_targets output with shape (1, 4 * A, height, width) + _bbox_targets = _bbox_targets.reshape( + (1, 1, H, W, 4)).transpose(0, 1, 4, 2, 3) + + # bbox_inside_weights output with shape (1, 4 * A, height, width) + _bbox_inside_weights = _bbox_inside_weights.reshape( + (1, H, W, 4)).transpose(0, 3, 1, 2) + + blobs_out.append( + dict( + retnet_cls_labels=_labels.astype(np.int32), + retnet_roi_bbox_targets=_bbox_targets.astype(np.float32), + retnet_bbox_inside_weights_wide=_bbox_inside_weights + )) + out_num_fg = np.array([num_fg + 1.0], dtype=np.float32) + out_num_bg = ( + np.array([num_bg + 1.0]) * (cfg.MODEL.NUM_CLASSES - 1) + + out_num_fg * (cfg.MODEL.NUM_CLASSES - 2)) + return blobs_out, out_num_fg, out_num_bg diff --git a/lib/utils/net.py b/lib/utils/net.py index 32c5d705..3f5e7f74 100644 --- a/lib/utils/net.py +++ b/lib/utils/net.py @@ -24,7 +24,7 @@ def smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_we abs_in_box_diff = torch.abs(in_box_diff) smoothL1_sign = (abs_in_box_diff < beta).detach().float() in_loss_box = smoothL1_sign * 0.5 * torch.pow(in_box_diff, 2) / beta + \ - (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta)) + (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta)) out_loss_box = bbox_outside_weights * in_loss_box loss_box = out_loss_box N = loss_box.size(0) # batch size @@ -32,6 +32,50 @@ def smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_we return loss_box +def select_smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, num_fg, beta=1.0): + bbox_pred = bbox_pred.reshape\ + ((bbox_pred.shape[0], bbox_targets.shape[1], 4, bbox_pred.shape[2], bbox_pred.shape[3])) + box_diff = bbox_pred - bbox_targets + in_box_diff = bbox_inside_weights * box_diff + abs_in_box_diff = torch.abs(in_box_diff) + smoothL1_sign = (abs_in_box_diff < beta).detach().float() + loss_box = smoothL1_sign * 0.5 * torch.pow(in_box_diff, 2) / beta + \ + (1 - smoothL1_sign) * (abs_in_box_diff - (0.5 * beta)) + loss_box = loss_box.view(-1).sum(0) / num_fg.sum() + return loss_box + + +def sigmoid_focal_loss(cls_preds, cls_targets, num_classes, num_fg, alpha=0.25, gamma=2): + masked_cls_preds = cls_preds.reshape(( + cls_preds.size(0), cls_targets.size(1), num_classes - 1, + cls_preds.size(2), cls_preds.size(3))).permute((0, 1, 3, 4, 2)).contiguous().\ + view(-1, num_classes-1) + masked_cls_targets = cls_targets.view(-1) + + weights = (masked_cls_targets >= 0).float() + weights = weights.unsqueeze(1) + + t = masked_cls_preds.data.new( + masked_cls_preds.size(0), num_classes).fill_(0) + ids = masked_cls_targets.view(-1, 1) % (num_classes) + t.scatter_(1, ids.long(), 1.) + t = t[:, 1:] + + p = masked_cls_preds.sigmoid() + # w = alpha if t > 0 else 1-alpha + alpha_factor = alpha * t + (1 - alpha) * (1 - t) + # pt = p if t > 0 else 1-p + focal_weight = p * t + (1 - p) * (1 - t) + focal_weight = alpha_factor * (1 - focal_weight).pow(gamma) + + cls_loss = focal_weight * \ + F.binary_cross_entropy_with_logits( + masked_cls_preds, t, weight=weights, reduction='none') + cls_loss = cls_loss.sum() / num_fg.sum() + + return cls_loss + + def clip_gradient(model, clip_norm): """Computes a gradient clipping coefficient based on gradient norm.""" totalnorm = 0 @@ -62,7 +106,9 @@ def decay_learning_rate(optimizer, cur_lr, decay_rate): if cfg.SOLVER.TYPE in ['SGD']: if cfg.SOLVER.SCALE_MOMENTUM and cur_lr > 1e-7 and \ ratio > cfg.SOLVER.SCALE_MOMENTUM_THRESHOLD: - _CorrectMomentum(optimizer, param_group['params'], new_lr / cur_lr) + _CorrectMomentum( + optimizer, param_group['params'], new_lr / cur_lr) + def update_learning_rate(optimizer, cur_lr, new_lr): """Update learning rate""" @@ -119,15 +165,16 @@ def affine_grid_gen(rois, input_size, grid_size): width = input_size[1] zero = Variable(rois.data.new(rois.size(0), 1).zero_()) - theta = torch.cat([\ - (x2 - x1) / (width - 1), - zero, - (x1 + x2 - width + 1) / (width - 1), - zero, - (y2 - y1) / (height - 1), - (y1 + y2 - height + 1) / (height - 1)], 1).view(-1, 2, 3) + theta = torch.cat([ + (x2 - x1) / (width - 1), + zero, + (x1 + x2 - width + 1) / (width - 1), + zero, + (y2 - y1) / (height - 1), + (y1 + y2 - height + 1) / (height - 1)], 1).view(-1, 2, 3) - grid = F.affine_grid(theta, torch.Size((rois.size(0), 1, grid_size, grid_size))) + grid = F.affine_grid(theta, torch.Size( + (rois.size(0), 1, grid_size, grid_size))) return grid @@ -139,7 +186,8 @@ def save_ckpt(output_dir, args, model, optimizer): ckpt_dir = os.path.join(output_dir, 'ckpt') if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) - save_name = os.path.join(ckpt_dir, 'model_{}_{}.pth'.format(args.epoch, args.step)) + save_name = os.path.join( + ckpt_dir, 'model_{}_{}.pth'.format(args.epoch, args.step)) if isinstance(model, mynn.DataParallel): model = model.module # TODO: (maybe) Do not save redundant shared params diff --git a/lib/utils/training_stats.py b/lib/utils/training_stats.py index 8cae8a18..610c4304 100644 --- a/lib/utils/training_stats.py +++ b/lib/utils/training_stats.py @@ -83,24 +83,24 @@ def UpdateIterStats(self, model_out, inner_iter=None): assert loss.shape[0] == cfg.NUM_GPUS loss = loss.mean(dim=0, keepdim=True) total_loss += loss - loss_data = loss.data[0] + loss_data = loss.data[0].item() model_out['losses'][k] = loss if cfg.FPN.FPN_ON: - if k.startswith('loss_rpn_cls_'): + if k.startswith('loss_rpn_cls_') or k.startswith('loss_retnet_cls_'): loss_rpn_cls_data += loss_data - elif k.startswith('loss_rpn_bbox_'): + elif k.startswith('loss_rpn_bbox_') or k.startswith('loss_retnet_bbox_'): loss_rpn_bbox_data += loss_data self.smoothed_losses[k].AddValue(loss_data) model_out['total_loss'] = total_loss # Add the total loss for back propagation - self.smoothed_total_loss.AddValue(total_loss.data[0]) + self.smoothed_total_loss.AddValue(total_loss.data[0].item()) if cfg.FPN.FPN_ON: self.smoothed_losses['loss_rpn_cls'].AddValue(loss_rpn_cls_data) self.smoothed_losses['loss_rpn_bbox'].AddValue(loss_rpn_bbox_data) for k, metric in model_out['metrics'].items(): metric = metric.mean(dim=0, keepdim=True) - self.smoothed_metrics[k].AddValue(metric.data[0]) + self.smoothed_metrics[k].AddValue(metric.data[0].item()) def _UpdateIterStats_inner(self, model_out, inner_iter): """Update tracked iteration statistics for the case of iter_size > 1""" @@ -125,13 +125,13 @@ def _UpdateIterStats_inner(self, model_out, inner_iter): assert loss.shape[0] == cfg.NUM_GPUS loss = loss.mean(dim=0, keepdim=True) total_loss += loss - loss_data = loss.data[0] + loss_data = loss.data[0].item() model_out['losses'][k] = loss if cfg.FPN.FPN_ON: - if k.startswith('loss_rpn_cls_'): + if k.startswith('loss_rpn_cls_') or k.startswith('loss_retnet_cls_'): loss_rpn_cls_data += loss_data - elif k.startswith('loss_rpn_bbox_'): + elif k.startswith('loss_rpn_bbox_') or k.startswith('loss_retnet_bbox_'): loss_rpn_bbox_data += loss_data self.inner_losses[k].append(loss_data) @@ -140,7 +140,7 @@ def _UpdateIterStats_inner(self, model_out, inner_iter): self.smoothed_losses[k].AddValue(loss_data) model_out['total_loss'] = total_loss # Add the total loss for back propagation - total_loss_data = total_loss.data[0] + total_loss_data = total_loss.data[0].item() self.inner_total_loss.append(total_loss_data) if cfg.FPN.FPN_ON: self.inner_loss_rpn_cls.append(loss_rpn_cls_data) @@ -156,7 +156,7 @@ def _UpdateIterStats_inner(self, model_out, inner_iter): for k, metric in model_out['metrics'].items(): metric = metric.mean(dim=0, keepdim=True) - metric_data = metric.data[0] + metric_data = metric.data[0].item() self.inner_metrics[k].append(metric_data) if inner_iter == (self.misc_args.iter_size - 1): metric_data = self._mean_and_reset_inner_list('inner_metrics', k)