From d77163b77592d68fffd0e994314200de17192882 Mon Sep 17 00:00:00 2001 From: reacher-l <45810596+reacher-l@users.noreply.github.com> Date: Thu, 31 Dec 2020 15:38:07 +0800 Subject: [PATCH] Add files via upload --- README.txt | 2 + __pycache__/evaluate.cpython-36.pyc | Bin 0 -> 2488 bytes __pycache__/loss.cpython-36.pyc | Bin 0 -> 15754 bytes cut_data.py | 103 +++ data/edge_utils.py | 84 ++ data/make_data.py | 239 ++++++ evaluate.py | 53 ++ loss.py | 476 +++++++++++ model/__pycache__/hrnet_config.cpython-35.pyc | Bin 0 -> 2083 bytes model/__pycache__/hrnet_config.cpython-36.pyc | Bin 0 -> 1631 bytes model/__pycache__/hrnet_config.cpython-37.pyc | Bin 0 -> 1650 bytes model/__pycache__/seg_hrnet.cpython-35.pyc | Bin 0 -> 21814 bytes model/__pycache__/seg_hrnet.cpython-36.pyc | Bin 0 -> 19517 bytes model/__pycache__/seg_hrnet.cpython-37.pyc | Bin 0 -> 19540 bytes model/hrnet_config.py | 130 +++ model/seg_hrnet.py | 750 ++++++++++++++++++ run.sh | 2 + train.py | 265 +++++++ utils/__pycache__/ema.cpython-36.pyc | Bin 0 -> 946 bytes utils/__pycache__/label2color.cpython-35.pyc | Bin 0 -> 933 bytes utils/__pycache__/label2color.cpython-36.pyc | Bin 0 -> 849 bytes utils/__pycache__/lr_scheduler.cpython-35.pyc | Bin 0 -> 2158 bytes utils/__pycache__/lr_scheduler.cpython-36.pyc | Bin 0 -> 1993 bytes utils/ema.py | 20 + utils/label2color.py | 24 + utils/lr_scheduler.py | 36 + 26 files changed, 2184 insertions(+) create mode 100644 README.txt create mode 100644 __pycache__/evaluate.cpython-36.pyc create mode 100644 __pycache__/loss.cpython-36.pyc create mode 100644 cut_data.py create mode 100644 data/edge_utils.py create mode 100644 data/make_data.py create mode 100644 evaluate.py create mode 100644 loss.py create mode 100644 model/__pycache__/hrnet_config.cpython-35.pyc create mode 100644 model/__pycache__/hrnet_config.cpython-36.pyc create mode 100644 model/__pycache__/hrnet_config.cpython-37.pyc create mode 100644 model/__pycache__/seg_hrnet.cpython-35.pyc create mode 100644 model/__pycache__/seg_hrnet.cpython-36.pyc create mode 100644 model/__pycache__/seg_hrnet.cpython-37.pyc create mode 100644 model/hrnet_config.py create mode 100644 model/seg_hrnet.py create mode 100644 run.sh create mode 100644 train.py create mode 100644 utils/__pycache__/ema.cpython-36.pyc create mode 100644 utils/__pycache__/label2color.cpython-35.pyc create mode 100644 utils/__pycache__/label2color.cpython-36.pyc create mode 100644 utils/__pycache__/lr_scheduler.cpython-35.pyc create mode 100644 utils/__pycache__/lr_scheduler.cpython-36.pyc create mode 100644 utils/ema.py create mode 100644 utils/label2color.py create mode 100644 utils/lr_scheduler.py diff --git a/README.txt b/README.txt new file mode 100644 index 0000000..e499be0 --- /dev/null +++ b/README.txt @@ -0,0 +1,2 @@ +将用于训练的四张大图(382、182以及他们的标签)放入data文件夹中,然后运行run.sh。 +注:run.sh文件包含切图、划分训练集、验证集(python cut_data.py)和训练模型(CUDA_VISIBLE_DEVICES=0 python train.py --backbone=hrnet --batchsize=4 --lr=0.01 --num_epochs=150) diff --git a/__pycache__/evaluate.cpython-36.pyc b/__pycache__/evaluate.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3785313dc64e3fdb628e354fab72d07a96b69403 GIT binary patch literal 2488 zcmbVO&2HO95Z>jF{`{xdYGb=bnxaKpra;u70eT4v2T9!K5TQVu2JmG;bJsFuiIR7z z)E4ARm1`cP57D>iiPxU|3O)6kB_-K{5(K5dnIU&~_E*!K$8COvP*^G1p(D9; z-Y~3kr3dXvUzVVKS(X*(lB~)abXl&*Rp^Rblk3n`xgqP&HF-;JLa)dze6^8PpS=#d z!!YgZh(1RbX&3hs+~y9%n2(uYDepLA)?vq7@&=pGJ6AF|aavcjB=E9oGVF<{8yb^2 z$sqGbO7~4xkNU~+(8T>j^ukof=Yf-XMs<%hh38MEt5E+wX!ewh!{$#0UgM)X@1l}Vzd?wW zSwaUe=7eCmM9NZfW=}2%I=8RQ+qNlLpzF1C9o_=2Z$W8GeJN$VNz;mqc`u&#UoP_9 zdx3z6=SixxQBg{@+JCLIc$vU#V6~ylGO)hG>%{%bIrsN1T^@r+aP@ygE&aExCX@rPq*2nm!i>)J(?liEhfFU&ZR&0(N1U$ z?um0*8kZ0^zh!q;=Eg~?>EmcG)n1%*RnksR&|?ThRzjb?7^u8gR*vE%>JO7ND{Ey= z!hs6Pc}y$qRK&fot+MJstHShpp*hX`Zg`}+nQs^7*;r^wbh+ba~yJW8laW)LQ`G{U2?;9kxNPI-%V-h8 zcl%b&?rf(T61N?P$C?B}AOVu_Ob7y^2oQmU@EZkwAS9p^AtWTCm5|793n_fxxz#m2 z58DaiN9f(Ued|`$t=spU^PTUUd*|Wl>2l@Xf2==y&M^MnnD{uTpTQA+#WW1x@Xd}< z)8A&>s#(%vx2?8Q%Sk<7E8yGe6jw{NlF2)5)XTN9)Sb@MYNb{&jiynX_H%yzWh0or zV2sS#OklKk1v4YFJv%aLGk(D@zHHQP@lX0?f9hqkw%f1x(`dQXe~&-w--4Do|CGPm zzZJFH{P+5E{%xr3@%Q+9ao1k|QGcJmA1(XdXZW}K2VOS(19*eg-rv4`WY-S(r~QNe z9k_nbKjYu!yQtmapY`wd??LTO|D1oXe+ad^{I~h{p+6VB9QKc(mm`5CJ=~3!s((LP z?oV6pLCXXFJX+?{mV421)IWxnV`f+Opyfo`@;3iL{~^5LKL26=?f5<%9P!_= zVtvX~zZ=|*nW$pa_xtBjKHxuc!>Y~8`Gt|`zkAN89Sy3%(O`aL)s6+ngJUBTvus{A zsuyokgbP(ODV*=Eb)!I`I^_pT-dZQ}Ta5^(#b-U$@)kQm1Co`Gg))!hJdW@nilz}8 zSLTMMZ^UL~MCQmmVC*wKZN2E=9p+6wSG5v5Tw6_w-OJu;zY~OLu#TNb%=JGoR2etC ze)RZi;J3Wvk6pNMyrW(`zP#pL3tGoJUKkx;Y=vIy!Q-K79PjkP@L2!)>*Z6ORv0yU ztNll(I1W81RY5%)C|}HW8E6re82?sm4r)POS~XP(%~IC|`$}@Hu?BUtd!OJO@sep z%ePi6bt%gE_Q-t6jIGPYhmG|+qkL@m&d9o8T$z2r8rs8L?8Lcu8AIpA4?kzbwr}F8 z^%#?jcKd*V-;*00ExY(&*sFopbvZ9y)PmI9YprP6Z3f*yc~MV=$K*bh!>E&-@>*dq zC|_LdN7r3Gc$3prEhJ70x{y>_ORWwnOI{}oWDXLi8>~mk^yR46e;ilvj$St@KkIeY zf{RM^R8k7MYpc8vRP%}1sg+yJZchbuj4a6od{JU{6WePv67%ZTiFY`LTQP+2P85by zG##^SS|}^#j9J9Lyg7$D+tmTwk%^ijAqI&^`lzHLF=L}`j*OudLnv))WVA7TBXgfI zv^o7T^mL@=JznFFt0&-L*&f!k;6>GnNn&~^N?PFb9PoX9vsLRxz$JDPy zLFhJ>=eL4x_8u`CDZ&)sb=6`E!oo4dEzn|nkUuBi^8 z?R&hyQ9aY3dhtUmVO_-XOQ^(7xH~qM%-CEqEn{d5%_b&#y-jh&rs(_&(E0qZ5ErQD zah~NFEyUtvTU?0k75j{FPG2j5by|plLczcqQyVX7DPufnSM$Ul+Ik28N^B}(S z)~r~RdI;AV)Es^YbF7QFzJp4f+d*9U_0LR*>k|-HVOWG1OY7YTAT}}qT#K0Ya_p?w z>XEn@mm$akO9&D+IGmbnLml^tAbG74SHhR$5<~`2gvcIY9YULeNS(Msk?M9pVIFru z$Xc+lOzUOZef~9rTw$2^?I_pI$N7;df_4TEi;&JL*(c zK7$xSBO&H8j&KfzxV_2{E)SY(YIi5HfLkFrRV_uQqaI{IAEyi2g!ri|sHA>tuop5r zcB<2B0BRpaW$fG3eYmMwP>1nFyzVvwtx_~atyfmJpX5CLrs}k+viW}AY@SyS@MNCF zQ5MHo9A|NY1&5WGtqo<`I!4~`1dbGp#|p&nVyZ?J>vo*&Hl?!$XNMme0S_BT_)ZkD zrS6H0S7G-eq$1QLwe}HaZs^cv=_h0Jlg7w?4KPgaR<*sofIN}eIkCSnhAy`-4=$;I zTy#-wwxaN)`~10i_py2Rxhw9Nwalm2TrYA(&i&SU&~c9}2GO-3=(E8xgo zTyw?3f%0J#cs^B3Qt{sy>yE7W4*5)?Qn=fI;+t9dPK(Q>WWd|%>lQpe^eBjC%E%H(SrXbCwmqex1z8z898;H<$h@k2J4 z`KuTw+O~yp?bB4kw+JSn37FTBF$bmws$ls}Y(RqnesiP_0IYV7OpU?I4I9#MKqml3 zBo%RPzcI{51p?EsfIEv(sSU6Mz}ayD7GVvGaS{4o82p2X`iE8ir`}BM;XXj<$vb`| z-AA8qruEu;ZB_fR*2+~k>ZQ0UKSs1IB{0NL{W+6XKpq$>_^s8X><3{34imOsolRm@rZn*%i^C)#7m9h+hWYL~dw3#dSg zpgoX0&<-_7eF)nG(JZP%M8G9P0nia^#Rf$M+lQ_=IOn&-0lW^re?=zqlD1BONhaQq zq>+46U&r!D{a+&{NRO4PIRHIuy}iEwsa0NHl^4#o$f~#DUJ$R&Wh` zGOr;(o1a)-?yA>m`L)8;6dVdyb>CZOj`sU1j+8-`LG3wX`kz=g?KArj*Z2t!p$Lo` z!poYq;X`V+Z~G2HVdqu2<(g1%1||$u&z-;c0iZz!kpwt)$0OX0LVyuL8cmoP2TZ9+ z1k8nN&;j)mxFspp>#c4ps@M7B!=ST-?^3-koUmTkKaLQiNC&9=D68-VQ?K(miPgS^Se*o(-8C0%3N8F9qUg4hytlIChj7_Ktcjmhj_r`6bim# zQYrGmKTYE6c}2I3h)wnS*Z)7Ti!sb@CkJqzv^JS2!AhZuXw9HdJLw$&E~#k5sISwZ zei6NBxfh6W%ENh50NR9@r!35)VvMyqzpA<;!*VWR%Yw!IVE~w9MB` zI!S7__x&;PdlH$b)=@FKl>HsO+*^uPz4b{6^;K`tOZngV?G#WyE;lWLtekHD#{f}d60-dY69(kT~R4@v66)*#rR zrsGaENbW>0H}T38r!%Lm$9CR5y1_E#=LPurhrLecumrKF%)5tEIq9bKF1iW|k8VQu zJx|K&xr{kD+VOqx3Eov?a+<+-#DGQ}ll~vFSv>?X!wjwo(;1_EJJrd!GXExwmI{Q( zfeMuJTKY&^&@m~bzlZRCYf1(~twEaN&WtIpRw8v0FC4R4MBTkbb&4&PYT#=Nw^i>& z+n?hYv)a2izhMVnE8T1&w?6v~P|&-kgCJYcFO<+t_-tmZo5mt6cV_cz|Bxdhjv8e~ zYr8G5l4YD_n8`rDfI~3zzfmD886hc*%;FrfkHT@ucZI3g2<{zll=4bmy;IB<{AFZz zkyJzO<}czL0{*WJ^Ke=D6{PqP#20XP4vc0gF0^xLze_Kk0VB5YhR-0w38sv?qh5*% zK6o&)p9Yfog;xLNlX5w;GjCL`3K=hY{!5L(NX34psDO29j|0Uc!HSeCEB;BHGurQ#} zj|4QL{|IBp5$6Q!{_zBZ%`9H0lWl;GT7)tUs+Tn)jsdX`Pa?_aUbk~y!{dV*XBM7; zt3QWlpwlgU$@A4)k{q+h9q=wdCqg?wlVW@;ZcK*(5)d#Aqy5`+tWllS1r`@sNRH$Z zYlm0}2bxxlG&Fk^gC+s1A*Kw&AgonHfc2%*`l4588X zw}c@DoQoN^^CN?%2}3X7yM)YLNy3rh-~rL#w+u(*uA4S~i>Qofk>WzjWUNt8XL-X4 ziziWlZy+v1?j$KUdflkiTw_ACfPCE`P`#uu&P0?aL@R-bc`h*@OU&mKWvo7k0zpa& zuurjxGdw}$iNW8+5lY6?saP`-+8xXSm03tkkUC2kIJljM$cU;&8F7VSh|~*YLe`E% z4B5PIo?`MD@Pot*^OCHQ-6PSB`-+6g5raNbQ)}?u^Sv(IZ=~%6-r5h=T`nEKS;IB= zx`DggLr$yP=rKf2FKJhORyD(uDW-tB2;V>Ho*h34Eh9bP%nn?!=3Be!)bK*{gK0QW9UQ=;75nP$2x`un^a; zVTn+zImuDRVNwhm%tNVqVmJD0aGy2YUz?ifTU})jd5{}0#x0UfDy!Z~P{*9rF=zjU z8^akCMv=zI__NF$W{eg>LiV9x`3dhwF^)ie6Om_0;!rM3mugPB2%VESG!NmGlNp_| zxzK6^3`D)VkOJU|a8}OZI4aK;7&eP6TIwT_@D}tE z(!854-(syI(}4a-_d?L^Axh*Lpo@5YEja1E0F3+S`bYXh_W@nMj{3y&4{bcJIr1B+ z(1iZzm~y8ZLqvrd`U#{+MPtTlrHiY zX`0k+ExCueYIt})>yLVz-i==Ss`lBXEI8 za!l!C8ayd1;>-Z8%;Vo+?`BQkp6wu!&k*uoaUy9i6Kkm!k#fJWVuc4I`X!P;NvdL1 zhp91K^Gonhz)ogB(3(Esm(gl5#fpQp(;8V!2Hhau7LX&+=~uH|#5vZ7@LXEki+sw& zeVbQ%jdmHN)ly$YqIQZI7Cc?{L21=D(E=KqwiM9vw~SMxieGF`^DYw0myn<$nbcXC z0fR59xqF|hjUr9RNigX5x`!hy86eTGq0rPmj}%4d4Hnv=ii5o$+Cg31j}9j?b2A$F zX3m1ZC~Db7t|Y5Id!qN$q=-^LnrN*fVovOS@7ji}Md6ZCrl(0B`5X)mDn!kYc%MH- zJukfgi~xaHaLxjf2_O0bZn>3Oayu#p)*APid#r=z5vyV`-KP)78tmDkH{%SFy~Pf& z_H_jvsgJOzqeu$%y5DOcTSTRnG?Mx$)-q>4R$b9rs{apgND`wsjc_P0Ch_n<<0DUnj64eFbFX2b?B!%W~01Mybqyz1T1tB)0J@&aYmm_hvvW_P?1MQ$E z9Nhedm)40Nq$j}%5R21^j0wV~Tsx64Rl}VineHFilDC3y0=MQ1vX^8le8W)zje&y5 zLUSO1?9H;(m?-eQ0gB=dkVDzw)`g4#kFk%E+zI=Lcfbn4J|f^&?E6{jg~q-WWHFn_ z{Oh^K76`p9MW%3!fIrXMe}To%p%|;tW(0g_8v@=kR)Ou<#}J#AcOFOh&nSp%L@{9A z957E%BeIr=G&ZtdVDLjnmW~h{6Q_pe`ezv|0}ZhNAS>%HGe?H|3s_q}2gC>J_{J+% zoMVs-mN{_-l*Cq|cF{LhO6sp!LqBqt?hm8*8b*<8m&R|{xQAgf@R-kI1I2#h-B_|h zbdi^BNawKj4koL6kqt=sg~7es08cB8-*dYlsv7Jtdz$)VcF&?dxcj2GLs!dX%u7Ey zyfTkXOVRVFVDX34PJNt(cr!{zy$=Przs}_7;7~*-LVQIo>OMn2h&Z@U@G(LdEE$Gc zD@VP^>(m+gDQ=XgmDsG$GceI9XrkI9l)OeY?m+^QSaZO$dvNRn&dne(J($gOVT@_oQ$W=OmI{?<*1Hd^2=B=$@x=$1>bq$ ztOdBOUCHbRmxEQwfnU@qa6#hlEy?FUPT%~}Cro1@OZhS|6e?S23C$617@>_l z64Mkj#i}P4W^)VF4y=@qxP?54`GQY?>v^7V|;Q7mNXx@5K z1g)*FfMI1uEoqd*>UNXa$=^emPD%~My>0N5RlOXdeg)mua?+#v3}3;w?`-@}P!hAa zY#=nuLkaRw0}KBK`?pTsb}%4q%D`X|Cz|Ado{;Ro+^`=C!MrvrAuSR}qBRH8v*X-~ z9iE~I!tkG#;1Vmz%#Hctm~J>AN+d}_hq!+VQY|b4jWL~Npc|!v zaj;YqB-2>PYj*1(x7aQi;)WGICzduPd&c@bEI0;=$%tw(T}_fzKTQp+{&HgVuO;Sk zVqV*c9wbwBSXgQs&v7;-4kJlGsGT=wK@$cCx6kFqLz!MC> zfp2n`lV6OHkd#*Ilo1!(C9LL6fm1SlDvSJSPuSZOm-icT2Ih%WcM-|UVpPGi(|B$M zxdD40)+>>9nMNX$oO2dE%)})=7gr=3!+gxl3Ntj=4A;Iz)(9(G6+m$%ZRyH^O>1 z_EC|N79;F2fl4|fn;YIFmz`T+eiNVXpB9797FEclIEY(ugf0r%zP1k;MC7Bm@{F!5 zYz{)|e_|_}uxbDS!AsA63AMyoTI)8TE+VK6f>gn^pXLN2-C-a&oZmJDTsS9p>IBHDp(43sZ>vq*&Q3p>B{!b}kxz-Ok4Q^ng9 z+jn;7_u)$B&L<-gJB>s(wmuY--Ipc0pLBRP8Tfqf^Pl7A8>cgMq#BM7+1t01>xMoc zqcF}TaFm;`CVN zNfw-ZZ7nac_8t}oP*lrXk(B_devifPqo~aQu9+9D>*c4!e4EE%#LD914s7Z z6@Fw>WM)S?b*m%#%#LYv& z99QUVK^xbfqBD&tBJ$7jl-r6l+e*z$F1;~9{o)v3w3(5d4sM(LA={1Y$AXr=DJz+v zp_+P3&`-7D4R)dmr#XM(pE4(f9ZIPdewmZC1OMDIVd>o&eA1iqpgD$5p`hJ2fk)b3 zqXdfRq$UDrEV}`7xI71+%>bPamP>V6s|4h1rFWNo&(1GOp2oayXtQ7EkEYBG@R>RDggKboq#v6tbcfhg!dAZjEmLn)=r8bf#Bi8s zu{fpJmZi2y7qXdr9VhsEDJQ}{1fAmG--BSEEdG)MG+YuF5;bBb(Ov+eWB&mYwxHI! zj#Mq`NS8>mBF~)&?VOKQe=b}KsRl*Hego{h1>(aOKoYD$0%@(t+6akfka+k7{K*24 zTs#DK9I%DyM+JW;j|{@M9sA6-%e{6LKVj*7oE3A3Y}0VFb<*_e-9v_vW)tKN9ptQNfsASkSRwx z!$NPNSI!5#z%}^(^=d``%>Y?d;sieaBBg-E8P#fpwJDyaDu-oqiL66Jc(nqPDjStf zSSzM2b@gSwyA?xNA+xrnFa@p>7`KP{#dp4OT zv02RXAygtU&>OPW8q+wDV}s(@A`GXUt~g`@bRZ0*BRJELfzTR`6C7CB&qdS~r$>E* zt+|cKQGbPoS|Oc<9e>;|`ElB}az6e(Ge-QpGJcHu7W&{G%2f}MOZ_H)$W81}#ZZ!* zA6@SU&A-hB8m~eZYq`Xnr3dTv11>Iz-(d*!kMl~~MJfP?;YcskV zZ!PM{7Hj@ORbcTNie%~%|9nmFI9H^6@@I#_WAr8tN>u+v%^B7nVexJhP)gZsiWKjY zQqg-KJzf|6_gM>A#4Ls^m`M} radius] = 0 + dist = (dist > 0).astype(np.uint8) + channels.append(dist) + + return np.array(channels) + + +def onehot_to_binary_edges(mask, radius, num_classes): + """ + Converts a segmentation mask (K,H,W) to a binary edgemap (H,W) + + """ + + if radius < 0: + return mask + + # We need to pad the borders for boundary conditions + mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0) + + edgemap = np.zeros(mask.shape[1:]) + for i in range(num_classes): + # ti qu lun kuo + dist = distance_transform_edt(mask_pad[i, :]) + distance_transform_edt(1.0 - mask_pad[i, :]) + dist = dist[1:-1, 1:-1] + dist[dist > radius] = 0 + edgemap += dist + # edgemap = np.expand_dims(edgemap, axis=0) + edgemap = (edgemap > 0).astype(np.uint8)*255 + return edgemap + + +def mask_to_onehot(mask, num_classes): + """ + Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one + hot encoding vector + + """ + _mask = [mask == (i) for i in range(num_classes)] + return np.array(_mask).astype(np.uint8) + +if __name__ == '__main__': + label = cv2.imread('/media/ws/新加卷1/wy/dataset/HUAWEI/data/labels/182/182_16_23.png',0) + img = cv2.imread('/media/ws/新加卷1/wy/dataset/HUAWEI/data/images/182/182_16_23.png') + oneHot_label = mask_to_onehot(label, 2) + edge = onehot_to_binary_edges(oneHot_label, 2, 2) # #edge=255,background=0 + edge[:2, :] = 0 + edge[-2:, :] = 0 + edge[:, :2] = 0 + edge[:, -2:] = 0 + print(edge) + print(np.unique(edge)) + print(edge.shape) + cv2.imwrite('test.png',edge) + cv2.namedWindow('1',0) + cv2.namedWindow('2',0) + cv2.namedWindow('3',0) + cv2.imshow('1',label*255) + cv2.imshow('2',edge) + cv2.imshow('3',img) + cv2.waitKey() \ No newline at end of file diff --git a/data/make_data.py b/data/make_data.py new file mode 100644 index 0000000..e45ce52 --- /dev/null +++ b/data/make_data.py @@ -0,0 +1,239 @@ +import cv2 +import pdb +import collections +import matplotlib.pyplot as plt +import numpy as np +import os +import os.path as osp +from PIL import Image, ImageOps, ImageFilter +import random +import torch +import torchvision +from torch.utils import data +import torchvision.transforms as transforms +from .edge_utils import * +class GaofenTrain(data.Dataset): + def __init__(self, root, list_path, crop_size=(640, 640), + scale=True, mirror=True,rotation=True, bright=False, ignore_label=1, use_aug=True, network='resnet101'): + self.root = root + self.src_h = 1024 + self.src_w = 1024 + self.list_path = list_path + self.crop_h, self.crop_w = crop_size + self.bright = bright + self.scale = scale + self.ignore_label = ignore_label + self.is_mirror = mirror + self.rotation = rotation + self.use_aug = use_aug + self.img_ids = [i_id.strip() for i_id in open(list_path)] + self.files = [] + self.network = network + for item in self.img_ids: + image_path = 'images/'+item.split('_')[0]+'/'+item + label_path = 'labels/'+item.split('_')[0]+'/'+item + name = item + img_file = osp.join(self.root, image_path) + label_file = osp.join(self.root, label_path) + self.files.append({ + "img": img_file, + "label": label_file, + "name": name, + "weight": 1 + }) + + print('{} images are loaded!'.format(len(self.img_ids))) + + def __len__(self): + return len(self.files) + + def random_brightness(self, img): + if random.random() < 0.5: + return img + self.shift_value = 10 #取自HRNet + img = img.astype(np.float32) + shift = random.randint(-self.shift_value, self.shift_value) + img[:, :, :] += shift + img = np.around(img) + img = np.clip(img, 0, 255).astype(np.uint8) + return img + + def generate_scale_label(self, image, label): + f_scale = 0.5 + random.randint(0, 11) / 10.0 # [0.5, 1.5] + image = cv2.resize(image, None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label, None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_NEAREST) + return image, label + + def __getitem__(self, index): + datafiles = self.files[index] + image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) + label = cv2.imread(datafiles["label"],0) + #旋转90/180/270 + if self.rotation and random.random() > 0.5: + angel = np.random.randint(1,4) + M = cv2.getRotationMatrix2D(((self.src_h - 1) / 2., (self.src_w - 1) / 2.), 90*angel, 1) + image = cv2.warpAffine(image, M, (self.src_h, self.src_w), flags=cv2.INTER_LINEAR) + label = cv2.warpAffine(label, M, (self.src_h, self.src_w), flags=cv2.INTER_NEAREST, borderValue=self.ignore_label) + # 旋转-30-30 + if self.rotation and random.random() > 0.5: + angel = np.random.randint(-30,30) + M = cv2.getRotationMatrix2D(((self.src_h - 1) / 2., (self.src_w - 1) / 2.), angel, 1) + image = cv2.warpAffine(image, M, (self.src_h, self.src_w), flags=cv2.INTER_LINEAR) + label = cv2.warpAffine(label, M, (self.src_h, self.src_w), flags=cv2.INTER_NEAREST, borderValue=self.ignore_label) + size = image.shape + if self.scale: #尺度变化 + image, label = self.generate_scale_label(image, label) + if self.bright: #亮度变化 + image = self.random_brightness(image) + image = np.asarray(image, np.float32) + image = image[:, :, ::-1] + mean = (0.355403, 0.383969, 0.359276) + std = (0.206617, 0.202157, 0.210082) + image /= 255. + image -= mean + image /= std + + img_h, img_w = label.shape + pad_h = max(self.crop_h - img_h, 0) + pad_w = max(self.crop_w - img_w, 0) + if pad_h > 0 or pad_w > 0: + img_pad = cv2.copyMakeBorder(image, 0, pad_h, 0, + pad_w, cv2.BORDER_CONSTANT, + value=(0.0, 0.0, 0.0)) + label_pad = cv2.copyMakeBorder(label, 0, pad_h, 0, + pad_w, cv2.BORDER_CONSTANT, + value=(self.ignore_label,)) #边界填充的是ignore + else: + img_pad, label_pad = image, label + + img_h, img_w = label_pad.shape + h_off = random.randint(0, img_h - self.crop_h) + w_off = random.randint(0, img_w - self.crop_w) + image = np.asarray(img_pad[h_off: h_off + self.crop_h, w_off: w_off + self.crop_w], np.float32) + label = np.asarray(label_pad[h_off: h_off + self.crop_h, w_off: w_off + self.crop_w], np.float32) + image = image.transpose((2, 0, 1)) # 3XHXW + + if self.is_mirror: #水平/垂直翻转 + flip1 = np.random.choice(2) * 2 - 1 + image = image[:, :, ::flip1] + label = label[:, ::flip1] + flip2 = np.random.choice(2) * 2 - 1 + image = image[:,::flip2, :] + label = label[::flip2,:] + oneHot_label = mask_to_onehot(label,2) #edge=255,background=0 + edge = onehot_to_binary_edges(oneHot_label,2,2) + # 消去图像边缘 + edge[:2, :] = 0 + edge[-2:, :] = 0 + edge[:, :2] = 0 + edge[:, -2:] = 0 + return image.copy(), label.copy(), edge,np.array(size), datafiles + +class GaofenVal(data.Dataset): + def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), + scale=False, mirror=False, ignore_label=255, use_aug=True, network="renset101"): + self.root = root + self.list_path = list_path + self.crop_h, self.crop_w = crop_size + self.scale = scale + self.ignore_label = ignore_label + self.is_mirror = mirror + self.use_aug = use_aug + self.img_ids = [i_id.strip() for i_id in open(list_path)] + self.files = [] + self.network = network + for item in self.img_ids: + image_path = 'images/'+item.split('_')[0]+'/'+item + label_path = 'labels/'+item.split('_')[0]+'/'+item + name = item + img_file = osp.join(self.root, image_path) + label_file = osp.join(self.root, label_path) + self.files.append({ + "img": img_file, + "label": label_file, + "name": name, + "weight": 1 + }) + self.id_to_trainid = {} + + print('{} images are loaded!'.format(len(self.img_ids))) + + def __len__(self): + return len(self.files) + + + def __getitem__(self, index): + datafiles = self.files[index] + image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) + label = cv2.imread(datafiles["label"],0) + + size = image.shape + + image = np.asarray(image, np.float32) + image = image[:, :, ::-1] + mean = (0.355403, 0.383969, 0.359276) + std = (0.206617, 0.202157, 0.210082) + image /= 255. + image -= mean + image /= std + + img_h, img_w = label.shape + + image = np.asarray(image, np.float32) + label = np.asarray(label, np.float32) + image = image.transpose((2, 0, 1)) # 3XHXW + return image.copy(), label.copy(),label.copy(), np.array(size), datafiles + + + +class GaofenSubmit(data.Dataset): + def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), + scale=False, mirror=False, ignore_label=255, use_aug=True, network="renset101"): + self.root = root + self.list_path = list_path + self.crop_h, self.crop_w = crop_size + self.scale = scale + self.ignore_label = ignore_label + self.is_mirror = mirror + self.use_aug = use_aug + self.img_ids = [i_id.strip() for i_id in open(list_path)] + self.files = [] + self.network = network + for item in self.img_ids: + image_path = 'images/'+item + label_path = 'labels/'+item[:-4]+'.png' + name = item + img_file = osp.join(self.root, image_path) + label_file = osp.join(self.root, label_path) + self.files.append({ + "img": img_file, + "label": label_file, + "name": name, + "weight": 1 + }) + self.id_to_trainid = {} + + print('{} images are loaded!'.format(len(self.img_ids))) + + def __len__(self): + return len(self.files) + + + def __getitem__(self, index): + datafiles = self.files[index] + image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) + size = image.shape + name = datafiles["name"] + + image = np.asarray(image, np.float32) + image = image[:, :, ::-1] + mean = (0.355403, 0.383969, 0.359276) + std = (0.206617, 0.202157, 0.210082) + image /= 255. + image -= mean + image /= std + + image = np.asarray(image, np.float32) + image = image.transpose((2, 0, 1)) # 3XHXW + return image.copy(), np.array(size), name + diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..9cfef8c --- /dev/null +++ b/evaluate.py @@ -0,0 +1,53 @@ +import numpy as np + + +class Evaluator(object): + def __init__(self, num_class): + self.num_class = num_class + self.confusion_matrix = np.zeros((self.num_class,)*2) + + def Pixel_Accuracy(self): + Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() + return Acc + + def Pixel_Accuracy_Class(self): + # print(self.confusion_matrix.sum(axis=1)) + # print(self.confusion_matrix.sum()) + Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) + Acc = np.nanmean(Acc) + return Acc + + def Mean_Intersection_over_Union(self): + MIoU = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + MIoU = np.nanmean(MIoU) + return MIoU + def Mean_Intersection_over_Union_test(self): + MIoU = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + return MIoU + + def Frequency_Weighted_Intersection_over_Union(self): + freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) + iu = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + + FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() + return FWIoU + + def _generate_matrix(self, gt_image, pre_image): + mask = (gt_image >= 0) & (gt_image < self.num_class) + label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] + count = np.bincount(label, minlength=self.num_class**2) + confusion_matrix = count.reshape(self.num_class, self.num_class) + return confusion_matrix + + def add_batch(self, gt_image, pre_image): + assert gt_image.shape == pre_image.shape + self.confusion_matrix += self._generate_matrix(gt_image, pre_image) + + def reset(self): + self.confusion_matrix = np.zeros((self.num_class,) * 2) diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..5de6e78 --- /dev/null +++ b/loss.py @@ -0,0 +1,476 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from collections import Counter +from collections import defaultdict +from torch.autograd import Variable + + +nSamples = [] +weights = [1 - (x / sum(nSamples)) for x in nSamples] +weights = torch.FloatTensor(weights).cuda() + + +def isnan(x): + return x != x + + +def mean(l, ignore_nan=False, empty=0): + """ + nanmean compatible with generators. + """ + l = iter(l) + if ignore_nan: + l = ifilterfalse(isnan, l) + try: + n = 1 + acc = next(l) + except StopIteration: + if empty == 'raise': + raise ValueError('Empty mean') + return empty + for n, v in enumerate(l, 2): + acc += v + if n == 1: + return acc + return acc / n + + +def lovasz_grad(gt_sorted): + """ + Computes gradient of the Lovasz extension w.r.t sorted errors + See Alg. 1 in paper + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): + """ + IoU for foreground class + binary: 1 foreground, 0 background + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + intersection = ((label == 1) & (pred == 1)).sum() + union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() + if not union: + iou = EMPTY + else: + iou = float(intersection) / float(union) + ious.append(iou) + iou = mean(ious) # mean accross images if per_image + return 100 * iou + + +def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): + """ + Array of IoU for each (non ignored) class + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + iou = [] + for i in range(C): + if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) + intersection = ((label == i) & (pred == i)).sum() + union = ((label == i) | ((pred == i) & (label != ignore))).sum() + if not union: + iou.append(EMPTY) + else: + iou.append(float(intersection) / float(union)) + ious.append(iou) + ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image + return 100 * np.array(ious) + + +# --------------------------- BINARY LOSSES --------------------------- + +def lovasz_hinge(logits, labels, per_image=True, ignore=None): + """ + Binary Lovasz hinge loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + per_image: compute the loss per image instead of per batch + ignore: void class id + """ + if per_image: + loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) + for log, lab in zip(logits, labels)) + else: + loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) + return loss + + +def lovasz_hinge_flat(logits, labels): + """ + Binary Lovasz hinge loss + logits: [P] Variable, logits at each prediction (between -\infty and +\infty) + labels: [P] Tensor, binary ground truth labels (0 or 1) + ignore: label to ignore + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * Variable(signs)) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), Variable(grad)) + return loss + + +def flatten_binary_scores(scores, labels, ignore=None): + """ + Flattens predictions in the batch (binary case) + Remove labels equal to 'ignore' + """ + scores = scores.view(-1) + labels = labels.view(-1) + if ignore is None: + return scores, labels + valid = (labels != ignore) + vscores = scores[valid] + vlabels = labels[valid] + return vscores, vlabels + + +class StableBCELoss(torch.nn.modules.Module): + def __init__(self): + super(StableBCELoss, self).__init__() + def forward(self, input, target): + neg_abs = - input.abs() + loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() + return loss.mean() + + +def binary_xloss(logits, labels, ignore=None): + """ + Binary Cross entropy loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + ignore: void class id + """ + logits, labels = flatten_binary_scores(logits, labels, ignore) + loss = StableBCELoss()(logits, Variable(labels.float())) + return loss + +# --------------------------- MULTICLASS LOSSES --------------------------- +def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): + """ + Multi-class Lovasz-Softmax loss + probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + per_image: compute the loss per image instead of per batch + ignore: void class labels + """ + if per_image: + loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) + for prob, lab in zip(probas, labels)) + else: + loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) + return loss + + +def lovasz_softmax_flat(probas, labels, classes='present'): + """ + Multi-class Lovasz-Softmax loss + probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) + labels: [P] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + if probas.numel() == 0: + # only void pixels, the gradients should be 0 + return probas * 0. + C = probas.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes is 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (Variable(fg) - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) + return mean(losses) + + +def flatten_probas(probas, labels, ignore=None): + """ + Flattens predictions in the batch + """ + if probas.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probas.size() + probas = probas.view(B, 1, H, W) + B, C, H, W = probas.size() + probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = (labels != ignore) + vprobas = probas[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobas, vlabels + + +def make_one_hot(input, num_classes): + """Convert class index tensor to one hot encoding tensor. + Args: + input: A tensor of shape [N, 1, *] + num_classes: An int of number of class + Returns: + A tensor of shape [N, num_classes, *] + """ + input=input.unsqueeze(1) + shape = np.array(input.shape) + shape[1] = num_classes + shape = tuple(shape) + result = torch.zeros(shape) + result = result.scatter_(1, input.cpu(), 1) + return result + +class BinaryDiceLoss(nn.Module): + """Dice loss of binary class + Args: + smooth: A float number to smooth loss, and avoid NaN error, default: 1 + p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 + predict: A tensor of shape [N, *] + target: A tensor of shape same with predict + reduction: Reduction method to apply, return mean over batch if 'mean', + return sum if 'sum', return a tensor of shape [N,] if 'none' + Returns: + Loss tensor according to arg reduction + Raise: + Exception if unexpected reduction + """ + def __init__(self, smooth=1, p=2, reduction='mean'): + super(BinaryDiceLoss, self).__init__() + self.smooth = smooth + self.p = p + self.reduction = reduction + + def forward(self, predict, target): + assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" + predict = predict.contiguous().view(predict.shape[0], -1) + target = target.contiguous().view(target.shape[0], -1) + + num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth + den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth + + loss = 1 - num / den + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + elif self.reduction == 'none': + return loss + else: + raise Exception('Unexpected reduction {}'.format(self.reduction)) + + +class DiceLoss(nn.Module): + """Dice loss, need one hot encode input + Args: + weight: An array of shape [num_classes,] + ignore_index: class index to ignore + predict: A tensor of shape [N, C, *] + target: A tensor of same shape with predict + other args pass to BinaryDiceLoss + Return: + same as BinaryDiceLoss + """ + def __init__(self, weight=None, ignore_index=None, **kwargs): + super(DiceLoss, self).__init__() + self.kwargs = kwargs + self.weight = weight + self.ignore_index = ignore_index + + def forward(self, predict, target): + assert predict.shape == target.shape, 'predict & target shape do not match' + dice = BinaryDiceLoss(**self.kwargs) + total_loss = 0 + predict = F.softmax(predict, dim=1) + + for i in range(target.shape[1]): + if i != self.ignore_index: + dice_loss = dice(predict[:, i], target[:, i]) + if self.weight is not None: + assert self.weight.shape[0] == target.shape[1], \ + 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) + dice_loss *= self.weights[i] + total_loss += dice_loss + + return total_loss/target.shape[1] + + +class OhemCrossEntropy(nn.Module): + def __init__(self, ignore_label=-1, thres=0.7, + min_kept=100000, weight=None): + super(OhemCrossEntropy, self).__init__() + self.thresh = thres + self.min_kept = max(1, min_kept) + self.ignore_label = ignore_label + self.criterion = nn.CrossEntropyLoss( + weight=weight, + ignore_index=ignore_label, + reduction='none' + ) + + def _ce_forward(self, score, target): + ph, pw = score.size(2), score.size(3) + h, w = target.size(1), target.size(2) + if ph != h or pw != w: + score = F.interpolate(input=score, size=( + h, w), mode='bilinear', align_corners=True) + + loss = self.criterion(score, target) + + return loss + + def _ohem_forward(self, score, target, **kwargs): + ph, pw = score.size(2), score.size(3) + h, w = target.size(1), target.size(2) + if ph != h or pw != w: + score = F.interpolate(input=score, size=( + h, w), mode='bilinear', align_corners=True) + pred = F.softmax(score, dim=1) + pixel_losses = self.criterion(score, target).contiguous().view(-1) + mask = target.contiguous().view(-1) != self.ignore_label + + tmp_target = target.clone() + tmp_target[tmp_target == self.ignore_label] = 0 + pred = pred.gather(1, tmp_target.unsqueeze(1)) + pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort() + min_value = pred[min(self.min_kept, pred.numel() - 1)] + threshold = max(min_value, self.thresh) + + pixel_losses = pixel_losses[mask][ind] + pixel_losses = pixel_losses[pred < threshold] + return pixel_losses.mean() + + def forward(self, score, target): + score = [score] + weights = [1.] + assert len(weights) == len(score) + + functions = [self._ce_forward] * (len(weights) - 1) + [self._ohem_forward] + return sum([ + w * func(x, target) + for (w, x, func) in zip(weights, score, functions) + ]) + +class SmoothCrossEntropy(nn.Module): + def __init__(self, ignore_index=255,eps=0.1): + super(SmoothCrossEntropy, self).__init__() + self.eps = eps + self.ignore_label = ignore_index + def forward(self, score, target): + pred = F.softmax(score, dim=1) #nxcxhxw + + mask = target != self.ignore_label + tmp_target = target.clone() + tmp_target[tmp_target == self.ignore_label] = 0 + + + one_hot_labels = torch.zeros([score.shape[0], 9, score.shape[2], score.shape[3]]).cuda() + one_hot_labels.scatter_(1, tmp_target.unsqueeze(1), 1) + K = 9 # number of class + smooth_label = (1 - self.eps) * one_hot_labels + self.eps / (K) #nxcxhxw + loss = torch.sum(torch.mul(-1.*smooth_label,torch.log(pred)),dim=1) + return loss[mask].mean() +# +# def calc_loss(pred, target,metrics): +# criters=nn.CrossEntropyLoss(ignore_index=255) +# ce_loss = criters(pred,target) +# loss = ce_loss +# metrics['loss'] += loss.data.cpu().numpy() +# # metrics['ce_loss'] += 0 +# # metrics['ls_loss'] += 0 +# return loss +class CrossEntropy(nn.Module): + def __init__(self, ignore_label=255, weight=None): + super(CrossEntropy, self).__init__() + self.ignore_label = ignore_label + self.criterion = nn.CrossEntropyLoss( + weight=weight, + reduction='none' + ) + + def _forward(self, score, target): + ph, pw = score.size(2), score.size(3) + h, w = target.size(1), target.size(2) + if ph != h or pw != w: + score = F.interpolate(input=score, size=( + h, w), mode='bilinear', align_corners=True) + + loss = self.criterion(score, target) + + return loss + + def forward(self, score, target): + + hr_weights = [0.4,1] + assert len(hr_weights) == len(score) + loss = hr_weights[0]*self._forward(score[0], target) + hr_weights[1]*self._forward(score[1], target) + return loss + + +def calc_loss(pred, target, edge, metrics): + edge_weight = 4. + criters_ce = CrossEntropy() + loss_ce = criters_ce(pred,target) + loss_ls = lovasz_softmax(F.softmax(pred[1],dim=1),target) + edge[edge == 0] = 1. + edge[edge == 255] = edge_weight + loss_ce *= edge + loss_ce_,ind = loss_ce.contiguous().view(-1).sort() + min_value = loss_ce_[int(0.5*loss_ce.shape[0]*loss_ce.shape[1]*loss_ce.shape[2])] + #print(loss_ce.shape) + loss_ce = loss_ce[loss_ce>min_value] + #print(loss_ce.shape) + loss_ce = loss_ce.mean() + loss = loss_ce + loss_ls + metrics['loss'] += loss.data.cpu().numpy() + metrics['ce_loss'] += loss_ce.data.cpu().numpy() + metrics['ls_loss'] += loss_ls.data.cpu().numpy() + return loss + +def calc_smoothloss(pred, target,metrics): + criters=SmoothCrossEntropy(ignore_index=255) + loss=criters(pred,target) + metrics['loss'] += loss.data.cpu().numpy() + return loss + + +if __name__ == '__main__': + criter = BinaryDiceLoss() + target=torch.ones((4,256,256),dtype=torch.long) + input=(torch.ones((4,256,256))*0.9) + loss=criter(input,target) + print(loss) + + diff --git a/model/__pycache__/hrnet_config.cpython-35.pyc b/model/__pycache__/hrnet_config.cpython-35.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9154a5ec7fed1bbb9e88641be925ec7af850f28c GIT binary patch literal 2083 zcmc&#%Wl&^6unN;yxKHr`k*`(ERaxzRMUr07YH>qE~%P2Qj>t>jVT?s80AqpDOGn! z?2wRHv1ZE-R&4nWZxKlO0QTHF*E5U)81d8y|JcyFQRYSR^)M`?>ET2Ivl}S9K7!*fx z&R|f#qE@v&zH90)&<<%23sQexVWm$NSw$)fQ8h^If(qINvFTC=!OC^%!j6ay7c$C8 zBjgPuQ8Re$h0ka&#LCNGl*hGbq=T5#S(FiQI*g?|rNgaGhg(jEF?%sqUienP=^(Gu zVXWq&beKVLJZel|Gdv5gYJD>u)aP{Y8J!MerJtt5q_PlonGTcMMW@5OIR&SKJf0+; z52nM|a7oi)-Y^o2G9Bi%mrjRy<>fEQbe4>C5OX?9MmmgjI^61XxaD*h^XV~GULh)T zI>_sE7^}G`3uaIp_qHt4VTNa6to6-wP@mJmXLLG@m42EIlgdJ@$aI+0E;=3N%_%q? z<_#lpL#D&L_R{GvUj?yB=OAuU?&22x4@USAxD4Vpbk^=n=>v7mN~$NX z#-sk_Bn+a##VDR2RTR;~Uq`cSe zc3ZCJI)_QN(>t1GXz3x9m$W{!*XSHL(<~7`u|)4sV0xF@cT|AdI-z@1np8evNeZvy zVK~{?Nx{};3buE=wUwkC;N~WmaTo+KImToV)3*HC>yy**`srwR7M;iVRe|3Jj^`Jf z$DZ2_ws-u6{R6Mj3Y@n0JUDcbZ6*0mxADx~NM^m>QE=3L+H1L;WR9A>3u7`Z>wEwZ7PKGRnX9t``JHu*7*wY zeHw+z-(4*BAQaJweKofKeE#T5j! literal 0 HcmV?d00001 diff --git a/model/__pycache__/hrnet_config.cpython-36.pyc b/model/__pycache__/hrnet_config.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..106482361828ef9821a0e7113c4173338bcfa714 GIT binary patch literal 1631 zcmc(e%TC)s6ozf*O74jby{sy=Qk4xZ1gR{lijWwBNF1pNU3fDMj0vuCp-$2YsSBjO zKp&tl(6`_ry7IQGyh2w!|IE1X1|?$7{~vSyocTx^#bU1X{1+RRH0`(c;3rA?fPNQW z6Pm^~u1~b7?&&(TLE7*Pp%bJNUP9>PBsooaXgcAw8teniKGMQ2Rmr)b5HOfF%P3|D}ND8k~cx5cT%&DoIn#}P`x{0Znn1YF^ zn4n4~P?8B0n4n@GRCy~b}`SjtYObvbtD++XVyu*NrfeLzQ)JM0gKUB`7=r%_^X z)>|8R^a{rZJWm}q2PdsHeet#R9UZ|f1+1UJ-r~?;Yfvw6bZ~s9qn?N|SK-WG?(8ek z*i~Y0->q*&IY#R(SE0{Xi1!Tf79ke%@;9T4#rDN~b{UMr6go!B^|GCJt}|qN`(E|v z#BFw2tM9&Nrw(cx-aYHFUcY_Tbp}xZ=I~w9ZFR`ygzNTOr|L5&ImS91nhK;Y(Zv$Dh~s)aSv- z?zG$XB>ZGwtwvw`z@Cg2OZz-njDj6|O8=FKEm$#m=k4p8DEo%`tBLZ=zA1ESm>RA4Xs^)om zpo>m{XOt@Ptmu^Zs;Ff(TM;uAn$1za!f2MJ4WSD>&kITyc~SJ_e1#?RRi37PmF7#L zig`*fg%b(9teniKGMQ2Rmr)b5HOfF%P3}SLBPF~_;gx~(GN-0;YBI;S=p>e6VhJXe zVuC7}KuIQ0V1kPMpvp5@3`{Iy0Vb#k6X*i35>rF;u@Cg*e8R+HU}7mIsEQg50Duq8ICh#&xGJ!7#c3=e={`{xuukjZhZR6z<@8pIamwQ*U$#5|bSU8$a zqB*s)Jp3BY!pWF)X&Qy&IlEqr`*YdJIM)Mj!h@ww0ZY8m%L7{C++n-hZMmN7oW*A6 z@_cFF))tmexW97P=$txBdhxa811;e_1uS2|-eb{VYf!JTG_m}kWy6d!H_v(5fwd>sy=73#!>XE^0p{?c;W5qA zZP$n4mvE;(=uNJJah(UjblAIMvxaWye*vR!Y4895 literal 0 HcmV?d00001 diff --git a/model/__pycache__/seg_hrnet.cpython-35.pyc b/model/__pycache__/seg_hrnet.cpython-35.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ea05762c00bc99e07878752d3e6fe34bf973583 GIT binary patch literal 21814 zcmd6Pd2C}zdS5+Wk|?UxlKRs0F|T{B=1w0v`kc(pjt;7&X^A?hZgtZ$pegc+qQpz@ zQRhtCNxD7j-FS_}UV9VU$xaR@apE`*)_}8c)(#vbj)5dVfJFj4AV>gbjf|0hjKD## zK#t$#nb=zPsx2R99Ct^POMM{JyQ!@2Qqg1o3zAIA07Xl~*c_ z&`|k+Dj6z`h>izTI^f5HDjoFWA+;7(>9C)SFuB&I((Qguhe}746;NxPDj!r&(PvEM zLn@CZ5mi7}U6O2*9JGq4e1|F`|CqX$kE(p9dIo?x)l)#wjV!e8lB_OeVW{KkDeCs1 z;F!vHOF?%_K`#o9t9*|X^eC%Ws`nwcSLOR8x38u836*dBpH%r%>KQsm1evE*{)~DC z@K2#}zskQPjbCbMd>RF3RsNh5oNFmKgM#xae?bZ^v=pG5mz9N{O=X=?Ma=LeRK2M3 zm!#^Yma6FZ6_tOL`RQ{i|C-9buJUhi>7Jn(de=FZxs1%Y;OE4-q|CucrpkXotp}8S zH(;T{o9yJNGSU~76;x}NRQ`&}Usd^Q{4ZdSu1o$a*fk^lsdE9lKK{UUR_SYNGf)=k@T>UJ^5KCYO>O0Kk-Ck%dZsaP(0Be1#5 z+=&A6O$Rd|oL9_pHD6rewm6*KW@ZXuCMnX?PgshPRysRZ(-l4H{k#;20^ z%u%#U&H{C(fWYWy+-xd=;m6T!aw<8ofKd`NiR9unjEX!n@em0!H90ap&7qlF z;WNi(=N>Oi-J4uACugU}6UZD*0J4$M>4Y8%lNy^InVPv`j*rZY+#_@g<}7+yh_af7 z&Rm*IFpJ}j;D2m!YBtGfjLjw&7m&PySuQL#nl4SH5?9QTg{c%5WMW}<=1P>yi5BP` zouFMZp?l#{%K|la1XZ|T4=`yT7IQo?GL62lJxQ6OU-??}Q_7%Z`q-*nvFe%SY;JwI zTCv0WA-uxqLYh}7rnXl!N)_bMJyYqLZx08>@w^xT3oy6Gs z%;M6cu?Lmmt?_&BE!D<$=5zB)_h*y#&4=?#vxBSCiTiWbd&|qI^sRbfd1?E-$-?Bs z?VERpCLiA}mdf|6;k(tFqpOp9x99H8Z_GTN9<4jK9^Rh6SIyqH?SZ-J;q9%hg&W)Z z=6?G|_UO6apWR$YCGXA8muk1R%bD!m<%z2g67$=&TiL;Zor$sO)$)zyjkWR3TiH}{ z>F)08_`=NXjoqDrt@{H@i>o(=3W;o`_~2oAc=dLzFfe&_X<+xgm6e+t^Y!Y;quc2{ z>&Disv$0m4s;}0rPAm;qN-Hzt16K>V_lkF>iVy9XrHzNP+e69WndPy~{H@LTd3!ON z7}*-hC)et0Gqv6F;CqGDGr60YsrvHj zs;$=Ysq_|C%-J2yU6sm^7mZ{J&)n<;=}-W|Gk>+!;R zc6oMielfMQKDJ!FonKqOGdhvJHGMyGe{yo_apKPQhLsu_TYYfzPV(WzY&x;EF_EqA z4c{G_E#!7b?TyUJok#c9H=TK_mZ{D=oAbBF=jO*A*!k)0-90OlEIGU5_m-;lf$6*1 z!utHxnM`UWX%7x07E5CjHy3V?=I@WE^P9DDZst)scXcW;{MgwZTz)v2pPjr>y80-w zoVY(RQoL2%ntQNpZIm9bFU-v6NF@FTvQ=wpw-K?R2tYMM7@(BaY8o^ws3=sjLXs(Z z)-xqJVb)2-BW`S_n%^v0lU50=I+=skC!Zifj&L83a|uCRsWn5wfci+Oj{@pzIteun zT~gp1$#}>OR4Q&XWo>LixPpJXouk=$ZZ%o8%R_lLv|vp?hzH$RWwV@Fv9k3|NKm(J zxss{x)+{%&4pCijeo5I~81awZMsU4s<%`+t+s^gh|Jh&t-d8{Q-q(Kp`u6ViTs3c9 zpL{T~l$g3cx$x%f*ur%PQmb^`u?iXC8P{sNZe%pc0WWh7RBTlsVssmOeXUczrt378 zmJ49unZ@I@A0wPVa7xH>rISGbOdBdZzMJF|bl=7M1T+cTL; zwrpiGZZwm@9IzLN$1<6X&1}gli9s4dtd=ck9eKOkAEISg0YJtMqx5x-O;BTjpwSz# znb(*g;))skIBy|n#T_68wRtG_LG@*&R)AB2H6KG|LibGfQ5I+LeqK#5~_jrKJg`=6u6 zRkHwNiCg2BO+`b=#tITst_`Vh(;u{}HIuj&Cpdf=);>v&)!fx2v1X0;1(USYjxnXbHawR7N z0zh#tTe3h7A!u_ec0a1wFCmBkVnm<-! zCPZ0o`WTyk7XhGM6Vndt!HZ))jkwrmUVe!AD*dXo+p$@3z}2ihE8cFOWxrvfvEzn`nT{J>CS7Nw zYa{0xtFEz~?yTnQjED6tsNI1t4C}Y$(Xt;#rd1xvUMwT8DwytW~UJU_e$XR>|=g)LX(YfFUqY0h(`S6xqAa zr_6G;));8r+C9+6!7j2Z7*x7N-Rr;^8%zmytb?5^H~JCc_t~U;zXjvAT6Hw5)Wu?v zZC9Ku7e=}~1Q_gACkq>_Y?Z4RC|7mr@x~yYWNjJdEz_x4IhdT%+cFGWNM5PGU_l&L znrn4_ubHPG3!j&oxsO;w;Z=0H*8yTxswubVeKZ)7}o}H-`cdU}zzLL?L$&IXl>uy)v&`M^r77vSP zka29vY7r$GB98+%RW*Y35RY>UfeKTe^Y5$?!E**;fZ+w`tcO#6?%stXdFx>XfNRo$ zM-+f>f%AF-TI_>XVeQ2!fmJZj5cs4R+2Qq&vS)O+CR&MfKiKM+R6r3bv{U>atk5kLP}F{?TQH9fk~>C(P5>->(jol- zRgbH_k2=%|WO!_%Lp0hRQ8tjN+o5{9v_l`!gWa!TLP|Z+MJ*H6tA--ZPpKaZsQ#b8 z(5Qzg9AihAf|qP8<3BVAM>ln9Q~(bHyZR_hMhIr6z1?tdKUE<;UMk+}={M)Gk@f`! z)N~4X@cB<)vR}rR8x{#+Ut~N4gTiv#sWn%utnGGXOT_|2U=@^RJBy+ZH3Pq+-;%8O zap^-NFzp2UZ3dSZ2xh;+7?l_M9R@Ty>^B&^%HTBwE_`|r_=U~trelAJ8K7&(##*%m ziy+>q)hJ^}RH&U@`%N~y%HSF+hedR{L6EB(-pZCXt#q4HEN38%p+swipsC?#)rMNn zhj<(Up^WRW_qqcYK}#3$cM@tY z;vI`cLP6u?MPWucWHdg$Bz0(_S}iubrj##h3k>>hNYMfln=A#x;?z+J)zJ76@|jj( z5Rb#{IUt{bv=qkZ9uX3yVF)u0BsIedY8E<;mNMZ#Hgo$K4$Ku$;a(#Mwzl{3(SSGj zRXlV+snsjpu&?*2u>|bv93q4uh9$O{vu}YpeSRF`>461q7`j+pZ>7N`BI;wD@<33D zR`Cpb;@P?#ZjYyq!`t8+x?dH|X==aP>iOeQjoy72^jfD{vJRM^g%T{fQ;2c{SD8_l6*|`GtLM&uT8Ejz)%L^(dMu4J{__F61kO{cKT(<6$GMyUfP8hH|yW)>? zY&!hh2$l-=&r!_v+bH)O5WRtmfxZ62bNbV?w&9;}Tl3GkGNDX6N2?bAR6Xt;N7cxRXf=2o?|N%K_B?@V8_ z6!a8fdy^0lN8L!lu5Q*GThP>P&le%j&@k>f2#I{)(j7pTS}9wBi!(|UR>y7WW|4h8 zU?2$8z#vp-dQ?3Y<0UM@ThLgs9Tcdi`Bg@LZktW5jcK8Xp46k28V5+Arc77YDdm*cwp}Jys0n z6ILHbp5u0mj?!NsegrM_K7AM6=lsF=O(Nn-iMtCcbjWcBn^CFtc4c2y+6VOnJPW+c zQ0qbEyi5KC9)~!@3b+^m1aOc?9(MzxEp;)B`smwRom@LR_hJ9XyqkS(AW)=)CS$)a%IrwM!6x| zDsAfQA!ZM`QCd{m=^>aG@6f7ZyFfvLinJ0FRzumOC_3T%`3ZHo(F%MnnNoJwvXAE2 z{}#=ha|l$R54^eqCRdDqfs;@VPl9WABX6&-xfD&7O%EeK#5Bgo@Hjt#0BAs8sL!~` zlx5?zYqXgZ<*S)Q9lPWle_ITYwa`$IRH;|#LfUJ z`-d5187wm(nF57HM4V-u6zaCatcK2B-*toFSqHe2I2C^n)f_TnWkdp0qPrQ&6UFyl zZ_8RePZYQv)=mu`F)&SyL@)4G64U+-#}r)}W5Pu45&Yu}8XuqbpT}slx+20rt1cd+ z5nA*Z&0(D{I$EvQOIF3ot$)gW*4#!&uEA|6s7M4sJkIwJd|qw?u0p~G4>8o)6F>y8 zA2Cm1yCMOP7D_?*;16I3NFozSF>#UPKeis7z|cjAB4}EzHNjiZ6{<-;@Qj`CdIV1q z7DF6muI?DFBk(717g;NhKOOhkPmhd-L2y}$`dcVjz+CZ}d)Iy+UyJcR%^FHb9Apha z7=9ugT2PTftpyd=k^Az1Q=omJ`^NTv2HjIH`oBW=t@MoD!alch1#E9?5(uoB z#4iYw5Q#(vBb@*XUce;UTXyCElc-{Dc8viy0vct5UD!qDLO(sI#|WqBgr7ck& zm|ry4_<8)oCUSfxry%0}cRZR?1cXx@x0$<-Q-qPhkK-UXpppRfJ-#5ejkdbAynY7d z#kULLpm~O)Qt%E)bGek51{M3a5!2WP{{kZKKKgeK(;e&_ejNA2B#tH*EXUJ0^|9$< zvcXcw%o;4^A0crRO93{3{6V^UcOfsJ?1kY&Z)i3nsxDVDsoCHkv&Tj+T5{kg(&m?MMS1#y_@JV$^MQRk7CD#mAHd%`PPW^K~IwF}wI(1bDjUSlC$^f^Xl+oUNSMQYrX!wTDb7n zQV&L35W*fleQ1r@x*I7~3kAz|!#KsS+V%%nxdC^EP;Xq1ksp`ri02 zYRRL3HU)h>756PEF^QYju%_g8y}2TRxYlwuq#aDct|Hl5&8poYYrnwII zSl)tzTl=R3Rfr|pp2h}&y?C6+C?3EU_qb+PxCiMmvuPraiAKE06vIt{h;iKL0TP`L zf|(fS180nv@O?J0_uA3uBaE@xU(T0=WF>qAb*2#EriIncurL&{GsZu7&BG?rGBzAS zM=v)v0z@rzZEW;mY4wE;;T8HKjQwNi(tgI^D-5`bZd41sY20((M^hiOMsuNdnf!4E z`@IkL4)Stb5aK*d5pvY?fxTB=Fm+!Nu$_F^mi@}a{|9n4rjcN9F}{x^l+szezJghc zF2&Ues-5o`z#?XX%e_l%tcVW-=sW(wCkFg*-YR}5^}JOq8>C76hs1^Oo1{i<*q#zn%2 zVx>XZfRjb$_J?YzbrAtAh}xU=68MA%d&3LOwYmp9)t+hTM7+`DrmgO0ar&KS=KBa1 zab6`%Us%0hv`BkF>m&Gr)+WA+mT}K$X}125of4_jxP1@ZMp%Grthw#INLe6n8kkiHq zBZ_zzN_ve`fo430I`|9hz5W7PY69-$d~)J9A0ekUM2M0MMc6$rB z^}S#1`^%h>xLAz6H;#nkaGRX(l4-}-ojxGnKu;PFpeb?t4Ud!|%=$sL`~tcZ43*j$ z67mqJkv0CoMc=?Cp))YRO{<*GT4IejG|`Iqcvx@3EmK2I0_OXG0EH zIpbAifBOLe=h*oj!U1qK+#=<^MV5@(z;*WYJ)v0C2&?|nXtMF!jG~;L$%k?;5k2#| zOb5sg3`YQ#j3_do>Vu^9{W<}daIi1d!xyGVx4D=JJYLXbu1U`b`U9fIa~gX!g#N_z z;JJ{$EXKwTXqMioPS1x)tkK6o7a1F2Fv@_aqz|e;{a?7k2U-#`60sme_@k^vg2H{u z)hurFy~>OP10l^UV*VM=1apP%B0Cu(uxsqve*u~4m^cbN6C&NUDpyCnht=>%TkxSvCI z;}h~?&JQ7Cp*~Onq2S|}^^r86s*v~lcTV_eT;rm@nde! z{v}CT^<*yS4Nv;*#>2ET0HDBl4>|Dx$N?w5P#dt4H_X8MF{}_A#gm|PD<<>OSRlm8 zV)>F-8g6%Z*>rOZA=7bKabR{o!T-WeCyxj^Y448f#lnUu3G?+Qg%(Wh7p5PH6#_7@H^hz&q z+F_jL7)X0#U@N|i3%;16yafsUtKO&TC)g910J$I@G$br1NyA{2EW<)$7N;ctv^R#% ze_t1!5rzdoPG~^-r4n9q+y@BKhXBD#mVN58013B1@lp#E{a(HY#aSPUQyPkMQWbAa z?t=p9Lr~yV%s(C!=UbpSM=0=;sJyIdc(7a`EcKVwnknH$a$3E{V9wZ{<5fDro#KEPgR&T-|&YD&v)sz_kGpPdJ7J_G^cNmOfi`aubN#U5J6|NtZ&M8dZtMR z?>p%CUAp<^177^$TS2QfY;tyhFS zU+zBSCOrxg{8OS8iB$zXcW&?^l)hfv_+fZ5*d?l94^+V}#E}YIhDu16-1n4#v#RnP& z<0AH(4vaa2dKXj;8Z@IV4h*cpKM|7zmOUDQJ@?a<_{)@Qa1fDc!^l7g!$cQT8lje_ zlWDyQR)KbEK#1vL48&FtC1Q58*!pEdcMblIO2)>U;$(qe|WBLgs=j|?~A#Mf|89BG5d#4f?BD0YrmJkOAE@T!~8iSRM;IOwaj=~JyI6d4kW zFqi}RxlP;4KhFBg-f>9=`RV;R=po(HVFVGVT0w}_BgEv~ki?DM`?i+!SP)@%$JhDI{AZ<%O%^o~w_pbrl1a6Ii zpW*Co({& z2n;;DGaxh%RN$Gb=4mi5!AWpFtj|`_ZJR#k5bKamnWq>QI$;XCrJXWSdNF}zFw9c% zSTfy7#dBtM{K0f0rS(()WX%2z?wr7nE@Wo#TMx7PO8CU0mZK(0;R^XV z)@*VTx&(CjPcqin`goVP>e4Y;aXKGYiUoZZLdR#ilgvV2aCWp)rs=g8)ZmGz+veAx zewaSPE6(-I;Zxg#yx|knP7z72Afn)#)PKk1GHqYrG)_$^67>~Z-Vo*W@~*&%7a)%4 z2|4sBpI(IRcti&NS`bSuiaGEAbV`>UlAts#6?GC}YZ9c{ngpr0CJ$5wGWexLFq|Gd zlDjz=1(JmaS(L3~a(fqfKpeXM^<9o{6ZPmmYZU0dL;ybZ3#EHw<0-@r>H*J0UIY2P z@xJbl2OH2I???UdfVa)@vU2*?fDE>Y{jSjjLDwaEd&duy!}d`XXi}5wvBH1z;cT>H zp1=x6$n#)@8F>iS?p9d40)T=O*0Un4|Fs1aUH1U27Xr=6G%DbdMg!Jw`aQJ3`b|S% z#ce1JYd{sTk3!wx`~BqIA@w#k2C_g>VVwd~VOg6_K@VZ7L#NOV9vko-tPopzX}o@9 zr~p_L_!1Zfcv8Jrt@TM2bPlT;tQ@?BE*rUW0`HNS_(6IF8da%gy`U=YhBc}JtQS>{ z(H%tJkA@K}j!ZBnzX}DkgHY=dEYsdAc!Y341Vn_j#m2xjq?6j;1-CaMc!3M8u$ZU- zw-9ph(_PXJbKuRpJxm|#BY1Btu@U)YjFD&CW{e-%Wm4S6WCa+%Z_;+0ZoSk$L9 zqc=pqlIYDv{aVuqn~gFc2e3^)h@XA_J*Zz`C1IR@m9cL!5M}JIGxj$a{7nXbi@`54 z_}dKr4ugvfUS*JA@H&G92EW4K?=krM41SrxKR^)g)h8wIGGhWk8a`_p@ty;FR}yZU zzV+gyLpuYRRSCD#UHF-ud&$h$>_QSh7UTwY2JsBx8Mfo>iCS5P_7;9FBh4O|kgNaC zy~{;sgR_jrba=P?feG`>`Fx-U_quev&$x&;18|k)a5`|z(0^}$UBAh{z@=a$h})yM z1ksBS782emI4t{pFM0csUUF)ZHuVbo69zBvjnjDlO(2T*cEQOHVy|@fcE4;hzhNJ> zA<^kQ&UX;m#ZSILLBOxzDOCn$hm^5YxuilkqgjvgbvVN?4pw+vG1Md22AV(hgKgjk zQ?xE(f7`p41&Sx3;NVx-2j7Mib}9G`6AGL_sjt94tG(`G3_-Mlg+Lsjh@M%rhG^Aq zJo@cAndK!K3Lkz}q6_*%Vj6gnoS3?o+FZocp7R4Mc=HH1sq5FF@k5iw+k^Rx{yB~7 z^6M&FLz(TtJDFk`E(QKzQ?uH&T77lz!WZ8$OVzCWut&3qZ$nun^TQ`OY=C@%2t#

-^#$f;Xuhji_{#K}9%=i;RFg(->0iplA5U}ol_yIt0W8WGOvpz2b ztoxr^A>ePultFv=MGslM^Qa(#Ok9_4IM1JMh+mdBQr$omKcvL#YJ3}>(#36;-!HjV zsp$JQeSY%V%4S7=vJk({gzE=>T!fPwhEk8;X#BhEpaTc&({cgM{&P&@_eJnKef*^& zO;wwA$&q*L+{lc)=O~^({mT6`VVz;{hyh;<(yxqwgDN?N)Mkj#3{gZ-z~(O4TnNqP ze+MzSSCy~e*H%C$bzHPTVT!?TG5CuNn#d$E!TI*L0VLkNd`Ec8D5~+w6#wW(Lm24B u51$|}*ca25EnpQ;Q>2&%im%Fa6u8yv*?k>+(xynafO;6AGEW651myaq}<@QW&IwPD; zJ1>6`J}139J8qbsVXNs~W?>eFVHkFSK}&-`TCpG&5J<3Sg@m-?M@SsfB2cgR3qtb) zc%C;R54+smGdqHG#TVawzW4V=4Rv%RI$!_8%pWLC`y;L87eV~Xc$}~6nr3Nf&C)HS zsHb)-w{p>#;z;q^#5BC+UtWIdLSOS~$P3dxILi)=le-b@qKd{gO3cowLp(b_TUBSTCd2%R#NP$hl~}f}B@^oO8&zWW9=< zSA(4M_F4O!Wq!{{UqIPw)@77k4$5A(Ubn8iuca?qgVtMC7-PI*4Ov&M2x6D4VXM!I zBKE3nT3^^OKGIdqzGS^={oHds{hA$GzieHzu3P<9e?E-V>(-6uM*52VntcU3^@jbX z{f4EV*Vf-U)OF1=7PY=x*|)-)mYAJhGG_`oyX@GBu}XDM74mDYc{SH(4&EHNXU-OK zs^U~u-NYkXl?skisFcluW3JiCUfDDADqD7K>zcW$Y}>4?nz^;C%G=jWw_;|?duG*E z4jNQe+-#v-DCf>vFu+LeX~2 ztL~a@US14ZT<()WEIV6F6w2t$HtR*^PQhKPY`Lbg9aj}{?BkkQDCde>7UxheE*46K zU<5Xol{*Q{e9OTKIL~WlsbUpY`L<OP|xNhfM=5Xe# zm1427!=>daWvjsTI&UYIP?BA#Y}>LF|6I!z7c-UJmi?DG_=g%4IcwQs(Oj{E{lc)= zU1T*@gT-yFI4;h;kS&_kijv`uud+X}G&Nx^&P^^Yk1R}>(~IV#g}Eow;}hfN<&j0C zFJCj4r7#{-#YJ;&Au&DsXl8l>dDE$} znaAVPsr%+A+NI_Iy3+t)bhkF1POC-YGh`H zLnTHYWA+Q2pE)-7=-I;b{i!8$YHntH0-2)|nAyna%!EG_7Bx0AGCh0E93Pn-xzE`x zm~-f5A;D_??9An<31)G;5&Rolnx0E>8Dn#)r3ECfVU-I@ji$@fixbz(k%j3+Zph@q z-0Zakw-YVUJ32wT)P&y)w^}x+aU!U~4SS45*XLr6PmIir`#)ZD-Kz8UjT@D!x>?%nw@Uqj#{C19+V0O) zN;md~OVgF*>Z8H(^_j<&wJDrbb?oN4nk(#<1}4+_xxqQ-vHNInd}q3{KXq%tsXTkK zx1AZcZ|#id7j8e@nlBa9?31Cfn>m zJ8MHr&cxWp?9%envB%}1?eY5$maAjC^SSxuhjS@)>&g7`+`!t*#KTASgO!!V^ldl4 zvb^(PDnB)O=hnT!sb_Zz#nOFy=w9X4=-Sl&ok#cPH)o&CjJnS4CwJ!WSF#V4y7_2k zXlHwSVR+}r+V2c!PhR`Oxvho8)cyJSV)gb;DU-doGI{;+#QaY6c6Q+A?&R3aT4{J? zbA5d4c6Kqfd~a`Ud|~#^=HBkj?T0s)m)3>{^Ap)};qjBw(Au49{^r#6<(qpCR#$Ir z&byV7r+3o(_VD(cv$eCBikcZYTaF*t?rcu9^}_j*C%h@zNr@HAC~g&=Be$}QuC-}R%(t%H&1ZK`Ui)&Yx|bcg`(#Ar#;41bN7

_g5dy=0P#< z4c@=~Y+)n2GB+^4w79%6woAjpW7Q%o0-+SPw#JRIrDZkQ<-$-;@o;jaaJ#Vm=<$lZS$wv!FgxEDQU=gOg&2UhYB@ZOA2Ys;$9WThtwHM1 z(uN(fqLyKW-q*pU!M7>Vysse+$p<_-<|St<)>hG;vWp$av^bJ0zpmQ+;92D3`s&qP@8VrV6_4wU1`q+*+!lN`sacUa)5#_l3M< zd8?FJwX^OPIHwm|DQDchs_jKLz@qcck2KYRk$yaUqhwo!?2W1M@f$_8d1GxWyJHt_ z6tj+dW2N9^3j;U6HSOXJ$IfSjK=fDlyy$3(qg>^LqX;x3tS9yT)2-|4_v@C7Hu1`ax&C!VN?m**RnX+|luxo4{>~dhtxAoGsaz zjF-q{unKHF;>k>Ab1Pd6N|IoH;G!iPLWZTf>Z7!bEOU_IDf05XW%hBy(34Tcyv7m{ zmxJWn`2vDEWLP@n;P^LjoB|VV=`bKA2I(;JO*BYHu?=lr?_w2G$rkTtAqJ?7zqb8< zkT8SaNItLZJR~X+6CfP8_O5HzPuR@naup>aQ6NY_Mq_u4_Em+G&QTz%s>(`srBE!m zdrjE~az=ln?nJ3ttnA4^lcl~iLA>nGS zUIR?88;1ak8seEy4G^CRTbd6Q5ijCw0a4UTsLsRWf^dMbT+zt@xKW(T7Hxof7!um5 zI*)3K=+mcrZJA6CFzjS9E#PwwWq0xTAgg!kQGNfFlQzA-K2~G;1Wg|G7@L2Dpk}N? z<2eL+;%x+P=rNKGAL^^&_reFEgYZG*Ac{P_7OLUQ4?_n=O_!AJx3gl%(~(DgDTJPo zazpFknzkOP8NC|vBF~Kj!;P-Tusu3AtX^|HZegT@SS^MPi>GAQ)fr6g6XM4w+-__| z$SRaR`9Jt6lNoK;t-`+Tg|`d#ju)fM2)WnOC%qU?x|H3OJ@g`OMdjAKXesNKwu)Zc zR@ni)vG?u1xOy2Kc@aXM<3$Kej+a;=_GP4NJ?H6bp1zZAujEuF0Mr(YA~Jcg)rtZo zvc8Cf(~UsWui;2~`4~yPTi?HU;?V@n`i$ms`(o-8+N&N0rxAFuOa@8^4xdM^NHV;R z03vb_=Ic;F)b7*OMOG3FC!`JH@i9CS);l8teL3Vcek9fW{qz|AQf6+2qW*~Ml5$td zqs2;YBl|HhMKF&-{z}#<lvN;uF0{91twK2i?XFz5i%vjZ-WF^Jut0mn)ap7wfxXwMj9JQ78v||Fd&l}X z-bHp5gG#rkdjsI3KvJM&Iyl5qqaVTII+2t2+t55K6~`xZezB-aJ7p)!jgc;2m>85w zCkq9sWS1%!C|7aZzQ!N{`)e8IZPTgRIp~AZ+X}Q!ut%xCV1s9tn|t-*jtZ8ga-<0S z@!gJx80U=QG>;LtkPtVLxUi78h+uQ57Fq=^KWBa%Ic<0nz+lJ)4wE`QH>7^IFo0MJw5xr? zhK^DcvLbsfNIwbQ-rfShUJ~|!1}Lf_oPrmnMsGo>sb(!QMK8Qk$T~oFo>Si`brvs0 zRQE{)*@hZM;3ZqZTI3S1Z8p36s8T6XxT@lWiPy%*{b$MgmTI_>K;dwVP!@E6>16_q+Cz%IJXgK5i%?O_2`telAw|d zdqwU#o~q~WUpkTAyvW@+14>}R-M#?$1wZn+Tfn6x4el1w49(|iI7i4b8)0=DG4drH zTnu6=V}M{lJfJfHZWgUY*Q2!dfmbo5k&m_1X-94hx$z^pN#w?B@s^fxw+(O~t;G(F zgEr~C(vh01V7rWPC^h;tk+hHqh6wxX!fw)ai~d7FuPT9V(k#z+tU8DabEkr z@gd8#_0C!Xt6@x52;1D77@@0{sD*10V60dtpNH8O_-a|)#lm;;0}pFbUA_rim}hQ7-uZ{pu#Nn|xmlFezyD0=Fmn zPvdbY-NlpPu%>suCPd+6e$UfwNc520AN290aU+2J4HM^%XiEw2R3IJBjV;tw zrEd%Oo<1r%1qNGGqxlyqpwHKrVK+UZz~qb2iHdb8OWDwtoQBOB(#KZSs;ged3IUO( zoz1P8Li_^_h+FpG{?c#o_q$(8<;XGk6G+2b&OMDt^W3QhxuBmo{ctTR{5xzxGXRB) zxsmlK9!UB%=m>#a5OL${ZGe=hNCgp4y~v?{NLiyj;Gjx~5~mXI&5$rhb%#0Bgz8|G z^ua2{d=Z&>P+qUk#Oe}!Mnq0uRp)xtI}9lM`0TH3Tve*z#9)(MpaR<&$Xu4673d5t zobMK!k^OnGppF1|IteYW!OAf`kVh=;=w0H5C3sRRNDFh^utR+7EqfzfPJ!)^gK38{cb$ZnzMaqLfc z|KiZ@5LJ}zN^y${j$F@}fea?>5yG1fgtyU`K%y7TtIAf@5s38SRsl>1O+j}^AqC3% zB3>KdqFT(BVarNT-1Fl~dU?Wn%M1ik8sH+_Xd6@d7bp`dq=C~Y;#~t4h>3(|Ss$UQ zb$!IO=Mh<1*TC*T>&oOHCSfI{xam(rFsgOdA@!7n&l)vOnjZT7eN+f+TI%N!&*5qO zxC#9I^y+3}l)9wom_7;r>&>={0lss0pn=FRauOrRu~W^l*Pbi!0iATrXfn*N8rVw>_Mb;_q-6O!!bG_ z(&rzcnnQ5X^soWx6Y^*Mi6g$>)3W|LxsCdG;}cy%p`U(1==9p`-74%FAX`4bkgR`} z=tK4uOPGlv9WeRvb3yUigIKB#eQwZ?$C0)>KKAh$eqY^1p}0pW#DH}Fw#v2l8zgNwh2-u}e6SjRWwVjVxDjaW=Dbq$9Dnhxy! z3`~u;Z0|8l-Q>c~x^^q(9zjDfO1hmxfm=huu1~G!MmY#5Q(+*cAvMSEjcx9lCY^D)`|9AwDAac_w9F~ zR0BT)YnzJ+$##2Pknlm6GApnJBj)&m9{H;h^y%c)C0u+01GU$_3MZb7fA!?wYt~zO zJ-D233mGvIXM7Rc>x3m@!*28r+eoM zkQ>9}e2Acin_{9b1hjsO*GxegYE4vBnkXH{T&kB8-B=2h5w+j5jG#7E$WN~g5f8P$ zZH0o`NS|Cgpw;`AQ;%meqi_fso0?d>`n7{T^PP81$(xxQdnh^kS3a%I7@SB_6El9T zH$SZw+=yE0K~o5U?ZNO0p(yKm(PAZ^x0M&cJz_gqjyaxPahy&}@a2UV9gT(`ke6bul8t(|8!=NuR2_VKkq+KhL z60zh2ElL+c0JDLdiV$jUl33n-yMqvlXtr9l%a#`bg6Hi3Y12s`-$6kv8#YYeY$0d- zb@p+a!72g(2Zd(G(+kHA+UxcS1>B+nGBU$cH_?c$4srdQoIKC|Oo&K(#^{x+OK9(J zoP6#=lA8TB^lKt9A31YQ4QCSBeG{k&o{L?*KnS454e6NY1loNFw8LI@VlR19j1%(Z z<%s$w+N(dq;LkGPmU#)EP0YZ@@d)I9z#7e6p>bAyo53+=Og@t1>Jx}ib|&G%I$!^c zb=C1BgA?c>@<`d4lr zgt#Fm@aHJw|Sh3n4ie^LNk= z^ZUu$VO?#sgg%FrPiFW#%OK$Xm^EOOKWpRxDp1$A2XZCDfE1?D3Xsy~V_jep)Wz)L z-(1jK*x=gj!u)&$6u1)BCcs6wI);iq>--fI)MaHMuY$e@tWMBM!~YeGT(Hj5*S$E^ zWSD1wfg9U=;p+o22c~Y`3O*RPYA|7D(bRtkO73#9K|TF zCl9yFqax0Cxgb%t_5HU_oW_f78vbKs8Dt$Y53(oDLy+qdVgP#}frN*89ssa=RZz;z-46 zVO%5gqEJR+YrhZ77s4tGwetDE^Kj57+2IQ{{W3uih`KPI%bI{7?=kNfYWT+U5$x*{ ztbDKGt+>HzoPmkplm88;8bG@Mx{yynlV4^zK_31|Ygzc$Tw%r}13{h~Vs(4?edY>g zM0Ya;F;Cx9&ykrk9ht!`1Z^rosMsuc~_r@T`~_abGxsfLNT9fKtCgzz~35?Fdnt-=Lz!%zy` zfI<`jWA-n-h4hQzlL{5#7Xn#eCt)36lFuF*b^Ob1zfJ=EwEHAq6fWIr1KiVU0@ce( z0OQ)vYv0n{PDv-vYe*6MJFE`BOw4Ls9{@8yfT$*FTlWXr`l%0itpqHzI;FoJwCQ1g zL0>&q(APJ$Z|b=6<|j6B_q5yl0cZfJz@h#vV)je5a9&3q*L())v+RfI_n31I7_VEW zWQ6l5Io~qE`C}vWeCi01p#KX&|K}V1zbyTqmi{lI^kPf@7mxMd`?>wU67+xZ106g5 z(({n4;nG3GeRbW$^BM&$m_ee}nS+>h_8@*dzZL^4k00>NF7x~$QncZ0!}vP~0o%K5PgNng?PAIzCG6Kx*_-hYF zfL_-eL0U(TwiXEJ6YcC`%4hmC2{tbZCt+1>-e}xm*X?+9hRSCAc7*r6H(`U#VyO^3ImIQnDaS; zm|W8d_+#hoflO{S@3s53mJGd?y>?$C^JPb0aT^fe@cbr&Y~m|Gzl`3zGk$$>H35pjJ${ZDKu*j1{H>a3d6@M3Fd2_XDzvV^(5E-9!+u~Bg5O7Ogz`SmO(eTvTgP>!q3tmN$Ucmtc$$%hlqnyGHivsAE}wI!$3O%5=0 z5i>Xf>O@Ncf`3L>qvWH70&`c0@(4YCBn_R zl?pQPe-9K;O??15Po-052=h!03L4~RC}bgCcm)b6%D@@O`|FA!ReR(L{f>z&5GZ+N zsBv)lMhWE9OQMn%F)k#BvoOm9YUxv`rMg8q861eJELBtWdw!o~a2`0T#aiY8PJ{7d zN5+TK8|HXq?9J~(E)M43P%tlv*#Q^Ab!C#au)vPvE0c9*;d2!e#jwl}DO>ClLK$Mr zdxwEQu}}d~A$8K_;|tXKnDAh4Dx}B7?jNAVtDu?=tu`7JN&p%!Sb4qLs)E9;&7cZ;1~ty1i~uW2 zr?G$a1Yi_P(eYWz2A^yw{7sDBVko?TQ-@KGoD}dk<>J!v+Hv>Y#2+V*%-2($UB;;G@Pc-~p5W`=+4fHP5_YD*^zzW7E>_E!?TytN{ zzmho#Hb_7OeCn2Ykpl9}U1K12TpomfDfcbL$bOF>+FwE5C=ZSN3Dzteif?r935E3w zI5uid#|`k%2Krmr&r+_ge^9ZfW*~Ip2Z#U*A;*Yh4c%4b(-7X!Wi6VBVx%7ybXUZQ zvB3C)Htcy~_bG51+9-B&=AgY6N^0t>wRSkx!~~PSYQ+zd<*?Oukc3Ms_3?<{Id&FO zBQIjYR^oY_Gy<3nonSozXIEo$6#GxDIq@Or*DAjsOQzE86wqep#vjj2Ec$X>-9WGY zB9@?joxxu~kZzyE7p0k5e8@WIJ1|ZXwFOsAltYXC9_xzQD+fgF%`X>4uyK%m9byPc zCk20L*;y^*ebW+!k90c`*LTlze0{oUn3RL&-HI2h*WkgX&(T58&AfR3bNtAEh}w%h z@@t3)o&x(|6!P)a=)wI9nMa*v_>3b7*7E6KARM(}f9XNDT&h69Kg38l51g}HnZo`e zoP)+py1e*q6TrqG1{*~XYnDP+V4hfC;8=r}Ftn3506GQ%TQpzNq)A7LZ@}0`jZt#g zPWVJxa*+bhNP~uJzVYuMB7JU4pY4r4E$K5^?-Q07xVzOy-O$t@2V+PLK%Q?nMc?If zM@y)61~d3rN`ktW#m6#>4)V)j7W*=b&c-a(WENd5v-lO6h0JpU#$Tm^(WK6a^9V|) z)%x=&!<`8G)0oFcI*p-NTer-_C=fTng~73*NxcIs5gS(2ibx#ziTs$v!Cet=lQ`{06Z|5eKet1#OVVXLC}F=xOWK1Hj&igl>Gnd; z_UD)YO-SbP0R?Dv2hNgu|Cw3`_OuJJvl63i;9RYv)`dFVh@VHioAnXDAT>_aI7iFprE={gA=unfIvv{hp4+B@l^+JHHZtJ<*eRRWZLIVaJf5U5$^!lD%$9 z#6_(HXIouc*2-oWSA`o1brYVc_-ieFrj?DSbe0B{I?r#ULmz^w)O7U`9U17Di2H8f zv+uYf&>dNiAHp`)a4LtCyxw-G(UNr#h1~(4mIU?MKR_P%DBV6|{#6Nl13SPlCuRTu zsmba4i~INbZ$Oy54k-;Ebyyky^H}GGd?39&nAsV)n<kX{(uLrA@1~)oQL!(XPWu0Q zM8L}b5eI39Gb9``eLqN^sqY*e=4$yV34O7?!HU$dUE9~|BL&wk!P(LGulKrGnHp8CsYF2A56cGMTI34PwUesUKKjvC5AVY(y zNED;8S-si`NJ4 z`UfgsB*=dxfH|-`oXWq#OZ{sG6djLCE0~t{@9XEFK>_}g3_mp3Iy~q7lMc_i|KVpH ze)!1h2~a=j@T~h^T8HO%-z6YfCW9_y=Uzb-;Xh)s0$4{Zez}Q_&!<&i&=Ge`kKWp6p-UD$9R4g#Tp$mzVXp@EtD#sTKdn z(%;4?UK{S`X2jJ*9WsqCkn#WQ@V_4Lk#0*B9r;$mi_XetC}ITizd)Sjyi*L)41S)$ zdko49Tm~X%3rZ1)j{VORKGFVt6o?~)RmQj3KqFU;vhHIB|D3_EF`&xQ5(rAf-@zpD zY0SGqYsgM%G4XFmc|Qs+5}{Ze4|OJc&h|X)PWHT>h$bRES9`AY{E~JKC1JdOG5&HQ MhQA1rhl=k11;9LZ)Bpeg literal 0 HcmV?d00001 diff --git a/model/__pycache__/seg_hrnet.cpython-37.pyc b/model/__pycache__/seg_hrnet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9d790b8e3f057ae64d539c947b40c1a45724ca2 GIT binary patch literal 19540 zcmch9X>1(nnO;@($!7Bq#nDI_jiyHCi`Ml z@pLomwPs`4SxsiM-o(4vIOJwGiL;55V1Y#(1e*Z45gyAPXdzMgAsO_(ziG zdB5s|$Bf4|k|yix?>^uAd#i`q+F}|$fBI}F{d=~i{V@}@PY98(;Bh|GHOgTkiVwSPrY6cGURNM@jAv2r{q}t4g z88utX*lz_=?GH4w)r^0rnepdFsspJuvmL2+E414AJdiqtoDQ=SIh}q^7jjOSUC8OO zx=`aZ(x=UCq`UnZ-R4d6jQQ$^dg>Lk&pd0ML+lJ{oi|@Yt=IfouOjDy`8sl5_jAr7 z=c4%ra^CQB&RMToXU)WK8L9Iqd(*sxvP*v1Yvx<#z|p3)U6u3#)HG*RhMnyw-C&^L9|vVp9_fiOGD{DmhkcxLnz>^SKo_aW&hM=)c)_ zKQWch+GVG_?8as+yO?*Je7Tg!JBby`ww87hIXhEwE%RDp*|w}ic{!0?$=EsTTEZTP}xjm9H*SkXV725EN3^0R>{q{ z9B(;au$;tIcg0Fvn)h2=>XAWAD^rN&OX$r?)QS>Yd3U9};U;X$aqWDTeOycAOWDGP z$vM=D3;AN+AA!wf!Q!TA0bIn3EwDHji~B=5wQH|!FI#HOHzS;m&h zh^toCWe#V)TrL#KTU=VUTr%@quk&tf0VSEG@}?zAQRiAJyO^o$w(P&mK^>}J0 zg~XEO?-z!}?jozc8Z2&O$#HS+`Ai{EDcdsK(N*@w7RE;t^V4Gsi-U8ciHZ5d%-r;o ziILHf#HGP`q%U1dEKV$pPd{ErpvK%_a^YEGdMq)Re3p1LksP^}7=1c3H#$F`n4XJG zOwCMAj3RF$IXwAzWFq+>F@$!>X@Kqo02rMvOeZ*?-_^t@+KnZqM(2jd(evQY#N@=n zvum-jiG?Kl9Gjj?3?^m<=N2Z0A5RX>C1xJa%}mdaV)zkso193F&0&<$snO&@FGfXP zV)O|TiTUxt$w>|s8+?q}&vAZ<;pv%Ya}y877ZT&slOv6@jiqwa{nO54ccy=2 zYofe6etXU-KYOyXnI5rjZ;j;U?mXR?E#~d1Cj-Mb2Meo{*3k9c*{9a#@?@bjTgfJ^ zhlPdvv-c*qlFvrARt6TF(c!hJg~g}Ck4poaBM%-fR))7{v$Knjrjz#VC$o#weJhir zk7lfgOH1>qJ8o`iaqHoDZhY+S?fd=X&+g_6#Rt~F{qpUhmGRxXGxulLr=Cp?xz3#@ zcV{1zGmmWh=FH^4*5>Bit*t|Azk4fl^4cFwZ_LdnAI#1cDtESu>CFA5vFndVXSXVM zGJQ9<$A%|Yino^5S4TGPWag8L_jguC=BDnh?`+@Pd~|bhVdYkTZZuQMKYmgiSh-ut z-5kHZcys6B^78HVS+_j+^loa`y0tm&tgn_Q+?C4pvBiN>VR>rg=Jj0mVgBAk{)s)c zxc+2%t3Np~wKTk8-r1O)wHGp@gPVh9a@Ad(s_Yc|9_Ch(*T-(&xoOYOJ}Tzi_1kmT zw;o!3_ij(@EZ(}F&fZQ>xJxU~D#JUI+vd7CI@`Q;OvY%SS-6YC+}x+YqQs<((}tnyYJ@cLScC9_T1ee^U+Ak+^7_@Q%_Ub z>l32`&z!Bkr6=R&^!Tm9^{1mtqmRZ0^LO%_Gmn?7^}@5YxvANnfNcOx>;MDsRxOLC z{$a*f@HkfxSQ;cBEoE2%Gi(}W;6oi88a$iQ%!eA{AyWsBj(G8@vbj;P#;pQKbu!ET z`QZ>0sUG2R&LeQOecjbojjFb+SB;Ml-wzyU2W0emvM1;nrIHt$x7IhnR6)1B)}f4> zT}hVhV!!DH=d8)cJpnIX+9;-%t&F;!?AK{p%=ZC z9WOkT2xVmw9;uWmQG_8I0?k#>Gb+Wrr?*v!3)7u zixy-K)9$E^(loNnL5Am%@%Ux-QNz&VVVimNB_gg7Q224)K~Mt}Q-^3A{T5DFK%^xV z1mHv}Y9Pe&eAr~rBj?CEY`HCWYFLsg?&Ltgl$v1hpG zWi?XOS0m32U2~)RE!Zsm0NYn4JXy@KXJWN-@S0+nfVPtObB?g@}7u&igEU1pGM$C(rG9sfC|Eb z-N&ppA5^js`fK1qobS=?3#=rlPY~?KqfkE-)Z0S>mRaQ0KR-L_zmlF_qDnC6x}@;s z(omtCUCaCsD-o2WGO(0!^4UZlh-o$HR?t?;(v-@B5QHT*l~xm_%7*)H9jB9tr-@-a z<9HS&cP3+lKDyBI8nyDJG_=4{$tpNL-Fa8Y06+#>CZ<-?74q!8Mt%~-Or<{1nzeJJ zkE2~=mNBSwi@G-eKsIO%lu-wVSgiLW=v||J^8GF})NpG`Im^KQbaSXtY!-f#wmDVYW^mapGw9gxAcNb`HTZ(Z$)X( zDf?A?*|!)7i6B3*?;!Bv&0sA8i`O!h*`6tv3lzdC__iag7#=Q{Hv3Hx4Md^?$h7`hSJ;CE}R$7Q$jDKA^ zMX)&NB*ShI$d0CJxw{unq&LUe9%oQzd+z{#!Kf75^VyfA!S({0p(%EUbA*ty7PRkF zwSBTEeGhQFA7E@h2-qfck`ach;r%e}gp`#kB2CE@ptk^GK2H6)_aI4RPU&kKWlz0@J zzJXW{4?h3pp91U>h{$rsaL0+7xWA3DO=h1Y%^jU7%x(=X{U zJ!YH%xSqjZyFpNG55xh+36#Y3Gsf=gCj)7N&r`Tb9^>&g`E-N7+XGAr`W)lxMDqze z4#m4@JQ&pU_BRFjPv+}9#kxoj(H-JEDDj6asep*_L7aNnJ982%?;kmVGZ3?cGf>W` zaRw-X5F$lf)L#Z+?WHtqfaWOz6;K3Ch2rN^SCI>@R2eU5*z-5XJ7=%sF{)kEu&7vcI?;!2Sk#J8(!{ig~AeDG(}ZO;`Gk zaOsJ|B9d>wMKzj(p#v)Zy$Cz%Aq^&1fKpVbNmI53b;+sQupwn^RG>I33FAm1ATnWP zvMUK8^1cehE&K1k`5XNF=*!70IRih0v=%&a?kPlS=dSc17jzP*AFPI-lTQats0N^9 z5jV6Sc0-U2ART}%DgGaFqpK}|l(2CB5U612KtG^N(du(bTZj**;`7IVFhTo1bKXTz z1FN_Ks~P4C-{V1f-HKV&Wc9R&mP%LWdhGWZP{vVgu4Tk7SHN+>4m&^vHq(%;OqJzp z2TdIAE}N0fc@e*k0C*}6HLuRHF+GvN4|BiYMJuNpffi{8*CP-|^(Z)A7YN7h`D1WV zhvd3C$2r&e+7%=~-Km2@&e+5L!5CwJW8o}vki!nZmG&Tq2p$S5^aSlWG_n^Mki2?< zY{spj;un#TYK8u@y(8ur$4;_J7&`-rxj~AQW}BHM0<*tS98P z04`vpD#5lDqgbcnw)FCpT|Hw^C-kzr!i_dDrGJhxp+Xusjr{>qbyy%K5}IXwgsSHC z5!YUZWMv(F+kMR|lYjY~u76XnMvx!0N8GR(WZH$1g8b__=@2Bfs_~q9ObdF6@vc|2npV^@ zP(m#rUa!*zXwuT{l)};-H3}xgD~-wR4m4JiD3=O5xXG&U0?WZYCc)=ReR4AfJ=scc zRI94hgAdw1ZTJVMGDJ1j%fN ztVM+(hP>U=jm>CKOtG}0VGT|MOTLDTst!|@P)$%rp_x_a0TRRM`!Uc^KL$FXVvzo| zwHEs!+J;0spsZuuh4z4)TFU@1jwW22h+vxxEC$OAC_UTU4v^v37@t7kMWNI|q;+?^ z0I0(eIv~>LAETN>aMJXk0qGO+XYGk1zT4Hb{u;RrE4)!emr$tE$AnI=#oDgGegU$j z0ET4!vqVpV*1yN2L{Cfz)c2IlTrRo(`E=a>N#ap#sO-80t7K)@KH=f{3*j2-1b}3n zBLwg`zm4Ft2!IQFB#>W!P3y)LM>Ql0Rv5_MB4``|8n+0lLm{j3FHj9tBhR7Or(^Z9 zR@gCVxztwU=!NLtM)dd7Boyr+6dgi#nd%gha7su*mym>DlIM@l%vXT2_SX>WIjzvO zg7{H%4M49ElG+4HB>zpIR7O%^<^VBsKoC~Jn;;!IOixUKTAoc;83^C0Bc;8@JmEkE z#JrFYpp&uj$EYBTw1$n{;7KtEFMK+IkAjQ!)B6dy_~+2u&x(t6d?PN_Fg09@twsb> zSAdtG>A>F4z|?5d_8!4hhdbjkIEu4&k+~%XhxwwQYb(qRifA>X>v5!QY}8MSv^G!v z^Pww2>wn=<=xPYMcG}E6gswz1ew-NubxpCWSyLnrgeDV&e-i8k@CFK*F|!ppVKZ*F z;X7isn;rO$nw{n;d;`&gT8b8hZm)YB*D!OI;|o{nifAEOM`*Gd$})n`yNDb?XtBcp z-EbKJT>(B?K9BEpgR3F%_$W{=EGQR&MTPuZcsH~TBC)pLs(3rvl2gO>NZScw?P%XY zTaIvdWv2@y>-gzg)iyC9*>0}`5gP#GJ1wuXPztJ!y6K zr0zLuExj&WL%8|07=zPF#P-@@iCDAhy~DiXkK>(ECDNTROvmqAB^s#jb7PyE=!He` zf!xVyskpt(0=aUn&C=fC`e@^jMVyTMZ`Gk1%oo=P2HA1g5F4Hmh7LoPy!-k|=iDF! zHLk#;4&fQb<9vjmikoAIUJ%smUQ@?Clb?oK6Bd;we4yj@i0UOpHpIoxl3^y`IE-?=@3@ftK}#^3=f;Tk)ytqDC-tTJ0e{i8^}oBQ zVXlE{jrR5n1}_@oQ(43044D)O>iJ;g=J&Lmt| z=Ubn#&KkjGJNd9}2Ofd!UE~sqpg9_LJ4ppoyeC+Nl!5gJKGwIBE@2Xctg4p`Q}*Ix z{i}EOg4`dJpGSEuH_qFHDsLS=#$E;8nzal&Mog&&&tWX`CtA*4@~(kGCvu4evX>8~ zyn73(Aus>#YiUWKUsB60<^Qt*cbBNwdm z^i?lPH5uj^VBlKnRdQEw@ESj**^z}u;gJ$=f$%tqhz1M{3i}M|7XX)CqS;cPrCHFx zcQP860=?PB30*eW($~H*vSy*uC-@5Z#!#QwM4SHy;6N?g_zIp`tGP1{pv& z&0&o4x^i&CJS^gTj|&oITi<>A#AzIFQ}-u}V(2;vA$#IH1i3B}rA%C6!CZ&%xt;?6 zY+n`BQUHCPYCH^jQimaQIHcaw*56qqDGF@=H`jn!8rQSW6+!wVv@l=# zHlc?72DbH+{|%QJn+LoAQ{S8TwnrI^ApoEYj0@!y40+BXf%dl@#PiBe9z_LO8U89E=N&|d3f)GhejdQytzQvv zR|X!cl-+_7@xVm_`h^oAtfoI13%qWkV1k>#Q}}8)&?s$*QjlUWfev;K= z6O}G;zU_#>{L~$NC3Ys-61^4;*evTYQlX*2`HA6FY-oC6VR956(2uexsDJ7j))2Dl z8dke#SfK8_i&I{h(6Bn>)`%Kbr|4LEL&p-GiWcir-l@ZTk+j@o-8RhoZ{ltMkMl2p zo=5a4+z{6-q(BO2LLrzh)UCGAe6e^^lj2P_Br!OM(WirTfI^Bd)a&5(x?U%SI<1P< zqeUEu)&hX1xkOuSGh;xt)^pmob+=v8@pBqdMEf?gO_hl~t>Xny@&#lx(bqa(Xsf4Q z@IDFXXSPd!U1-zA{`|hWOuw&hY2VUur_J|lqV8$8`vs7n9PqjRZOS6ARD(JF1-K{X zbOz~H*$>knFy|~#T{ln32~2X$08Y@z?5)0EMnGf||qO)&c=NqLy7qDsE3(VB>~x4Cd6vRmOdmy^EkG z4!}uu0q;Kad2#At!@kBPGK3czqB7^j8hV{yyT8sV$=Zm!r}2pKD+g%(YS$iMi24H{ z6K2Ibm;e`yM>xG6m;-L~2d_`40U2^R<72^CM+HdBXN@qUcW=GxbNl_5W$)|JfKB zTTF>A57!;-6Us4%?y?8~pHQgCsY~o1d>QwfR8N$TN&WiPevG`Hca#N|&ILy-u%gE) zjg5p{8S|GJ%Q4^xVrxys;EA2H`qJ6uoY$(1ENS{Hd#y?pQ$nNC*!l!GJk_M1O_b(C z{_CAl^~Il1-lp>9&DKUvU#(LalRjY>8doy%J)ox0sNL^Z81O6HI$B}CuW$>F`W4&e zBm&MxQMrS)`eJMvUa!iPxr6?%aO9?!_llJ<<AB{e1q&hjz zV%}`jmYilkIKa>a%-|5H6DUPT)8Dz)c|K8@2_cu z)alXUfb=^evVgDMm7u@Dy<2H@7eyB>U|dKJufiVV>!VMhkAiE778#`uOP$pIL)B*q zoCn@$k*0Zo(_sAQq4D9>c699Z??WJ#`9p}VX_gno-hi9nnifeTm~X&QT4aq`i0CY= z1tu9HWs_lodeRXyPLuAaT+dQxs2$Z6L4vZN|B?LchVAEfie=Wf@}`lWI-G+ZQl}ai zty(!krtOb8ArZatDr#*^IH@-@(W7Gb0b0BPs_Bp(YSbhFXuMhFy=t@W34b<&y5||x zIEyj@tZ19Y?$r~3QA|WfXDJ(evW4(>FnW`P@H|c();Mxfz~8uwmkL}RwzhiiA!{JL z9i{F8q_({<;R}#LNM(Ix|AU0#+yDH!dj^z8$hXy?Sp3e5#jn-|3K!}t%S08hqQjT% zbK3bL>MrIa*bvo;+EvXYFH%4r8R{bL8;mt>@OwDMaVBVr4`p-v1P_OG!n4@GIV)ZJ9MsDbrk>}eDGM=CDZ zOM}jGKBp{76g^U{L|f&R=O|s8{7j7SO4#F-iWjNX;IXF8(l5_VAHViFdfB`^*hUbgOAa}d zm@!h4i(X*TNP&WD%I3Ehl0G-3&(?aMrt}%F^$7zEyxeM|u4(q4_+v;7KpxC6YMj#9 z!zIeJB%}UNN~Eqz%PiXDPFc-jS7y;(pT&yIqN8aRzaq1cd9G=gmF_3m8^&Jyba#Ni&T)it z>@)V;41~m!nc2U~;BPYcTMYg-gC8*XI}F}raGAk4gR2alF!+ZI{vLxLGWh!p{sDrX zE@gimVa5akTHC1f35Q69cO>CO)V;Bj3U1#_uSmF+YQx*~50dHO>A56+6TmaJ`|$MR z8L+>|o+u!sDVXEU*c5wULR>vR)un;Tsk4N}G}SfzfeP^jw=?v|k~qCU&Ku_q2*!Sz z%;`Mtn2^R@qx!oD+I60P#@RqH0B=Z?!4HTxbst9fiofR4XIja4it8?thI4#NI`ENBCToA}24OUV(19Cnpt5pR;m@Q%e2pH! zC{}kP2jkptInZdr+6%uRX_VV=)N6eqLUxsoo_KTe4CV%Jz&7y8(MbS6a%|$k{O?hQXv>u_vC2tDlqbnfi}7 zNGrS{;gP8vA;lO#9QA@+t$hVi`$MLwXSeERH6y$VLD zl1BA2L1L%%$6p3t|A;qh#_s9yImnFH`|hdZlot}@ zcLgv9c84?hm-w>(6@xefN|BAjkh*_VI|U7D@Xs>*K!5Y_ocHG)o^}7@&pZ6BL#yYO zecs_&_rEp|&-cEIK%I0NUC7Syj#5~V7_I=+A(Jm?Vq=N7jR5rpkqrp1wdbn59OM~g z{5}HxLh$~d^52ijYt6l-lDaZKT}$>ZZkQ4+-WUi5q{g1-=uN7UW_ E3vXz)7ytkO literal 0 HcmV?d00001 diff --git a/model/hrnet_config.py b/model/hrnet_config.py new file mode 100644 index 0000000..6edfcc5 --- /dev/null +++ b/model/hrnet_config.py @@ -0,0 +1,130 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Create by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn), Rainbowsecret (yuyua@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from yacs.config import CfgNode as CN + + +# configs for HRNet48 +HRNET_48 = CN() +HRNET_48.FINAL_CONV_KERNEL = 1 + +HRNET_48.STAGE1 = CN() +HRNET_48.STAGE1.NUM_MODULES = 1 +HRNET_48.STAGE1.NUM_BRANCHES = 1 +HRNET_48.STAGE1.NUM_BLOCKS = [4] +HRNET_48.STAGE1.NUM_CHANNELS = [64] +HRNET_48.STAGE1.BLOCK = 'BOTTLENECK' +HRNET_48.STAGE1.FUSE_METHOD = 'SUM' + +HRNET_48.STAGE2 = CN() +HRNET_48.STAGE2.NUM_MODULES = 1 +HRNET_48.STAGE2.NUM_BRANCHES = 2 +HRNET_48.STAGE2.NUM_BLOCKS = [4, 4] +HRNET_48.STAGE2.NUM_CHANNELS = [48, 96] +HRNET_48.STAGE2.BLOCK = 'BASIC' +HRNET_48.STAGE2.FUSE_METHOD = 'SUM' + +HRNET_48.STAGE3 = CN() +HRNET_48.STAGE3.NUM_MODULES = 4 +HRNET_48.STAGE3.NUM_BRANCHES = 3 +HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4] +HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192] +HRNET_48.STAGE3.BLOCK = 'BASIC' +HRNET_48.STAGE3.FUSE_METHOD = 'SUM' + +HRNET_48.STAGE4 = CN() +HRNET_48.STAGE4.NUM_MODULES = 3 +HRNET_48.STAGE4.NUM_BRANCHES = 4 +HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] +HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384] +HRNET_48.STAGE4.BLOCK = 'BASIC' +HRNET_48.STAGE4.FUSE_METHOD = 'SUM' + + +# configs for HRNet32 +HRNET_32 = CN() +HRNET_32.FINAL_CONV_KERNEL = 1 + +HRNET_32.STAGE1 = CN() +HRNET_32.STAGE1.NUM_MODULES = 1 +HRNET_32.STAGE1.NUM_BRANCHES = 1 +HRNET_32.STAGE1.NUM_BLOCKS = [4] +HRNET_32.STAGE1.NUM_CHANNELS = [64] +HRNET_32.STAGE1.BLOCK = 'BOTTLENECK' +HRNET_32.STAGE1.FUSE_METHOD = 'SUM' + +HRNET_32.STAGE2 = CN() +HRNET_32.STAGE2.NUM_MODULES = 1 +HRNET_32.STAGE2.NUM_BRANCHES = 2 +HRNET_32.STAGE2.NUM_BLOCKS = [4, 4] +HRNET_32.STAGE2.NUM_CHANNELS = [32, 64] +HRNET_32.STAGE2.BLOCK = 'BASIC' +HRNET_32.STAGE2.FUSE_METHOD = 'SUM' + +HRNET_32.STAGE3 = CN() +HRNET_32.STAGE3.NUM_MODULES = 4 +HRNET_32.STAGE3.NUM_BRANCHES = 3 +HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4] +HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128] +HRNET_32.STAGE3.BLOCK = 'BASIC' +HRNET_32.STAGE3.FUSE_METHOD = 'SUM' + +HRNET_32.STAGE4 = CN() +HRNET_32.STAGE4.NUM_MODULES = 3 +HRNET_32.STAGE4.NUM_BRANCHES = 4 +HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] +HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] +HRNET_32.STAGE4.BLOCK = 'BASIC' +HRNET_32.STAGE4.FUSE_METHOD = 'SUM' + + +# configs for HRNet18 +HRNET_18 = CN() +HRNET_18.FINAL_CONV_KERNEL = 1 + +HRNET_18.STAGE1 = CN() +HRNET_18.STAGE1.NUM_MODULES = 1 +HRNET_18.STAGE1.NUM_BRANCHES = 1 +HRNET_18.STAGE1.NUM_BLOCKS = [4] +HRNET_18.STAGE1.NUM_CHANNELS = [64] +HRNET_18.STAGE1.BLOCK = 'BOTTLENECK' +HRNET_18.STAGE1.FUSE_METHOD = 'SUM' + +HRNET_18.STAGE2 = CN() +HRNET_18.STAGE2.NUM_MODULES = 1 +HRNET_18.STAGE2.NUM_BRANCHES = 2 +HRNET_18.STAGE2.NUM_BLOCKS = [4, 4] +HRNET_18.STAGE2.NUM_CHANNELS = [18, 36] +HRNET_18.STAGE2.BLOCK = 'BASIC' +HRNET_18.STAGE2.FUSE_METHOD = 'SUM' + +HRNET_18.STAGE3 = CN() +HRNET_18.STAGE3.NUM_MODULES = 4 +HRNET_18.STAGE3.NUM_BRANCHES = 3 +HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4] +HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72] +HRNET_18.STAGE3.BLOCK = 'BASIC' +HRNET_18.STAGE3.FUSE_METHOD = 'SUM' + +HRNET_18.STAGE4 = CN() +HRNET_18.STAGE4.NUM_MODULES = 3 +HRNET_18.STAGE4.NUM_BRANCHES = 4 +HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] +HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144] +HRNET_18.STAGE4.BLOCK = 'BASIC' +HRNET_18.STAGE4.FUSE_METHOD = 'SUM' + + +MODEL_CONFIGS = { + 'hrnet18': HRNET_18, + 'hrnet32': HRNET_32, + 'hrnet48': HRNET_48, +} \ No newline at end of file diff --git a/model/seg_hrnet.py b/model/seg_hrnet.py new file mode 100644 index 0000000..d2a8590 --- /dev/null +++ b/model/seg_hrnet.py @@ -0,0 +1,750 @@ +""" +MIT License +Copyright (c) 2019 Microsoft +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import os +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +logger = logging.getLogger('hrnet_backbone') + +__all__ = ['hrnet18', 'hrnet32', 'hrnet48'] + + +model_urls = { + # all the checkpoints come from https://github.com/HRNet/HRNet-Image-Classification + 'hrnet18': 'https://opr0mq.dm.files.1drv.com/y4mIoWpP2n-LUohHHANpC0jrOixm1FZgO2OsUtP2DwIozH5RsoYVyv_De5wDgR6XuQmirMV3C0AljLeB-zQXevfLlnQpcNeJlT9Q8LwNYDwh3TsECkMTWXCUn3vDGJWpCxQcQWKONr5VQWO1hLEKPeJbbSZ6tgbWwJHgHF7592HY7ilmGe39o5BhHz7P9QqMYLBts6V7QGoaKrr0PL3wvvR4w', + 'hrnet32': 'https://opr74a.dm.files.1drv.com/y4mKOuRSNGQQlp6wm_a9bF-UEQwp6a10xFCLhm4bqjDu6aSNW9yhDRM7qyx0vK0WTh42gEaniUVm3h7pg0H-W0yJff5qQtoAX7Zze4vOsqjoIthp-FW3nlfMD0-gcJi8IiVrMWqVOw2N3MbCud6uQQrTaEAvAdNjtjMpym1JghN-F060rSQKmgtq5R-wJe185IyW4-_c5_ItbhYpCyLxdqdEQ', + 'hrnet48': 'https://optgaw.dm.files.1drv.com/y4mWNpya38VArcDInoPaL7GfPMgcop92G6YRkabO1QTSWkCbo7djk8BFZ6LK_KHHIYE8wqeSAChU58NVFOZEvqFaoz392OgcyBrq_f8XGkusQep_oQsuQ7DPQCUrdLwyze_NlsyDGWot0L9agkQ-M_SfNr10ETlCF5R7BdKDZdupmcMXZc-IE3Ysw1bVHdOH4l-XEbEKFAi6ivPUbeqlYkRMQ' +} + +class ModuleHelper: + + @staticmethod + def BNReLU(num_features, bn_type=None, **kwargs): + return nn.Sequential( + nn.BatchNorm2d(num_features, **kwargs), + nn.ReLU() + ) + + @staticmethod + def BatchNorm2d(*args, **kwargs): + return nn.BatchNorm2d +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + def __init__(self, cls_num=0, scale=1): + super(SpatialGather_Module, self).__init__() + self.cls_num = cls_num + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw + # print(probs.shape,feats.shape) + ocr_context = torch.matmul(probs, feats)\ + .permute(0, 2, 1).unsqueeze(3)# batch x k x c + return ocr_context + + +class _ObjectAttentionBlock(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + bn_type : specify the bn type + Return: + N X C X H X W + ''' + def __init__(self, + in_channels, + key_channels, + scale=1, + bn_type=None): + super(_ObjectAttentionBlock, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_object = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_down = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_up = nn.Sequential( + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type), + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=True) + + return context + +class ObjectAttentionBlock2D(_ObjectAttentionBlock): + def __init__(self, + in_channels, + key_channels, + scale=1, + bn_type=None): + super(ObjectAttentionBlock2D, self).__init__(in_channels, + key_channels, + scale, + bn_type=bn_type) + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + """ + def __init__(self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + bn_type=None): + super(SpatialOCR_Module, self).__init__() + self.object_context_block = ObjectAttentionBlock2D(in_channels, + key_channels, + scale, + bn_type) + _in_channels = 2 * in_channels + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), + ModuleHelper.BNReLU(out_channels, bn_type=bn_type), + nn.Dropout2d(dropout) + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + return output + + + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True, norm_layer=None): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.norm_layer = norm_layer + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(num_channels[branch_index] * block.expansion), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample, norm_layer=self.norm_layer)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + self.norm_layer(num_inchannels[i]))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + self.norm_layer(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + self.norm_layer(num_outchannels_conv3x3), + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode='bilinear', + align_corners=True + ) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, + cfg, + norm_layer=None): + super(HighResolutionNet, self).__init__() + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.norm_layer = norm_layer + # stem network + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = self.norm_layer(64) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = self.norm_layer(64) + self.relu = nn.ReLU(inplace=True) + + # stage 1 + self.stage1_cfg = cfg['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion*num_channels + + # stage 2 + self.stage2_cfg = cfg['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # stage 3 + self.stage3_cfg = cfg['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + # stage 4 + self.stage4_cfg = cfg['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + last_inp_channels = np.int(np.sum(pre_stage_channels)) + + MID_CHANNELS = 512 + KEY_CHANNELS = 256 + last_inp_channels = np.int(np.sum(pre_stage_channels)) + ocr_mid_channels = MID_CHANNELS + ocr_key_channels = KEY_CHANNELS + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(last_inp_channels, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(ocr_mid_channels), + nn.ReLU(inplace=True), + ) + self.ocr_gather_head = SpatialGather_Module(9) + + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + ) + self.cls_head = nn.Conv2d( + ocr_mid_channels, 2, kernel_size=1, stride=1, padding=0, bias=True) + + self.aux_head = nn.Sequential( + nn.Conv2d(last_inp_channels, last_inp_channels, + kernel_size=1, stride=1, padding=0), + norm_layer(last_inp_channels), + nn.ReLU(inplace=True), + nn.Conv2d(last_inp_channels,2, + kernel_size=1, stride=1, padding=0, bias=True) + ) + + + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + self.norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + self.norm_layer(outchannels), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample, norm_layer=self.norm_layer)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + norm_layer=self.norm_layer) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + + def forward(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + if i < self.stage2_cfg['NUM_BRANCHES']: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + if i < self.stage3_cfg['NUM_BRANCHES']: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + outputs = {} + # See note [TorchScript super()] + outputs['res2'] = x[0] # 1/4 + outputs['res3'] = x[1] # 1/8 + outputs['res4'] = x[2] # 1/16 + outputs['res5'] = x[3] # 1/32 + x0_h, x0_w = x[0].size(2), x[0].size(3) + ALIGN_CORNERS = True + x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) + x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) + x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) + + feats = torch.cat([x[0], x1, x2, x3], 1) + + + out_aux_seg = [] + + # ocr + out_aux = self.aux_head(feats) + # compute contrast feature + feats = self.conv3x3_ocr(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + + out = self.cls_head(feats) + + out_aux_seg.append(out_aux) + out_aux_seg.append(out) + + return out_aux_seg + + +def _hrnet(arch, pretrained, progress, **kwargs): + try: + from .hrnet_config import MODEL_CONFIGS + except ImportError: + from .hrnet_config import MODEL_CONFIGS + model = HighResolutionNet(MODEL_CONFIGS[arch], **kwargs) + if pretrained: + # if int(os.environ.get("mapillary_pretrain", 0)): + # logger.info("load the mapillary pretrained hrnet-w48 weights.") + # model_url = model_urls['hrnet48_mapillary_pretrain'] + # else: + # model_url = model_urls[arch] + + + pretrained_dict = torch.load('./pre-trained_weights/hrnetv2_w18_imagenet_pretrained.pth') + print('=> loading pretrained model {}'.format(pretrained)) + model_dict = model.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + # for k, _ in pretrained_dict.items(): + # print('=> loading {} pretrained model {}'.format(k, pretrained)) + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + # print(model.conv1.weight[0,0]) + return model + + +def hrnet18(pretrained=False, progress=True, **kwargs): + r"""HRNet-18 model + """ + return _hrnet('hrnet18', pretrained, progress, + **kwargs) + + +def hrnet32(pretrained=False, progress=True, **kwargs): + r"""HRNet-32 model + """ + return _hrnet('hrnet32', pretrained, progress, + **kwargs) + + +def hrnet48(pretrained=False, progress=True, **kwargs): + r"""HRNet-48 model + """ + return _hrnet('hrnet48', pretrained, progress, + **kwargs) +if __name__ == '__main__': + model = hrnet32(pretrained=True) + # print(model) + input = torch.randn([2,3,512,512],dtype=torch.float) + output = model(input) + print(output[0].shape) diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..9362ed4 --- /dev/null +++ b/run.sh @@ -0,0 +1,2 @@ +python cut_data.py +CUDA_VISIBLE_DEVICES=0 python train.py --backbone=hrnet --batchsize=4 --lr=0.01 --num_epochs=150 diff --git a/train.py b/train.py new file mode 100644 index 0000000..a78fb04 --- /dev/null +++ b/train.py @@ -0,0 +1,265 @@ +#coding=utf-8 +from collections import defaultdict +import torch.nn.functional as F +from loss import calc_loss,calc_smoothloss +import time +import torch +from torch.utils.data import Dataset, DataLoader +from utils.lr_scheduler import adjust_learning_rate_poly +from utils.ema import WeightEMA +from utils.label2color import label_img_to_color, diff_label_img_to_color +from data.make_data import GaofenTrain, GaofenVal +from tqdm import tqdm +from evaluate import Evaluator +import numpy as np +import argparse +import matplotlib.pyplot as plt +import itertools +import torch.nn as nn +import cv2 +import os +from tensorboardX import SummaryWriter +from model.seg_hrnet import hrnet18 + +parser = argparse.ArgumentParser() +parser.add_argument('--root', default='./data') +parser.add_argument('--train_list_path', default='./data/train.txt') +parser.add_argument('--val_list_path', default='./data/val.txt') +# parser.add_argument('--test_list_path', default='./data/val_1w.txt') +parser.add_argument('--backbone', default='resnest50', type=str,help='xception|resnet|resnest101|resnest200|resnest50|resnest26') +parser.add_argument('--n_cls', default=2, type=int) +parser.add_argument('--batchsize', default=4, type=int) +parser.add_argument('--lr', default=0.01, type=float) +parser.add_argument('--num_epochs', default=150, type=int) +parser.add_argument('--warmup', default=100, type=int) +parser.add_argument('--multiplier', default=100, type=int) +parser.add_argument('--eta_min', default=0.0005, type=float) +parser.add_argument('--num_workers', default=8, type=int) +parser.add_argument('--decay_rate', default=0.8, type=float) +parser.add_argument('--decay_epoch', default=200, type=int) +parser.add_argument('--vis_frequency', default=30, type=int) +parser.add_argument('--save_path', default='./results') +parser.add_argument('--gpu-id', default='0,1', type=str, + help='id(s) for CUDA_VISIBLE_DEVICES') +parser.add_argument('--is_resume', default=False, type=bool) +parser.add_argument('--resume', default='', type=str, help='./results/checkpoint.pth') +args = parser.parse_args() + + +folder_path = '/backbone={}/warmup={}_lr={}_multiplier={}_eta_min={}_num_epochs={}_batchsize={}'.format(args.backbone,args.warmup,args.lr, args.multiplier, args.eta_min, args.num_epochs,args.batchsize) +isExists = os.path.exists(args.save_path + folder_path) +if not isExists: + os.makedirs(args.save_path + folder_path) + +isExists = os.path.exists(args.save_path +folder_path+'/vis') +if not isExists: os.makedirs(args.save_path + folder_path+'/vis') + +isExists = os.path.exists(args.save_path +folder_path+'/log') +if not isExists: os.makedirs(args.save_path + folder_path+'/log') + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +global_step = 0 +def train_model(): + global global_step + F_txt = open('./opt_results.txt', 'w') + evaluator = Evaluator(args.n_cls) + classes = ['road', 'others'] + writer = SummaryWriter(args.save_path + folder_path + '/log') + def create_model(ema=False): + model = hrnet18(pretrained=True).to(device) + if ema: + for param in model.parameters(): + param.detach_() + return model + model = create_model() + ema_model = create_model(ema=True) + # model = hrnet18(pretrained=True).to(device) + # model = nn.DataParallel(model) + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001) + ema_optimizer = WeightEMA(model, ema_model, alpha=0.999) + best_miou = 0. + best_AA = 0. + best_OA = 0. + best_loss = 0. + lr = args.lr + epoch_index = 0 + if args.is_resume: + args.resume = args.save_path + folder_path+'/checkpoint_fwiou.pth' + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume) + epoch_index = checkpoint['epoch'] + best_miou = checkpoint['miou'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + lr = optimizer.param_groups[0]['lr'] + print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) + F_txt.write("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])+'\n') + # print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), file=F_txt) + else: + print('EORRO: No such file!!!!!') + + TRAIN_DATA_DIRECTORY = args.root # '/media/ws/www/IGARSS' + TRAIN_DATA_LIST_PATH = args.train_list_path # '/media/ws/www/unet_1/data/train.txt' + + VAL_DATA_DIRECTORY = args.root # '/media/ws/www/IGARSS' + VAL_DATA_LIST_PATH = args.val_list_path # '/media/ws/www/unet_1/data/train.txt' + + + dataloaders = { + "train": DataLoader(GaofenTrain(TRAIN_DATA_DIRECTORY, TRAIN_DATA_LIST_PATH), batch_size=args.batchsize, + shuffle=True, num_workers=args.num_workers,pin_memory=True,drop_last=True), + "val": DataLoader(GaofenVal(VAL_DATA_DIRECTORY, VAL_DATA_LIST_PATH), batch_size=args.batchsize, + num_workers=args.num_workers,pin_memory=True) + } + + evaluator.reset() + print('config: ' + folder_path) + print('config: ' + folder_path, file=F_txt,flush=True) + for epoch in range(epoch_index, args.num_epochs): + print('Epoch [{}]/[{}] lr={:6f}'.format(epoch + 1, args.num_epochs, lr)) + # F_txt.write('Epoch [{}]/[{}] lr={:6f}'.format(epoch + 1, args.num_epochs, lr)+'\n',flush=True) + print('Epoch [{}]/[{}] lr={:4f}'.format(epoch + 1, args.num_epochs, lr), file=F_txt,flush=True) + since = time.time() + + # Each epoch has a training and validation phase + for phase in ['train', 'val']: + evaluator.reset() + if phase == 'train': + model.train() # Set model to training mode + else: + ema_model.eval() + model.eval() # Set model to evaluate mode + + metrics = defaultdict(float) + epoch_samples = 0 + + for i, (inputs, labels,edge, _, datafiles) in enumerate(tqdm(dataloaders[phase],ncols=50)): + inputs = inputs.to(device) + edge = edge.to(device, dtype = torch.float) + labels = labels.to(device, dtype=torch.long) + + optimizer.zero_grad() + + with torch.set_grad_enabled(phase == 'train'): + if phase == 'train': + outputs = model(inputs) + outputs[1] = F.interpolate(input=outputs[1], size=( + labels.shape[1],labels.shape[2]), mode='bilinear', align_corners=True) + loss = calc_loss(outputs, labels, edge, metrics) + pred = outputs[1].data.cpu().numpy() + pred = np.argmax(pred, axis=1) + labels = labels.data.cpu().numpy() + evaluator.add_batch(labels, pred) + if phase == 'val': + outputs = ema_model(inputs) + outputs[1] = F.interpolate(input=outputs[1], size=( + labels.shape[1],labels.shape[2]), mode='bilinear', align_corners=True) + loss = calc_loss(outputs, labels, edge, metrics) + pred = outputs[1].data.cpu().numpy() + pred = np.argmax(pred, axis=1) + labels = labels.data.cpu().numpy() + evaluator.add_batch(labels, pred) + if phase == 'val' and (epoch+1)%args.vis_frequency==0 and inputs.shape[0]==args.batchsize: + for k in range(args.batchsize//2): + name = datafiles['name'][k][:-4] + + writer.add_image('{}/img'.format(name),cv2.cvtColor(cv2.imread(datafiles["img"][k], cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB),global_step=int((epoch+1)),dataformats='HWC') + + writer.add_image('{}/gt'.format(name), label_img_to_color(labels[k])[:,:,::-1],global_step=int((epoch+1)),dataformats='HWC') + + pred_label_img = pred.astype(np.uint8)[k] + pred_label_img_color = label_img_to_color(pred_label_img) + writer.add_image('{}/mask'.format(name),pred_label_img_color[:,:,::-1],global_step=int((epoch+1)),dataformats='HWC') + + softmax_pred = F.softmax(outputs[1][k],dim=0) + softmax_pred_np = softmax_pred.data.cpu().numpy() + probility = softmax_pred_np[0] + probility = probility*255 + probility = probility.astype(np.uint8) + probility = cv2.applyColorMap(probility,cv2.COLORMAP_HOT) + writer.add_image('{}/prob'.format(name),cv2.cvtColor(probility,cv2.COLOR_BGR2RGB),global_step=int((epoch+1)),dataformats='HWC') + # 差分图 + diff_img = np.ones((pred_label_img.shape[0], pred_label_img.shape[1]), dtype=np.int32)*255 + mask = (labels[k] != pred_label_img) + diff_img[mask] = labels[k][mask] + diff_img_color = diff_label_img_to_color(diff_img) + writer.add_image('{}/different_image'.format(name), diff_img_color[:, :, ::-1], + global_step=int((epoch + 1)), dataformats='HWC') + if phase == 'train': + loss.backward() + global_step += 1 + optimizer.step() + ema_optimizer.step() + adjust_learning_rate_poly(args.lr,optimizer, epoch * len(dataloaders['train']) + i, + args.num_epochs * len(dataloaders['train'])) + lr = optimizer.param_groups[0]['lr'] + writer.add_scalar('lr', lr, global_step=epoch * len(dataloaders['train']) + i) + epoch_samples += 1 + epoch_loss = metrics['loss'] / epoch_samples + ce_loss = metrics['ce_loss'] / epoch_samples + ls_loss = metrics['ls_loss'] / epoch_samples + miou = evaluator.Mean_Intersection_over_Union() + AA = evaluator.Pixel_Accuracy_Class() + OA = evaluator.Pixel_Accuracy() + confusion_matrix = evaluator.confusion_matrix + if phase == 'val': + miou_mat = evaluator.Mean_Intersection_over_Union_test() + writer.add_scalar('val/val_loss', epoch_loss, global_step=epoch) + writer.add_scalar('val/ce_loss', ce_loss, global_step=epoch) + writer.add_scalar('val/ls_loss', ls_loss, global_step=epoch) + #writer.add_scalar('val/val_fwiou', fwiou, global_step=epoch) + writer.add_scalar('val/val_miou', miou, global_step=epoch) + for index in range(args.n_cls): + writer.add_scalar('class/{}'.format(index+1), miou_mat[index], global_step=epoch) + + print( + '[val]------miou: {:4f}, OA:{:4f}, AA: {:4f}, loss: {:4f}'.format( miou, OA, AA, + epoch_loss)) + print( + '[val]------miou: {:4f}, OA:{:4f}, AA: {:4f}, loss: {:4f}'.format(miou, OA, AA, + epoch_loss), + file=F_txt,flush=True) + if phase == 'train': + writer.add_scalar('train/train_loss', epoch_loss, global_step=epoch) + writer.add_scalar('train/ce_loss', ce_loss, global_step=epoch) + writer.add_scalar('train/ls_loss', ls_loss, global_step=epoch) + #writer.add_scalar('train/train_fwiou', fwiou, global_step=epoch) + writer.add_scalar('train/train_miou', miou, global_step=epoch) + print( + '[train]------miou: {:4f}, OA: {:4f}, AA: {:4f}, loss: {:4f}'.format( miou, OA, + AA, epoch_loss)) + print( + '[train]------miou: {:4f}, OA: {:4f}, AA: {:4f}, loss: {:4f}'.format(miou, OA, + AA, epoch_loss), + file=F_txt,flush=True) + + if phase == 'val' and miou > best_miou: + print("\33[91msaving best model miou\33[0m") + print("saving best model miou", file=F_txt,flush=True) + best_miou = miou + best_OA = OA + best_AA = AA + best_loss = epoch_loss + torch.save({ + 'name': 'resnest50_lovasz_edge_rotate', + 'epoch': epoch + 1, + 'state_dict': ema_model.state_dict(), + 'best_miou': best_miou + }, args.save_path + folder_path+'/model_best.pth') + torch.save({ + 'optimizer': optimizer.state_dict(), + }, args.save_path + folder_path+'/optimizer.pth') + time_elapsed = time.time() - since + print('{:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed%60)) + print('{:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed%60),file=F_txt,flush=True) + + print('[Best val]------miou: {:4f}; OA: {:4f}; AA: {:4f}; loss: {:4f}'.format(best_miou, + best_OA, best_AA, + best_loss)) + print('[Best val]------miou: {:4f}; OA: {:4f}; AA: {:4f}; loss: {:4f}'.format(best_miou, + best_OA, best_AA, + best_loss),file=F_txt,flush=True) + F_txt.close() +if __name__ == '__main__': + train_model() diff --git a/utils/__pycache__/ema.cpython-36.pyc b/utils/__pycache__/ema.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5df405403a8d9db3498c97fea7e9191e0586b3fe GIT binary patch literal 946 zcmZuwJ#W-77`F3uxwNHlR8>_W78r<-E`qLz7EwMHPDl)uP9V#T-Cja+m$MyctK4*z z_$~YyCKf!fFd!z_keGPiTuVXXwcbaM?e`;o>};*9L^p4IWS<>EzL3An1HJ>UuK@@m zXhDYL2@#HPPl<3Pox4Lv&=(~3zCi>Tq=@I{0@(rA*8wCMQb8cP(`M75D?H(!lA$L8 z$PaCIPiBWB^XS=wueaZ=fBNwA`*w=`(E@O+_7GePfQB}Nk;Va}jSJ`)x1qT=C%Hc- zbKDJJ#rbEYb9Bpu4ao!W&fM80uhw5WGjHZMUf#{4#?QNRO3C!`9-I}vi4Ni=>z5ND z3)__?XAA1{q8f4Q6`3|R(uNz!M3x#Gyx~PHwGAq+c&TlKIcMBG$tvp!Zn*W+Np;NP zzW585}}cupirTw>`RA zhM9^GmH@0T%yA{HZzd`o+0H>R;b!Zu^~$M|C& zf`HO3;3_bX+{g;G(AbC2h;Gnj7$=u5#K4e2Ocg%3*hd}5yp)XDh~W#U3*^g;O>15( zcGMb#s!fE8el5eM3jV70^=ec!q*2ta#I9P0xeX@!xlB#F_P=x+-dl9NkC&sdSm^u) Do9o;{ literal 0 HcmV?d00001 diff --git a/utils/__pycache__/label2color.cpython-35.pyc b/utils/__pycache__/label2color.cpython-35.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7bbd2736fdb76e212061648ded826196f2a1b13 GIT binary patch literal 933 zcmah{&ubGw6n^_-wrxz>+S(S6q6e2k*Gdr&MU<-4ONACGh!WQ9CfT~lt}`2ABmqHB zqF0aN)tfiLi~q-5r9%Dy&-%StqbVXTZ@&HJ?VI=Byl+=)waW0q=lU0$=qpV<*wb|^ zbrBH3EJ}&)TNK$8IrI`7#HC0W?r^|??~M^ndku?Uc)VY02XFNnfT-Wr!Xc9k`C!;4 zHjOMA5*-me0EanLSaf955QZ?9nm}Qk4!y9fkxNfEsX!whwsiC1c^Z|U1r3o)&{?>L zE@#E&u}#}gLOO)wz@OQZ-00ZA4`vQPi>Mgn@t7yzds?V=n8(`g=2|?9Wu~;~CEenx z7BcL%Vl6@`!vO-Fq|?&#Y4|8k>qS;?WN9XK1OQPD1G8R0Zi^?~Fv zs(rr=2s&|;gh5{g-`;*ae)azN&8MJ02pU-w2fO!f9NxMU?C!7M-PvDn$Sl9Ibt&i- zNvZ-<__E2lkq`7z*$9`MG{cuu^8m6JtQD+t)-sk{f#-x`vv3&2-hr@OP{|l8Hf6lU z;3tx@fX@E}6h=Vrq<~u{tx1I>QyN0He`u)wXvp^(KL;2!CF>?#1_AJeoMFy)l5-3d zhAP1KWQ{qGL@qEaGMr{uVqi^5I0L5FqNLfZ|2bu!sip6nWYzN6w{Q zB+bbYl}z$)%HFoTRH6a#!dbOHShbTbB!6x)o$YG@GoIcsS{j>!ks zvX)o$m^G~7(wkz=rTETT@AH$&ll~bUl8^hX0A4h-5%=4qK(wk)mqlIg-&QLAP3$_kw>XPFhL*69i+ zaW#UZwIV yTyJnGYIvz)h$`*@7kO4J>qWWhj)1~Ceq(o6$CU^U7KN@_|#gO>|faP1%esk58_ literal 0 HcmV?d00001 diff --git a/utils/__pycache__/lr_scheduler.cpython-35.pyc b/utils/__pycache__/lr_scheduler.cpython-35.pyc new file mode 100644 index 0000000000000000000000000000000000000000..389e61043ed8581c6231f397f3f7ca9c8a5bcc33 GIT binary patch literal 2158 zcmb_d&2QX96n|rTKeK5=^Ff+W0*H$*yNwVkhbZNvP%6+?Vv8cJs7A(iHg3IZXU5w! ztK?J^5E2M3oZ-lg8~+6S1HPgnRpP*f3nzYWyt~=Z4=zkHv(Im4{C@L$Z+>2xpRa#? zxNQmrJ+P?@k3&I;on4!PG6fZ5 z!&#+{TYH4}kSp?r{P3fWzZV9>I8;76C#^I#^bHsxUL!a#zgW=cM7wx2VNU=85vI{# zU~0O?L_V>oTjE_Ex0KSf5CBquwc}dt_uO(`?+@cFO5-S0`COJ{UM#~j@%Qrc-fk8usVC>p ztGm%4(qJPm>0yfHRVm@gq;yNUqr-SNuO(>~^`lX!lmlpr!78MS$*f7~$DY#)RePRro^vy}OO~Ox##SUEq$-LPjiHzH9BMXHGSwJp2<0`?lMQ8R+W=fCGxaIZ~ zEY|>^b`f1)!q7tn5Zzv+V^ly9o@nV?i}uUZQS(^g3X>Jr%B|+*yz;51Lm4Z0nJ}|L zorCpveJhT1<|qAho1Zm@s?likDLnSV46kW!#fk65dK<=M_pSd0tvb)}7BD;pRXqnU z_c6vZ@woU5JFmeQv*Vs&2iIVglOkxT$jv%D&rdjW1vYW}^Qvnb=JGlUI4kalscqN# zFK`#Ie+|*3suIOZCd4-!3xVv293c= zi)+6Sn+|q%T+zny`MgFBAQxGf7EDyx$!I__pDr&I`cvtLnTmW33L3nGE7;VhT83p3 z%gdREhFG-LtkL>Y7~oh?I>^<)>xWY2bt(G^i#qIvl>5-zV#a8Q;n@T!Vvetl0b%km zp-{83HNQ&fnx71zD66NnT&1a`M`$14lg~^IRX*SFRJms&tRg}UuXTfB2!CJZ4 Gu>S$SW343s literal 0 HcmV?d00001 diff --git a/utils/__pycache__/lr_scheduler.cpython-36.pyc b/utils/__pycache__/lr_scheduler.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4600023e0e65946b2e55fc90457711b931f482b3 GIT binary patch literal 1993 zcmbtV&u<$=6rR~%UOP@2Xb7$fR1K=c_F%UHl}nUXB5DyrK9EY4ELEfJ&c<1HH{Q&w z)i&Ck9I3tXXYjwkjniB?RpP+8C%!jpr!G)9FxI?z^JC||_q{itzg%DM-1z)gu@^A* z8(V1}^n2*~7KCJy=d9o(&e;Pd9qFDk>85<{jvVPp|D26n8ORn!Plg99YQM%;EaIjs zzW?DMIZEX$PgO#T%SHyM>8~JCHsX@uvhXE4Z_Xxu&F6?vGF6azwh&1kYc0f2?9JDMB9&P@_;!DPkgJoy(JX$JW&`Y3 zgD07ev)hANmE{_nZ$0RjFH{?s(bc*G!FaIvonHHAocc=_6He#RXe6vGUPn+~u?cnl z)Kd>JI_e&ztMAx^AW`Qb=H5P=bJ`W2bYB5Y(kK2LINEdG>bb?5|L$583@@CZgi#ci z(e*aOnAPm@#u=~42-{*mIw#%%Q|s7@xWzJZjMKkmT2Eq~id=G94x1MDoZrt?9ve`OIvj6BWN~hbew_>`G&LQoL+rTm}D!0 z1I}<5wTM8D<02Kpbc86T#Guf-LY#o{njQP#EvD9%SzoY7lIFP(jVFs&VqLQ-SfZzi zdf-tIwryC+l96eo{tdeBLDZoCiV;XP>Z=F(yo?I7B>9>c1N`M5=F_-}JX_i($V3Ku zG+_}xZ#2K0J_GMrY%TyYEf82a&0l;%mgq?`uC7B~xI!*c>04KBLV+8^;jZL!@63C? zUGoXwXOFvQKG+ptkh93YFU=f zm0H8~>N678NbEqEwUxtvb^b~$Co?>yJdM?Hc03j;u2LHgeXprp7~eqFyIkFXK8*ZP rubQglsBixW`m5}%D0-U)>!tl*#o$}W!ri76_6T^J!mhJ self.total_epoch: + if self.after_scheduler: + if not self.finished: + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] + self.finished = True + return self.after_scheduler.get_lr() + return [base_lr * self.multiplier for base_lr in self.base_lrs] + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] + def step(self, epoch=None, metrics=None): + if self.finished and self.after_scheduler: + if epoch is None: + self.after_scheduler.step(None) + else: + self.after_scheduler.step(epoch - self.total_epoch) + else: + return super(GradualWarmupScheduler, self).step(epoch) + +def lr_poly(base_lr,i_iter,max_iter,power): + return base_lr*((1-float(i_iter)/max_iter)**(power)) + +def adjust_learning_rate_poly(init_lr,optimizer,i_iter,max_iter): + lr = lr_poly(init_lr,i_iter,max_iter,0.9) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr +