From ef6f62d998848614f84fa7b1dd78d8830f67c944 Mon Sep 17 00:00:00 2001 From: reacher-l <45810596+reacher-l@users.noreply.github.com> Date: Thu, 8 Apr 2021 23:44:26 +0800 Subject: [PATCH] Add files via upload --- model/fpn.py | 23 ++ model/losses/__init__.py | 50 ++++ .../__pycache__/__init__.cpython-36.pyc | Bin 0 -> 2397 bytes .../__pycache__/pseudo_ce_loss.cpython-36.pyc | Bin 0 -> 955 bytes model/losses/pseudo_ce_loss.py | 16 ++ model/optim/__init__.py | 78 ++++++ .../optim/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 2228 bytes .../optim/__pycache__/cyclicLR.cpython-36.pyc | Bin 0 -> 5661 bytes .../__pycache__/lookahead.cpython-36.pyc | Bin 0 -> 3673 bytes model/optim/__pycache__/radam.cpython-36.pyc | Bin 0 -> 7054 bytes .../warmup_scheduler.cpython-36.pyc | Bin 0 -> 3222 bytes model/optim/cyclicLR.py | 125 +++++++++ model/optim/lookahead.py | 100 +++++++ model/optim/radam.py | 250 ++++++++++++++++++ model/optim/warmup_scheduler.py | 65 +++++ model/tools/Balanced_DataParallel.py | 112 ++++++++ model/tools/__init__.py | 2 + .../Balanced_DataParallel.cpython-36.pyc | Bin 0 -> 3627 bytes .../tools/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 253 bytes model/tools/__pycache__/metric.cpython-36.pyc | Bin 0 -> 3181 bytes .../__pycache__/split_weights.cpython-36.pyc | Bin 0 -> 987 bytes model/tools/metric.py | 75 ++++++ model/tools/split_weights.py | 34 +++ model/unet.py | 22 ++ 24 files changed, 952 insertions(+) create mode 100644 model/fpn.py create mode 100644 model/losses/__init__.py create mode 100644 model/losses/__pycache__/__init__.cpython-36.pyc create mode 100644 model/losses/__pycache__/pseudo_ce_loss.cpython-36.pyc create mode 100644 model/losses/pseudo_ce_loss.py create mode 100644 model/optim/__init__.py create mode 100644 model/optim/__pycache__/__init__.cpython-36.pyc create mode 100644 model/optim/__pycache__/cyclicLR.cpython-36.pyc create mode 100644 model/optim/__pycache__/lookahead.cpython-36.pyc create mode 100644 model/optim/__pycache__/radam.cpython-36.pyc create mode 100644 model/optim/__pycache__/warmup_scheduler.cpython-36.pyc create mode 100644 model/optim/cyclicLR.py create mode 100644 model/optim/lookahead.py create mode 100644 model/optim/radam.py create mode 100644 model/optim/warmup_scheduler.py create mode 100644 model/tools/Balanced_DataParallel.py create mode 100644 model/tools/__init__.py create mode 100644 model/tools/__pycache__/Balanced_DataParallel.cpython-36.pyc create mode 100644 model/tools/__pycache__/__init__.cpython-36.pyc create mode 100644 model/tools/__pycache__/metric.cpython-36.pyc create mode 100644 model/tools/__pycache__/split_weights.cpython-36.pyc create mode 100644 model/tools/metric.py create mode 100644 model/tools/split_weights.py create mode 100644 model/unet.py diff --git a/model/fpn.py b/model/fpn.py new file mode 100644 index 0000000..d5bdd9c --- /dev/null +++ b/model/fpn.py @@ -0,0 +1,23 @@ +import torch.nn as nn +import segmentation_models_pytorch as smp + + +class FPN(nn.Module): + def __init__(self, num_classes): + super(FPN, self).__init__() + + self.model = smp.FPN( + encoder_name='resnet50', + encoder_depth=5, + encoder_weights=None, + decoder_pyramid_channels=256, + decoder_segmentation_channels=128, + decoder_merge_policy='add', + decoder_dropout=0., + in_channels=3, + classes=num_classes + ) + + def forward(self, x): + logits = self.model(x) + return [logits] diff --git a/model/losses/__init__.py b/model/losses/__init__.py new file mode 100644 index 0000000..759ea05 --- /dev/null +++ b/model/losses/__init__.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch_toolbelt import losses as L + +from model.losses.pseudo_ce_loss import PseudoCrossEntropyLoss + + +class LossFunction(nn.Module): + def __init__(self): + super(LossFunction, self).__init__() + + self.loss_func1 = nn.CrossEntropyLoss() + self.loss_func2 = L.DiceLoss(mode='multiclass') + + def forward(self, logits, target): + loss = self.loss_func1(logits[0], target) + 0.2 * self.loss_func2(logits[0], target) + return loss + + +class SelfCorrectionLossFunction(nn.Module): + def __init__(self, cycle=12): + super(SelfCorrectionLossFunction, self).__init__() + self.cycle = cycle + + self.sc_loss_func1 = PseudoCrossEntropyLoss() + self.sc_loss_func2 = L.DiceLoss(mode='multiclass') + + def forward(self, predicts, target, soft_predict, cycle_n): + with torch.no_grad: + soft_predict = F.softmax(soft_predict, dim=1) + soft_predict = self.weighted(self.to_one_hot(target, soft_predict.size(1)), soft_predict, + alpha=1. / (cycle_n + 1)) + loss1 = self.sc_loss_func1(predicts[0], soft_predict) + loss2 = self.sc_loss_func2(predicts, target) + return loss1 + 0.2 * loss2 + + @staticmethod + def weighted(target_one_hot, soft_predict, alpha): + soft_predict = alpha * target_one_hot + (1 - alpha) * soft_predict + return soft_predict + + @staticmethod + def to_one_hot(tensor, num_cls, dim=1, ignore_index=255): + b, h, w = tensor.shape + tensor[tensor == ignore_index] = 0 + onehot_tensor = torch.zeros(b, num_cls, h, w).cuda() + onehot_tensor.scatter_(dim, tensor.unsqueeze(dim), 1) + return onehot_tensor diff --git a/model/losses/__pycache__/__init__.cpython-36.pyc b/model/losses/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..689b3ec4cf6b31ce0731280c33a6a6dcc7f9be32 GIT binary patch literal 2397 zcmaJ@&5zqe6rUN}v6CjtQkH)5QTibWpp9B?sA?;1TQ0CF^Z>FD#`Vl5F^-+iIBGX? zPFbm^iZlEN{3-kkzH(u2oVmd7jW?TZ5j)Ym`ShDN@BQ9y@>#Ffaqc|+@y`}xf3Oq9 zMgJbE`2j*Q$ukxj#H8gYEWt-O37MltJ&@Bco=r*wL7Hke9lc?th7HLOWN*IqI0} z=qMI)$wKjh?QzMaI0!;VI?_F0q5FclEc9@aS9eLv!&#n`<04P^kp^l>8&&f$M9E5C zv871Gl5ti#sk>wqi@dUx`lS6OR--^N=yZ1L&SaLA<0Oj>neI*ssrv2OGqb7Enmkzt zQ8doSWfavS&+E@rt>)`%4QDo_fgo(^;G~E7lMC9O>>Ic{x#87Delm`&@5gU6TPe zyd8Ax;Q{Zq>`a20<6_Oi$lQZi;Z$}lQiXI5+CZw!l+aaPlOanUw$Sg%jy!X~G4d^j z1GsLZzVsL<*eSF|jDg~lcRlRlh(th8p_XDr)oeq&3hRnM?3AZLTT^&(U&2Jct$j55 zG6_OU?Irt3rZlOqdnSn}Kn;jB*`a&?S-6EjQ*iaW5Gy|(N7HIOj}=4efkEC90iR#{ z|NE{nNrqA(6Y1O?RFgxLz~IlHBZ)-{?AhcMKp{(ARId`1leQ#WyhKkrl`Hvqb{W+5 z(pfm)g3!KZ->~lp**=i$0ol${EIf<@_zx&wv?>IxY7x2-9)I_ckc}jDhSKueUH;cS%WfG#C0dR8by!6r`I?b8C(12FwI|(rYLW84nPF zbNmwDMK~chV-3qqsjHv zRrM#dc6GI3o%i}f%u#J}UTmF?l9e0nS%?4NyJVG=gtTK0yN>}P%f5FO?!sFD7Zt{p z^OLxZ2O$-ei%+q2v}ohda}{Q#7wa#QHg&u0@bz0n`gm)-F&FSNo-BGvPnKdJes z=Fh?&Z~<6E%{Q-ylxX==R6``iL>Hml!f$@#lyI)Gh7^4VpCN0npwXlVG(ah;DXjV; z336e>m|auGcnBw|92L^WNYnZH5(%5;BvyHE`IFd$U}x0R?b60hOD@csGeH|AgJOjyr`aCtbxG z@B%ylZ@^P*OLYwoK*h`^;jVy@W;~v?zxlqIeKeVjzW;jj{s$%GC%I!B@~7YigCvsH zB-=h|#WLnALB%Sbg^;tFG*K2&@`A`f@+%^_V%1<1WU-`gNfO=w3OP-vi{7c$bh^Oz zyGHBl&n}8(uA4L6RjRNiwPoHG%Tj1# z)FLf9snP~Qoz}ornRMkvh(ZZ0%vKjVf(AoyL%tKPcE=6a!Q zfdbG%aD+Jl_fxi{qA0pzl73;I0~=I)9UPD~J0gFpe9fzHgF|7kx~<~K(dX`RI7UGr>fa3f$67yF91^I2nbjp1t8R|oXh`sz-+d2k*Z~r!%_ibUzuFY#Vy3O?z h5s`q4UVDlj?6&irfy9ruBA(!NHpn2RBRXc|;17#p>Q?{& literal 0 HcmV?d00001 diff --git a/model/losses/pseudo_ce_loss.py b/model/losses/pseudo_ce_loss.py new file mode 100644 index 0000000..4bc8ac9 --- /dev/null +++ b/model/losses/pseudo_ce_loss.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch import Tensor + + +class PseudoCrossEntropyLoss(nn.Module): + def __init__(self, dim=1): + super(PseudoCrossEntropyLoss, self).__init__() + self.dim = dim + + def forward(self, input: Tensor, target: Tensor): + input_log_prob = F.log_softmax(input, dim=self.dim) + loss = torch.sum(-input_log_prob * target, dim=self.dim) + return loss.mean() diff --git a/model/optim/__init__.py b/model/optim/__init__.py new file mode 100644 index 0000000..50eea8f --- /dev/null +++ b/model/optim/__init__.py @@ -0,0 +1,78 @@ +import torch.optim as optim + +from .radam import RAdam +from .lookahead import Lookahead +from .cyclicLR import CyclicCosAnnealingLR +from .warmup_scheduler import GradualWarmupScheduler + + +def get_optimizer(params, optimizer_cfg): + if optimizer_cfg['mode'] == 'SGD': + optimizer = optim.SGD(params, lr=optimizer_cfg['lr'], momentum=0.9, + weight_decay=optimizer_cfg['weight_decay'], nesterov=optimizer_cfg['nesterov']) + elif optimizer_cfg['mode'] == 'RAdam': + optimizer = RAdam(params, lr=optimizer_cfg['lr'], betas=(0.9, 0.999), + weight_decay=optimizer_cfg['weight_decay']) + else: + optimizer = optim.Adam(params, lr=optimizer_cfg['lr'], betas=(0.9, 0.999), + weight_decay=optimizer_cfg['weight_decay']) + + if optimizer_cfg['lookahead']: + optimizer = Lookahead(optimizer, k=5, alpha=0.5) + + # todo: add split_weights.py + + return optimizer + + +def get_scheduler(optimizer, scheduler_cfg): + MODE = scheduler_cfg['mode'] + + if MODE == 'OneCycleLR': + scheduler = optim.lr_scheduler.OneCycleLR(optimizer, + max_lr=optimizer.param_groups[0]['lr'], + total_steps=scheduler_cfg['steps'], + pct_start=scheduler_cfg['pct_start'], + final_div_factor=scheduler_cfg['final_div_factor'], + cycle_momentum=scheduler_cfg['cycle_momentum'], + anneal_strategy=scheduler_cfg['anneal_strategy']) + + elif MODE == 'PolyLR': + lr_lambda = lambda step: (1 - step / scheduler_cfg['steps']) ** scheduler_cfg['power'] + scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) + + elif MODE == 'CosineAnnealingLR': + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_cfg['steps'], + eta_min=scheduler_cfg['eta_min']) + + elif MODE == 'MultiStepLR': + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, + scheduler_cfg['milestones'], + gamma=scheduler_cfg['gamma']) + + elif MODE == 'CosineAnnealingWarmRestarts': + scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, + T_0=scheduler_cfg['T_0'], + T_mult=scheduler_cfg['T_multi'], + eta_min=scheduler_cfg['eta_min']) + + elif MODE == 'CyclicCosAnnealingLR': + scheduler = CyclicCosAnnealingLR(optimizer, + milestones=scheduler_cfg['milestones'], + decay_milestones=scheduler_cfg['decay_milestones'], + eta_min=scheduler_cfg['eta_min'], + gamma=scheduler_cfg['gamma']) + + elif scheduler_cfg.MODE == 'GradualWarmupScheduler': + milestones = list(map(lambda x: x - scheduler_cfg['warmup_steps'], scheduler_cfg['milestones'])) + scheduler_steplr = optim.lr_scheduler.MultiStepLR(optimizer, + milestones=milestones, + gamma=scheduler_cfg['gamma']) + scheduler = GradualWarmupScheduler(optimizer, + multiplier=scheduler_cfg['milestones'], + total_epoch=scheduler_cfg['warmup_steps'], + after_scheduler=scheduler_steplr) + else: + raise ValueError + + return scheduler diff --git a/model/optim/__pycache__/__init__.cpython-36.pyc b/model/optim/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63f4328e423d517df11cfa0dc56ddd5af8a8cdf7 GIT binary patch literal 2228 zcmZ8iOK%fb6rQ)o<3|ih(>MVFU9@!{F)FJ{D1ZWD69RHjtIlS6nVHy=%)`2K4Foxx zf%JZtk(nBq>6^NSxs_U(?b~W@r%qP!D{AhfZdUcHm{%hA zh30!vHS+MTMf$GRsl9++0mulyA`nvlj?R+9I!#mA}p4+4#PAFw~F#sp2sXr z^8R)Y`^$GZi$*N{n(=Hj+zkhDG)iM0lEL(t7!`~yjPe*_tVenq#vT}aL%@kK?U>lH z@V48uaekhDNc42~zyb}wkRp>G2XRWhmnAh4Ge<7@~5!Itc)I2ed zjIjZ?&+SMnE23DMyLWF(BjvAu`I(--Z^-(SIOz{W5XB)olx`lEBId>8j>W6+K)P9x z#km+|yoM#OlOR7|fBJIy_h0}0=SIhz0PRPyV5LIIYC4PN8yna@2A)4|bPTB;yT2T0ljUhK!AD7zwIc{oS{UY2pc8zM4Y3QyWqO1eRirH(L5^tla*9}*1T6_&B^Ln6KpbyB0a)^>UMlrWF%89CpZK{cV z$~1C1sf#vYP#KV@49JR-u98*Mig!{uGGFQWW$37(Qus@4Pp53ZaE~md&w@S+xi^l? z{4G!mKz*-H?D47g7Id9MeVF-gz?=u>CxvkoW>MX@1k8I1^CK|J!2G2!6@__UVOEqR zZEe)2nI8kWdSJZL`FY`<(X}?i!6(vt0X1*t{;{@SjV>zuKzO32ytzc)00%3c8_)G4 zduG|#o_@U!s~k~3qyI9x$%(Eg8!XZIP9K}7`wwFO9(q$zFJhD{5D^#x9vW?JOiSAi^Y%&v2ovpx1*e$=&BJDoQ zGS<=g2Eo?;r;zxx2PiU7O7mfG)vrJ)sRFrNr7xVxTHV5lQe;c<6>ON@0mAwm2mA{X zWR1FuJD9x_aB4;saH295ZnWeI1=-jE87gX@=8#wfk=38V@DDLR4J!W#%PkBQPtcdT znw;a;Aiz?dvVWK+G54p15f6)S;4iQ}5bSC6J59cXPh=edpQnUhP+N-Wf<7-sC{uS^ zB?lEuiR7P==#o(JuiJ7dPUiUGb zkPm)qzNvzc&gyLPohm5Bh;tZOTJ|6`-314_;6rEnYhhj2&l59(Bk9bUziH;_))d(^ z>iR0q4d5qxo?#h$6IwfUYAhZOHWUe~VEH!wi1Cj=x#TSp)JDEby_$rG5QW5V&5}85 l3>6ba(Mt0qos{I$#G+g9M}U9zO=`}PYQ%$W>Swx({{c$EC4v9| literal 0 HcmV?d00001 diff --git a/model/optim/__pycache__/cyclicLR.cpython-36.pyc b/model/optim/__pycache__/cyclicLR.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ba3d092e31f6d77283856f8cf3e2407852a5fec GIT binary patch literal 5661 zcmeHL&2Jn@6|er74^Ql1Scfc@kjItiNh)fByLE1f%mFrdS*o=f>Pz?d?ClwdY(Q5Hx61VBVi+$m{UVL7Zkn`ydXHJo`y znx#|Xl4cn;p3N6NI1J(__#l&a(v-6}O52;;%qb|i*L>(72R&gHpmb?3*#{v|tZo)`fJ>AH+z8I3LNh2P6sS&j?rv5c;9z>ZNs?+O-15F&Jg%K)EGBGsV}nMmliF_zr( z>vSsvZ>tW&AOJbCG(OY}qj89#{m=Rv0|UaYrLDb;-lJ(|0%_<9S=W5-h_)QtKN}?q zT`F?MTcZyJ>vlnz;Q5r@8mvwr-2ebA$%nAf& z7_Vb)ArIK0KjwLc<%r{WsR=jWHDrD-gmjD#mo+2$n8iKFH`&R0F-ZTIiW;UF^%Z?rlD)S0{lKk2+QbBaw0?wD#2 z)noqmA;*$lGR>wALD6EpSgp}jzVcqNdLmetKxh!}Ly!WL=TFh@_}Um`g2Bpay=G1l zNH-tDaChxB<8T}w9i_T*tkM*%i{(Sr6 zU;XtDzxn&0e!u?0f}x;d#tA;1;Vpj*4MqC)iKQ&0k8F8a!7&VN`{!1_)F-~vw}HBS z2WYA90xheGs`kl=eK6Cv`sEW{OWw~p5Cj}O|R<004>9S7$q>@ZGwn+u4^yO0~l z1Fx35I4v*BOeh}awVz@D{ZNQZ=0lnV%cC-_1)pbn=qyb-55w3`kz}eLO>b z->c-V09}d~meInSzlF|o>i};ntXv|GO5*CN1;x}3oo7`TwKvEKRBz?r z44ipl53Ifwx{qwJ9=ho)jziyCB$vp*Ha%!L;C*|M{1QhXTA+n{xZ)S5(^7ME}AWo`Vq6(50I@m)U-hik48jv>Wlb_WF)^l<|ertbk|p*FOh$l95r5Q(Kc7Wq<$;Tpt{I)XpGhh zoi-ywBNLQb3%LmmE#w9Wi6%j6DKrn7);+P`APd?RivY7!>7WdEIgs(0G@tnL0m_YbY__RIp z(vv7uov~TO6_|DlZetH;%KU|9H(`R-?D~XeL~8-zqfo!k(n5R#<#CM`Z$evq6UkdR zIwi57IIf-dEo`BWwFkARcc5oPwO{-LF0Z6*U=p`FbiNdYg=pSAAlwcBX!$Qy9 z5@mP_Ja5fN1y7_D5TozsHQ$GG0U$E8XMO*p9*c)>D!w0Pf$!_{L3|fh=X1luC*W}m zFSMP83GsF0j*{vl5BHL|YxS~gpFgL^0_Yak=SSV=jl!FNmYD#&{SL6w1?RQkwpl)1 zZ%s-6Pw+N;0)S9MJ_sYec&z<;82i(J^~x)t^nV{lYj=JLN&gQZDdxx*2uMx&$HNLQ zd=?Pq-vXovLR}3@OB$B`uK=l-0ntPa$vGs53r$@{^l0ibf;X{(1o50#G4%{lg056z z0b2#C)FliMi$1o;^$wTr!aWTIq1xc@=;HLz`y~hjOHCn(vj|Y*hXb@Hq62__-Wdyi zF3|Tr@Y2xe<^(i)CFl#u2q-NAU@Cx~)dntb{p~j3nQ=T7D8>ZHYrv!)Ff(dSBRb3~ zja8SSp~b#BA+|=b*Ss+Kdi<4&cuk`?bHr#{ akBNq?)msKguH%=A#KWp;&En6Qzx_`|Z?e?@ literal 0 HcmV?d00001 diff --git a/model/optim/__pycache__/lookahead.cpython-36.pyc b/model/optim/__pycache__/lookahead.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c0bb0e6f830b9897381a98d488414af0c32026b GIT binary patch literal 3673 zcmb7GOOGR06~4C~ew5wrnaR-UP9PLX2yKMgJpyHdHDqQW39nJp5+Ng5prX3&wOw|( z%6+Rm(`~so>G=VOE${>Q3vAf2$`3%bh>%#I*=56q72mmKm)%VxgsuCmbIv{QbH2T` z<^1!LC%ir+4pBrCc+>K2x9d!v3Z1%evC=79v7@<2;MWr zj<7}L1@Bp+Drzqn=+}jVwIUj#iM1-WLQi1>*BV!gVhnQio4=Vz-)+Dz|n4A zf086)e<*!1&m0%O%rQ?N;MQK4tJ?)% zFb`wIM+SDrf@X7JVzxvDv%Rpy*0_?6vdX7?Vd18l)w22};W=kd3d|4KuO404XWfM@ zoIQrq!3A(f&Z4r%E~;5$)Lc|XTfnJ`#sfz49viiaJI+pj`;?u&c*;Z*%5V8^amKV3 zNb%jNul$LdtURL|#<5iE^L4{Ipa++Sp@T2L_-|t6>PiRpX zOJBL5A=NaA3N!cpAWPJNd+*@TRWeAPNtJ5DUt^{)`v%+7Srnc4!PuK546bd(ulILy#?M(sdPm&ut}BndP+=->OQt6>Wnw-wOO-D1}!XkJ_fH9Hwp$d$v-mnReb%|hn*p;))~%6o#S&DOihkeGLk`-cCyfqgJI|?Fg)l4iIANM zrl_Mu?v&a+n4ae~&kN%)^SmaArf*_m)m`4=HH->(FzDZ;z1hN@!*B8VoicYEtk46p z5dCN+b?f+l0~WizirN9(h7Rt~u~fwW{La3IX0X`OlAH*k}Zh|I9~>ObGvlw%EiMHZ4r%R(js-;m@ILJ9`n`BB@@Z0B_{ zFg(Z??HjpKZvp!rMrvSUeE!P%=@RwkcgcZQ?_PP+DP$ve`!A2a4OFUpt4HsyA6?@H znj=8bN<&LazF54v;FtVA65g5!^q_4(FJ3c*`a16JW30H?TXS)ZN~r*<+hq|s z!Qt$e1v5zA#~1wUCkrmPPUY`kkW-;BAeEb$3D+`4Hd47M45V_*NadB{jx$S&i)Q8ge-zC}rQHeF^aMQ(Z2t``2kjYAO*uhMRp zrahWo!_>A4Eq|FVzCsfP*fo9UJKmDM9}7=wSZBsAK#LG40~w+0u8nas+RZlk5s00A zzqHA_u*qR*leaf*0-cjpv?WOAwENy#^Atwrk=BcP2SnTUbv^vw(gzW0q*n@GDB}Z? zo`8I7799UD#Lgmr9?uohfU@cRfK^eD?bH@1JnRZZaR8E1dh z@D|+nvkUX=FNF)@^rs7*3l|pL7ro4=GO7ZycF9L|bVS(gGIT9A+!f7gu|wYiIIT)H zYKY3%P@ik+z~5t{3Q869E;rGs5uHs~ga~D{5)&`f9R5`#7pY>Zi=9ZZ{SHuN*}Jj)1`Mcbom-4${~68-?lY^LG@3NNHnjd^RG4M+&|P?Uugd4ADDq2e{EXNdiY^t{~h@U)mkv`4)?o%zx7 zOV#s45_q2a8sWZ46O}-v3j|dJMI~A7(?ncuBrz?6OLtTT=XN*Sl$SdC_+PojxE)#nTEgVp}h8>5aWHS_9;DCgIb#AzX4=y BXhQ%1 literal 0 HcmV?d00001 diff --git a/model/optim/__pycache__/radam.cpython-36.pyc b/model/optim/__pycache__/radam.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b81b2c10e478be9eac7ba9a1e8a1e49e966e2175 GIT binary patch literal 7054 zcmd^E+ix6K9X@AnJG-8}c%3?SQpzf{Dccr17igy_iL2_eHC=_XB6r2E8KNW{k-x%g=F7}rH49BlCg$14gx9_P z?zO-D_L)m(PG0}buj%>o$?NyOJpKE({_)q7*B|-bkw1KY{mB!Oqm}dZ%&^fcer>!u z-tckc2_9>09!^K-cMIW7zR6$WNHvi908(|N8sk)q$1L2?S(^us3$x97hQ>GxYba@< zq{`n>Ix{Za-%%ZTDD=D?1!Q7&eiT5BTavAwBrtQt&5X|XXJUA8g%bnmj) zX^Kk6_hhfxYgD8c`Hxh-J=o6j`(d*a^eQO0s-v!| zD_&G-_JqGyIYRYq(Zqgm*`ri8F{!GsPM^@D;iYAG4sSRO(_k^XSQyai$6rJLfwp3x zXMFJ*t7?fhcO>Di%%Pw%?>siY9CZEp<&D<-nf3azC%b24(DLh1I3G2=UVXXg$}sdB z^YuXZ^Ie$Ed_bFFUV2!!bN%&qatobi7}bMr{{%S@l1mobIHVT#oa;7w&B%4<7CJ%Q z>x3r|7!}_lKN9QG>e3Q+)~$C*b)l-s3F?kRX52BEr)U9Y2Jg_Ps+njVDQV{q*t9J_Z-dUlTrpUv6cea^`fLK{D$B2X$1&33f!Vvk0onoW$*?<^&jS{R|6M7DBH zmMJtvhDJ0oRe{77{*t%ai9$jo*R6NFFm#8Lsk(a?3KJxR6&S{CUQq9yyv*@x7Qc_@ z_z`{pcA9S#c27{9=8YcuB~KNw>_Z@85m%cADtxi3%TFLT?Um!Yp&wy1BHwi} zC7(nNx=ryam!G0?YHpODcO!ZbubLLdH_9VVwyRhiyF+6^nD<5c;4g4aumKzqteur* zIE2_$ZS8k4Yq2NTRqf)N10C@aF~VarEX<|y3KS~|WoW3y$*T4V~K zY0dnbWP5w~mKp0~Ih||i(mckozSIk;qy}MDyoMo-tb_t|d zJ5o?8Et95eY+w(vDXM)_8{`HPaTd5VvB_JGu;Ymt7Q@d?uvly(J%{u>(sPtf$c0)A zlL0U7hcIWTTPO643Y3o2+%%98J#o@_>p&*ip`oC)GUAI+>#R$~c_z*g?Oz+w%^J+&w+So$0Hs!dW6fge_SK~w>x^aF`s zfJ5k48o=YG--|*;K|q$s?KgL=9>hc;`i4VlH(8z-6`>x~X4@b9S`{x=8U<-h*1? z0T^IqAnVIX4w@#ZfG@Owe`y+m{v6%#s#OE63` zfXkLv(#k+$jXT`pwwA-6qwt#2%Xl4ahEHoHJ_G#ThdhUufaA8F;|4#>??TyWstNG| zFuja&Dc&Cdy6@9Aaw)2kyVo>HlPfKlswruX<(J72u%@~a3Za@Ia_T{&xNeejU9kOX zhr$5&l~u2k=1A&?BnX$(t0X^UqPpMdjM#{vegyjp@rg^Q@sNzI6{Z|r@kSld@k1`s zb31Z6cc$Ytd#dLjLQ~VOAP>pi@HLo^oTeVB+UzR^xC#1N5i`^^SZbtVbN>?+D{1g< zQpHw{y~Iv6BS~(^JIIjmkt7mQkwFJMweqUA+TRKj_o^iQ@DPfR7T~t-pp{G5riAid z`lzIBOl>cqv34D6kz18U%4^yxLq(*Z=0?A$#AWh}>W3cMzsrZcL=izU?Oi2`B7C-| zD5{M`(V`MXr$A97QS|WtnJC(pGLFbXJSdT&qzvQ=9Xv@aK?*oZ$~Z$Q134jbkT9CG zF?NtiN=9(`-66drMactahm6n#Vo{-0L@THVHKc+^2_J}@g2*btQy}p}!iVA{cvQ=~ zB~KAjus9SY1tm&KXt@*@(W)Jn#$p6gWhh2itE{BR$X881_(8tjAf<}x3)RM0nv7f6 zAx(-OMw%$rkYl>tAxElM(DFVq>gc0n$Ndx~?NObQpP?|(x11+)gv?Pg$H*;}lY@Y&%c;Pd-8B zOJtrVbBfGUWKNS&@Ux9Uhmp39S6x}%suEBNAmw{~5sY5~<6DaTiTlC%#U02oB>uDH z#5ouc|5qveH8SVP(24uJB7XT5ia$eUk<4z6m(Nn#xFaI~Q_M~f=i?nQdlAe|X*a>_ z1$8-QiOiMJMVT$K>Y|KJ8_uqavQlEwJ=ukuCiU?>8PO6!rQ-Bg5%{QZdZdl+sA5ZG zaVwRH^vF!Jpd!wLM&h4|ga`@)h{Fmg?kj=}Z=HHLVNpTX8gnz<@+=IyZ+FL^+=NEqgBIj#A z9}_vh_~VZqk-zu8;?{0EvaQ095OK?aQ###|>>$kn-HnXN9=8z%j8zV|9E#xE25uu> z8r@tVCnauG;vn0PdyXCvJyHk7xQJT|2mm5xx}CtChc5DC2{1b0BL!rP8<62G0?~6x z(X(?SJ?|jT!jL>q)AIP>WBmgaTrGMy=+}c@*c7-#KBo@;>T9}(^gW>f{r&lsXjIZ( zYK;G-pg5M;7RToQ4!tU7RUA8}*V|;)({L{-U21xvTMwkf-E=eP9lastSj`FR%RU`h zg^OEW#lU0EO7DY-e1C;EJO#sWjJ~;Jcu5_^iEim=ewSuzWvxI*bB&JK>GmG7?enNt zejNsEdx63)kg1U&wpHdf8JA$_r+7z)21wD@nB7u<~i literal 0 HcmV?d00001 diff --git a/model/optim/__pycache__/warmup_scheduler.cpython-36.pyc b/model/optim/__pycache__/warmup_scheduler.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f03a30ae4ecca03e39300679e0d68f80216204e GIT binary patch literal 3222 zcmc&$QE%He5GEzrmK`T)v#d4NwdINdt%}r6FrW`Xnzci>Zb0S+O|SyS3Qdt1TaGL# zB#kbH^HK+BvA1Dw`v?0=dfiijJ?yzpyQ6H!ZPRYU9wwV7@<<+!ci-LVv-$bTAHUrB z;jl``?_}n&pnn&dTmcf0fJUT8Ta;qofWFZ(D0x5vGq6rbU~xJ!TV_z$Ax`ln>>v)! zs_xdK9lyhae#8Zg&p(2$&$k{ujJG3Caj);=%E^N#;v>HR!byt;q-6vYq|QAj`&6G) zZdpMgD4vj3A*cqWz&@d^Vlc1g%R%L&+$udM7HQd_$3k}Df$)O97d`ewub)m;@oa+- zHlB-)*i)Fk)lYAPu`jqM!?^9RhcP^;{tcso+mkr1iLx2n1%bZqfNF} zU#-8vyo1aY+$lk_l5AfjJ+{icIAA;uBaeODovypV+jaI9_#7>^ z=@&HW2*I~GZgLYkLt{u>G9+C<{{ay-MFV3*pV6TSd}C+>LDgONk0nM9)}BZK#c%xy;!kCFn$PV1tWmXzjExXa=I(4eCI`d z{-qF!$mV9unk~Lzl+4}-v`H{%R*?Oa3xR=^*{%z0<+@HGvm}r9v+@*30g*XSjSCx$@l*&(f>E_2UmUu>y4A$OZI>-eWksDFsdy(9P&SdxQe^XPON4=JS zyrrYS`(A}MR&E*xubqgX(}rQt(T1s^VU+e`?x#y{n8?fu`%&VltODRvdBAJjI>z6r zbL~`}v$eQ6(c;>~Jprj{U6zp=mmsPyPW5r<^iqzQ$Ipi#;2$dU>8UbKNlZ8D=o4j_ z%l<8z!X=8Mht}-0tyYKO?E&X0io zUt&fONdABC;uq5-EbqO;*i!M-eXAs$6{eX1S-_qF3A#2j2VpAau~$t!;q~0MNcySN z2!i|$(w(Bt99gNyl?Z*A6?7SHmB&drf;uv@ReS`(CW%O=9F3_9XDfOgub5-Z25c<> z(kk$)(FMb%=ZwLve|7CC<|Zj!V@|*s3d=8)RE0R(gT5)YkpmP^0cc#sQP%>5K-rC# z^3jmOy8!qAp5_!s@ODdwv9eWJ7itR&^$v-vFkgW}PZ41nOSI|CWtKXE%8A8Lpe{Nl z#bxa2DtQrlP*HUff~c(E1jdT^1PJ5JM2h@YV(QGvtd}l+y>bm$DF`g;8Zjkt{tLv`7+@O`z zDrEEHGVk*!f+bjnCJ@jXqk8&cn5!jT)3z;Rp=>W!7B#==SXoUa!tc~|#ji&*pJxI~ bQ*J+PoT>OzOJw#&uJ last set milestone, lr is automatically set to \eta_{min} + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list of ints): List of epoch indices. Must be increasing. + decay_milestones(list of ints):List of increasing epoch indices. Ideally,decay values should overlap with milestone points + gamma (float): factor by which to decay the max learning rate at each decay milestone + eta_min (float): Minimum learning rate. Default: 1e-6 + last_epoch (int): The index of last epoch. Default: -1. + + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, optimizer, milestones, decay_milestones=None, gamma=0.5, eta_min=1e-6, last_epoch=-1): + if not list(milestones) == sorted(milestones): + raise ValueError('Milestones should be a list of' + ' increasing integers. Got {}', milestones) + self.eta_min = eta_min + self.milestones = milestones + self.milestones2 = decay_milestones + + self.gamma = gamma + super(CyclicCosAnnealingLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + + if self.last_epoch >= self.milestones[-1]: + return [self.eta_min for base_lr in self.base_lrs] + + idx = bisect_right(self.milestones, self.last_epoch) + + left_barrier = 0 if idx == 0 else self.milestones[idx - 1] + right_barrier = self.milestones[idx] + + width = right_barrier - left_barrier + curr_pos = self.last_epoch - left_barrier + + if self.milestones2: + return [self.eta_min + ( + base_lr * self.gamma ** bisect_right(self.milestones2, self.last_epoch) - self.eta_min) * + (1 + math.cos(math.pi * curr_pos / width)) / 2 + for base_lr in self.base_lrs] + else: + return [self.eta_min + (base_lr - self.eta_min) * + (1 + math.cos(math.pi * curr_pos / width)) / 2 + for base_lr in self.base_lrs] + + +class CyclicLinearLR(_LRScheduler): + r""" + Implements reset on milestones inspired from Linear learning rate decay + + Set the learning rate of each parameter group using a linear decay + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart: + .. math:: + \eta_t = \eta_{min} + (\eta_{max} - \eta_{min})(1 -\frac{T_{cur}}{T_{max}}) + When last_epoch > last set milestone, lr is automatically set to \eta_{min} + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list of ints): List of epoch indices. Must be increasing. + decay_milestones(list of ints):List of increasing epoch indices. Ideally,decay values should overlap with milestone points + gamma (float): factor by which to decay the max learning rate at each decay milestone + eta_min (float): Minimum learning rate. Default: 1e-6 + last_epoch (int): The index of last epoch. Default: -1. + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, optimizer, milestones, decay_milestones=None, gamma=0.5, eta_min=1e-6, last_epoch=-1): + if not list(milestones) == sorted(milestones): + raise ValueError('Milestones should be a list of' + ' increasing integers. Got {}', milestones) + self.eta_min = eta_min + + self.gamma = gamma + self.milestones = milestones + self.milestones2 = decay_milestones + super(CyclicLinearLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + + if self.last_epoch >= self.milestones[-1]: + return [self.eta_min for base_lr in self.base_lrs] + + idx = bisect_right(self.milestones, self.last_epoch) + + left_barrier = 0 if idx == 0 else self.milestones[idx - 1] + right_barrier = self.milestones[idx] + + width = right_barrier - left_barrier + curr_pos = self.last_epoch - left_barrier + + if self.milestones2: + return [self.eta_min + ( + base_lr * self.gamma ** bisect_right(self.milestones2, self.last_epoch) - self.eta_min) * + (1. - 1.0 * curr_pos / width) + for base_lr in self.base_lrs] + + else: + return [self.eta_min + (base_lr - self.eta_min) * + (1. - 1.0 * curr_pos / width) + for base_lr in self.base_lrs] \ No newline at end of file diff --git a/model/optim/lookahead.py b/model/optim/lookahead.py new file mode 100644 index 0000000..378d874 --- /dev/null +++ b/model/optim/lookahead.py @@ -0,0 +1,100 @@ +import torch +from torch.optim import Optimizer +from collections import defaultdict + + +class Lookahead(Optimizer): + ''' + PyTorch implementation of the lookahead wrapper. + Lookahead Optimizer: https://arxiv.org/abs/1907.08610 + ''' + + def __init__(self, optimizer, alpha=0.5, k=6, pullback_momentum="none"): + ''' + :param optimizer:inner optimizer + :param k (int): number of lookahead steps + :param alpha(float): linear interpolation factor. 1.0 recovers the inner optimizer. + :param pullback_momentum (str): change to inner optimizer momentum on interpolation update + ''' + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + self.optimizer = optimizer + self.param_groups = self.optimizer.param_groups + self.alpha = alpha + self.k = k + self.step_counter = 0 + assert pullback_momentum in ["reset", "pullback", "none"] + self.pullback_momentum = pullback_momentum + self.state = defaultdict(dict) + + # Cache the current optimizer parameters + for group in self.optimizer.param_groups: + for p in group['params']: + param_state = self.state[p] + param_state['cached_params'] = torch.zeros_like(p.data) + param_state['cached_params'].copy_(p.data) + + def __getstate__(self): + return { + 'state': self.state, + 'optimizer': self.optimizer, + 'alpha': self.alpha, + 'step_counter': self.step_counter, + 'k': self.k, + 'pullback_momentum': self.pullback_momentum + } + + def zero_grad(self): + self.optimizer.zero_grad() + + def state_dict(self): + return self.optimizer.state_dict() + + def load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict) + + def _backup_and_load_cache(self): + """Useful for performing evaluation on the slow weights (which typically generalize better) + """ + for group in self.optimizer.param_groups: + for p in group['params']: + param_state = self.state[p] + param_state['backup_params'] = torch.zeros_like(p.data) + param_state['backup_params'].copy_(p.data) + p.data.copy_(param_state['cached_params']) + + def _clear_and_load_backup(self): + for group in self.optimizer.param_groups: + for p in group['params']: + param_state = self.state[p] + p.data.copy_(param_state['backup_params']) + del param_state['backup_params'] + + def step(self, closure=None): + """Performs a single Lookahead optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = self.optimizer.step(closure) + self.step_counter += 1 + + if self.step_counter >= self.k: + self.step_counter = 0 + # Lookahead and cache the current optimizer parameters + for group in self.optimizer.param_groups: + for p in group['params']: + param_state = self.state[p] + p.data.mul_(self.alpha).add_(1.0 - self.alpha, param_state['cached_params']) # crucial line + param_state['cached_params'].copy_(p.data) + if self.pullback_momentum == "pullback": + internal_momentum = self.optimizer.state[p]["momentum_buffer"] + self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_( + 1.0 - self.alpha, param_state["cached_mom"]) + param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"] + elif self.pullback_momentum == "reset": + self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data) + + return loss diff --git a/model/optim/radam.py b/model/optim/radam.py new file mode 100644 index 0000000..f439c04 --- /dev/null +++ b/model/optim/radam.py @@ -0,0 +1,250 @@ +import math +import torch +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): + param['buffer'] = [[None, None, None] for _ in range(10)] + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + buffer=[[None, None, None] for _ in range(10)]) + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state['step']) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + p.data.copy_(p_data_fp32) + + return loss + + +class PlainRAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + + super(PlainRAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(PlainRAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + + # more conservative since it's an approximated value + if N_sma >= 5: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + p.data.copy_(p_data_fp32) + elif self.degenerated_to_sgd: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + step_size = group['lr'] / (1 - beta1 ** state['step']) + p_data_fp32.add_(-step_size, exp_avg) + p.data.copy_(p_data_fp32) + + return loss + + +class AdamW(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, warmup=warmup) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + denom = exp_avg_sq.sqrt().add_(group['eps']) + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + if group['warmup'] > state['step']: + scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] + else: + scheduled_lr = group['lr'] + + step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) + + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + + p.data.copy_(p_data_fp32) + + return loss \ No newline at end of file diff --git a/model/optim/warmup_scheduler.py b/model/optim/warmup_scheduler.py new file mode 100644 index 0000000..1a8d7eb --- /dev/null +++ b/model/optim/warmup_scheduler.py @@ -0,0 +1,65 @@ +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.lr_scheduler import ReduceLROnPlateau + + +class GradualWarmupScheduler(_LRScheduler): + """ Gradually warm-up(increasing) learning rate in optimizer. + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. + Args: + optimizer (Optimizer): Wrapped optimizer. + multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. + total_epoch: target learning rate is reached at total_epoch, gradually + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) + """ + + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): + self.multiplier = multiplier + if self.multiplier < 1.: + raise ValueError('multiplier should be greater thant or equal to 1.') + self.total_epoch = total_epoch + self.after_scheduler = after_scheduler + self.finished = False + super(GradualWarmupScheduler, self).__init__(optimizer) + + def get_lr(self): + if self.last_epoch > self.total_epoch: + if self.after_scheduler: + if not self.finished: + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] + self.finished = True + return self.after_scheduler.get_last_lr() + return [base_lr * self.multiplier for base_lr in self.base_lrs] + + if self.multiplier == 1.0: + return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] + else: + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in + self.base_lrs] + + def step_ReduceLROnPlateau(self, metrics, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning + if self.last_epoch <= self.total_epoch: + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in + self.base_lrs] + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): + param_group['lr'] = lr + else: + if epoch is None: + self.after_scheduler.step(metrics, None) + else: + self.after_scheduler.step(metrics, epoch - self.total_epoch) + + def step(self, epoch=None, metrics=None): + if type(self.after_scheduler) != ReduceLROnPlateau: + if self.finished and self.after_scheduler: + if epoch is None: + self.after_scheduler.step(None) + else: + self.after_scheduler.step(epoch - self.total_epoch) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super(GradualWarmupScheduler, self).step(epoch) + else: + self.step_ReduceLROnPlateau(metrics, epoch) diff --git a/model/tools/Balanced_DataParallel.py b/model/tools/Balanced_DataParallel.py new file mode 100644 index 0000000..bcedc98 --- /dev/null +++ b/model/tools/Balanced_DataParallel.py @@ -0,0 +1,112 @@ +import torch + +from torch.nn.parallel import DataParallel +from torch.nn.parallel._functions import Scatter +from torch.nn.parallel.parallel_apply import parallel_apply + + +def scatter(inputs, target_gpus, chunk_sizes, dim=0): + r""" + Slices tensors into approximately equal chunks and + distributes them across given GPUs. Duplicates + references to objects that are not tensors. + """ + + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + try: + return Scatter.apply(target_gpus, chunk_sizes, dim, obj) + except: + print('obj', obj.size()) + print('dim', dim) + print('chunk_sizes', chunk_sizes) + quit() + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return list(map(list, zip(*map(scatter_map, obj)))) + if isinstance(obj, dict) and len(obj) > 0: + return list(map(type(obj), zip(*map(scatter_map, obj.items())))) + return [obj for targets in target_gpus] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + return scatter_map(inputs) + finally: + scatter_map = None + + +def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): + r"""Scatter with support for kwargs dictionary""" + inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] + kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs + + +class BalancedDataParallel(DataParallel): + def __init__(self, gpu0_bsz, *args, **kwargs): + self.gpu0_bsz = gpu0_bsz + super().__init__(*args, **kwargs) + + def forward(self, *inputs, **kwargs): + if not self.device_ids: + return self.module(*inputs, **kwargs) + if self.gpu0_bsz == 0: + device_ids = self.device_ids[1:] + else: + device_ids = self.device_ids + inputs, kwargs = self.scatter(inputs, kwargs, device_ids) + + # print('len(inputs): ', str(len(inputs))) + # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)])) + + if len(self.device_ids) == 1: + return self.module(*inputs[0], **kwargs[0]) + if self.gpu0_bsz == 0: + replicas = self.replicate(self.module, self.device_ids) + else: + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + + # replicas = self.replicate(self.module, device_ids[:len(inputs)]) + if self.gpu0_bsz == 0: + replicas = replicas[1:] + + # print('replicas:', str(len(replicas))) + + outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) + return self.gather(outputs, self.output_device) + + def parallel_apply(self, replicas, device_ids, inputs, kwargs): + return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)]) + + def scatter(self, inputs, kwargs, device_ids): + bsz = inputs[0].size(self.dim) + num_dev = len(self.device_ids) + gpu0_bsz = self.gpu0_bsz + bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) + if gpu0_bsz < bsz_unit: + chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) + delta = bsz - sum(chunk_sizes) + for i in range(delta): + chunk_sizes[i + 1] += 1 + if gpu0_bsz == 0: + chunk_sizes = chunk_sizes[1:] + else: + return super().scatter(inputs, kwargs, device_ids) + + print('bsz: ', bsz) + print('num_dev: ', num_dev) + print('gpu0_bsz: ', gpu0_bsz) + print('bsz_unit: ', bsz_unit) + print('chunk_sizes: ', chunk_sizes) + return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) diff --git a/model/tools/__init__.py b/model/tools/__init__.py new file mode 100644 index 0000000..af87f13 --- /dev/null +++ b/model/tools/__init__.py @@ -0,0 +1,2 @@ +from .Balanced_DataParallel import BalancedDataParallel +from .split_weights import split_weights diff --git a/model/tools/__pycache__/Balanced_DataParallel.cpython-36.pyc b/model/tools/__pycache__/Balanced_DataParallel.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b89543b685c6c8c3cdd15f6007bab899df3923e4 GIT binary patch literal 3627 zcmai1&yOQV6|Smox7$6Qi8BnF*=Pc->=I%Ko)ts_R?7;Jm2yGD!h|2O5;Z;Uo|*Qv z+uo{jvSYO`c+ofqRzgBtxo`*yQhlaN?X3->YuVc(b$6t*U;mdiCC`SKoWD z{JqW1;O?(}`rD0{8T%JodtA^zMpN&hOIX5l)~Bz?`AGByC)&xKk=u7I?dFY<*Y`kg zB#p#-!u!pnoir04l>aVET1oJPB|$0{PQR7#L)P8=9Fo~Vmz(y3SjP`z8RvPL<9G9L z5Nn-EkXw^_IEp8e{Cq%W(L>(=%|TOdpu1of{6btf$8N$C@x=c`%>1P&*#fE(=M(W8 z^ib-~{t|?E4)bh~s!*qe8cP*sg&v2{AjglhQLNMaJWL-=<2)RkPKz@Y#zkV6C7IGP zJDO@*dzy~Icp%40g(ulZX%W8r@RzC=KA27*972>`BGcnkriIlu4#!8sbf9T*tixEQ zVKLS#D|(iHz_A;8=(~yb6*To9beb)h=0mX%@)zeF?F`+eyW~r;WDEBY?DXnJ!k295 zE!YBQgkNF_eoHw^ZmH%1eDAXYF0m4Q$eZaFr0~Q&b3q4sL${X>M#NL=K+6|^@6$ED zPr=vz60ab4TkcJ8f40W`ndROhEi386m&i(Ju&XZEF}KL^m+k|rwc9)}4jg8jBpaCx zdlZq%W~q|on@dLt{2&}yM|D7)u#$b~&*6waibrtG-Ielm&*o5 z?bH`P-#;CX(*4ufaR1@?;53$_hjMI>YF}q@F*wa4sZ@HhKNu(J{s>*ZugBwD?f)ds zF`p*UI*5Cd^DlyXInXi~k0$rO20e<*^w{i1tK?^)mIt7&=hTz@irH{!#%#m z1H5g1gKwklqJ@06_rE-FP#u+PysVf^wUVzv-jzA^?Aicnu9(OW)2xNlfh+1*bNhzi zzK~xBJ)je(hrYDEchJ<|(UnYYYxJckozh*ng*ahzR&x0!M)@kICjBBP+oA(inrHiu%@EsCR1N!w9Kq|f0LHPe0kiEzN8A8={CVcnGgl@wguKX&7jF&z}R!>Yr#>Gi0 zcQA%|f~D~)3$KGzrNJU->)HK$oAhj>W8wzz+u?7Z-R3@Cc^jhxnp!;+jr6D=pi5bw zC#)}QX0t`9@8H1QC#>(FZCC(^1l6YiT*4rUA9Nk#BUkB0a$3g*zvC&60XlQ25$8ux06U zs!ce;AW1*M9V5yTWjqv&X`Wg)uR*mX)B36)$)UzOiE*irW^+8%$bC^Iq3g*mW-Wy6 z>B#R8vtKV(rdhG8FGY!IA{ihRBtHUsotHi2;(&LA$G00>do3Y+_-s~e@i;Ja+zdVE`ZWs7qBhr9Ej^vfU-ub{Nt1bn{__4nAL zqhGVqvrzhL>0d(0FFG36u%W*KG%ii8gz64m)mn*LHfpL-dNl=`YN`qS#1~Ynp%?o~ z{691(byOgfZ(-ZB24enhIBONt5n%~ruu26)V>X&14}jiWyIkE3aWkS|?QTfQdzQ8b zSRIDcfL7Es<@ zLj-B~08LTOWk|sR(oY9Tcn6T)!qvj(TcRWAx>3C^fODPPyFO*XA|9ntWP&KF0u*#R zis<5B%}7dk@@?uUKzyQ?%dj;_kU;du&NE!v3%@Fk` zkXsn+zie-Hy`t!?KrIRS{k`YLd(rW@hr4MSR?!UU}$4FzIpMQLIN%(99hXFF2#UKAq*eT>!;WLv9d zBcm*(mjDWIpN9@i6c=TVI)-hLRpW0`{6~H2^cmWn_+H{DiE)hV%oWV?<{F literal 0 HcmV?d00001 diff --git a/model/tools/__pycache__/metric.cpython-36.pyc b/model/tools/__pycache__/metric.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b94a4f71990f247c4490e4467476299384a11356 GIT binary patch literal 3181 zcma)8?QaxC7~k2KyY01wQYy57WPL*xNDn1Ji!Y%NYihJypoFBPIhUPjZ?Aj1Wp-ES z<(d$#5sa4jL8HOM1Y?XZiN+5mM4R|G_;tU^QR)}I8$b9wvv+%K?@ZiepP8L`X6BjS z@0s7uiT?iFoqx{1`$w-N{V8pF6sVtuH`oV;N|dm>Q``MJXg@ zxh9n*eCOmkwkcGC*~pZs#=TxtDcY_T1T2VUZ!y-YjQc^n)Aqf&D8N-LSmnAY$7;ab zIgTrQCvhE$zt7}Leud>r)pEYDWS1;nDR95c>@diOj^)`Ur^sRA`Mm8@mal+u^P%s% zLB1kFA6;CE^{n-O+%imr*lU0;r_K##&pD) z;aXRio3`7&8&$khMu9Wx*rCB$5V>L6%?Sjyf#jGZ?x{GQ%e?upWU2y6j`7F(700vv z$P42vj4oM=%*^mC_GA~6Sz}Es8mkp6Sco;@%yIvMiNy5afWdj@o!!bEwg)*dQf83s zAwy($2ZNbP%1+vl_8(XN*!}c#&W;3Olo*ONVSa*XJN!W!*X3i zB%0yB$2fj1I@g(NShgK;%U((WD~ubjVJq0f(*@(GF+F7(`LqrV6V29i2!JVYRHMY9^Ak2;QsfG+n+tU_if|G+T*p48sFS}ymt5U z@zJs4$Bg5nFAF<%au&M4XZ`D6*MI)4ar=iCj2Dd)jgME7FPLoOtCfepeAu}4{`$Qi z*YDg)SKGh@g1lF)>rT*Qi4Q?rlDi0dlcZP|$}Fn`YA%T2`q1M)U&(_v%JF{xE(W~ zOPE;`DM4Hze-T(wjGcVE_5qM|0y?pEEM3i1IR1i?w9YmcfJPJrJ_2SYi%WKaYZta} zLF*8>2dEX2s*@b4KHDX|I|ik2p)TVqj{bS5v{M`)lZ0rB-`uL4ekqXj^4?9J!yvjK zzcL&I1-#Un0x*NEsdEUmTD_9)e9{XU58x*PISsH(w^R3`*J^5Mz{FI0n14}tr%EA2 zVq&U&Ecjjs-Xwe&3L(m-Cg9zck9woMkfN51x5$}Ir|4uS(R7X)k*Edg251N#ff;pH#3-Ap9B+km^_$TK?aB>D2^9COtJr;G;(13MQAO1vwZ9 z$xUEF1+)Qi1!MV>9Aa=uiEiKpybv}RwuVkAp z0=h-%87s6)6VTqgSiz6K+r_HonPwQOHts-V7@n6_?qG&=Tq2WGgu6|yG?<%k&V*EL z3rd6Q4GZ4_hX5s5BGut8SPQ2qvkxzao>gJRVw@`$;r@RgHjKg%DNYfP|{596%MQxUtz}H)}7pOS>rL zK)Ck8MQT!WXXfMZpgB62xKz2dQPf%lYjzP%TS&b&BK~K<|>O8Wz zjx3o(4KZ6wJ8pNT4oM|?*5WO3!9G7f)&#}C`Vv~X$0+*%z%P64P(POh6ERPME*#6fE z(|J)^k=qk3(vWf>z?LidjPf$KrIJi9sCw6eF*GnD{zpM5v*(hQB^L)E*d?J0qy3nS zi1t(B`fbh;7dcyuhORqP{G<}brKKsQu}tQ|5h-1=3-sORE_oh!a|xTzh2+}Lr-q?L zu3Ip}tko{6gmiIL8tc+i&6_6>283;^QRSU*rYTvwf{5z?jKw3HXZgNJHgumMbnYz%Q!IDpi>*Fm$w>m96>0 z4X$`R0p>O+)Q@pYhWHlllN1l56!&q8uai4OZ$NYuxt>+= 0) & (label < self.num_classes) + label = self.num_classes * label[mask] + predict[mask] + count = np.bincount(label, minlength=self.num_classes ** 2) + confusionMatrix = count.reshape(self.num_classes, self.num_classes) + return confusionMatrix + + def pixelAccuracy(self): + """ + return all class overall pixel accuracy + PA = acc = (TP + TN) / (TP + TN + FP + TN) + """ + Acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum() + return Acc + + def classPixelAccuracy(self): + """ + return each category pixel accuracy(A more accurate way to call it precision) + Acc = (TP) / TP + FP + 返回的是一个列表值,如:[0.90, 0.80, 0.96],表示类别1 2 3各类别的预测准确率 + """ + classAcc = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1) + return classAcc + + def meanPixelAccuracy(self): + """ + 返回单个值,如:np.nanmean([0.90, 0.80, 0.96, nan, nan]) = (0.90 + 0.80 + 0.96) / 3 = 0.89 + """ + classAcc = self.classPixelAccuracy() + meanAcc = np.nanmean(classAcc) + return meanAcc + + def meanIntersectionOverUnion(self): + """ + Intersection = TP + Union = TP + FP + FN + IoU = TP / (TP + FP + FN) + """ + intersection = np.diag(self.confusionMatrix) # 取对角元素的值,返回列表 + union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag( + self.confusionMatrix) # axis = 1表示混淆矩阵行的值,返回列表; axis = 0表示取混淆矩阵列的值,返回列表 + IoU = intersection / union # 返回列表,其值为各个类别的IoU + mIoU = np.nanmean(IoU) # 求各类别IoU的平均 + return IoU, mIoU + + def FrequencyWeightedIntersectionOverUnion(self): + """ + FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)] + """ + freq = np.sum(self.confusionMatrix, axis=1) / np.sum(self.confusionMatrix) + iu = np.diag(self.confusionMatrix) / ( + np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - + np.diag(self.confusionMatrix)) + FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() + return FWIoU + + def addBatch(self, predict, label): + assert predict.shape == label.shape + self.confusionMatrix += self.genConfusionMatrix(predict, label) + + def reset(self): + self.confusionMatrix = np.zeros((self.num_classes, self.num_classes)) diff --git a/model/tools/split_weights.py b/model/tools/split_weights.py new file mode 100644 index 0000000..dc759a3 --- /dev/null +++ b/model/tools/split_weights.py @@ -0,0 +1,34 @@ +import torch.nn as nn + + +def split_weights(net): + """split network weights into to categlories, + one are weights in conv layer and linear layer, + others are other learnable paramters(conv bias, + bn weights, bn bias, linear bias) + Args: + net: network architecture + + Returns: + a dictionary of params splite into to categlories + """ + + decay = [] + no_decay = [] + + for m in net.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + decay.append(m.weight) + + if m.bias is not None: + no_decay.append(m.bias) + + else: + if hasattr(m, 'weight'): + no_decay.append(m.weight) + if hasattr(m, 'bias'): + no_decay.append(m.bias) + + assert len(list(net.parameters())) == len(decay) + len(no_decay) + + return [dict(params=decay), dict(params=no_decay, weight_decay=0)] diff --git a/model/unet.py b/model/unet.py new file mode 100644 index 0000000..7bc7ea1 --- /dev/null +++ b/model/unet.py @@ -0,0 +1,22 @@ +import torch.nn as nn +import segmentation_models_pytorch as smp + + +class Unet(nn.Module): + def __init__(self, num_classes): + super(Unet, self).__init__() + + self.model = smp.Unet( + encoder_name="se_resnext50_32x4d", + encoder_depth=5, + encoder_weights='imagenet', + decoder_use_batchnorm=True, + decoder_channels=[256, 128, 64, 32, 16], + decoder_attention_type='scse', + in_channels=3, + classes=num_classes, + ) + + def forward(self, x): + logits = self.model(x) + return [logits]