From a0dc5c9211f1e8d8890af5362b63811752c4c021 Mon Sep 17 00:00:00 2001 From: Ha0Tang Date: Sat, 20 Jul 2019 23:30:43 +0200 Subject: [PATCH] add source code --- .DS_Store | Bin 6148 -> 6148 bytes data/__init__.py | 74 +++ data/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 2648 bytes data/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 2605 bytes .../aligned_dataset.cpython-36.pyc | Bin 0 -> 3180 bytes .../aligned_dataset.cpython-37.pyc | Bin 0 -> 3000 bytes .../base_data_loader.cpython-36.pyc | Bin 0 -> 742 bytes .../base_data_loader.cpython-37.pyc | Bin 0 -> 699 bytes data/__pycache__/base_dataset.cpython-36.pyc | Bin 0 -> 3458 bytes data/__pycache__/base_dataset.cpython-37.pyc | Bin 0 -> 3409 bytes data/__pycache__/image_folder.cpython-36.pyc | Bin 0 -> 2124 bytes data/__pycache__/image_folder.cpython-37.pyc | Bin 0 -> 2087 bytes .../__pycache__/single_dataset.cpython-36.pyc | Bin 0 -> 1709 bytes data/aligned_dataset.py | 81 ++++ data/base_data_loader.py | 10 + data/base_dataset.py | 103 +++++ data/image_folder.py | 68 +++ data/single_dataset.py | 42 ++ data/unaligned_dataset.py | 62 +++ models/.DS_Store | Bin 0 -> 6148 bytes models/__init__.py | 39 ++ models/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 1206 bytes models/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 1163 bytes models/__pycache__/base_model.cpython-36.pyc | Bin 0 -> 5717 bytes models/__pycache__/base_model.cpython-37.pyc | Bin 0 -> 5672 bytes .../cycle_gan_model.cpython-36.pyc | Bin 0 -> 5005 bytes models/__pycache__/networks.cpython-36.pyc | Bin 0 -> 12991 bytes models/__pycache__/networks.cpython-37.pyc | Bin 0 -> 11074 bytes .../__pycache__/pix2pix_model.cpython-36.pyc | Bin 0 -> 7164 bytes .../__pycache__/pix2pix_model.cpython-37.pyc | Bin 0 -> 4274 bytes models/__pycache__/test_model.cpython-36.pyc | Bin 0 -> 1543 bytes models/base_model.py | 159 +++++++ models/gesturegan_twocycle_model.py | 219 +++++++++ models/networks.py | 429 ++++++++++++++++++ models/test_model.py | 46 ++ options/__init__.py | 0 options/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 204 bytes options/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 208 bytes .../__pycache__/base_options.cpython-36.pyc | Bin 0 -> 5468 bytes .../__pycache__/base_options.cpython-37.pyc | Bin 0 -> 5472 bytes .../__pycache__/test_options.cpython-36.pyc | Bin 0 -> 1254 bytes .../__pycache__/test_options.cpython-37.pyc | Bin 0 -> 1258 bytes .../__pycache__/train_options.cpython-36.pyc | Bin 0 -> 2404 bytes .../__pycache__/train_options.cpython-37.pyc | Bin 0 -> 2408 bytes options/base_options.py | 119 +++++ options/test_options.py | 21 + options/train_options.py | 28 ++ test.py | 41 ++ train.py | 59 +++ util/__init__.py | 0 util/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 201 bytes util/__pycache__/__init__.cpython-37.pyc | Bin 0 -> 205 bytes util/__pycache__/html.cpython-36.pyc | Bin 0 -> 2366 bytes util/__pycache__/html.cpython-37.pyc | Bin 0 -> 2323 bytes util/__pycache__/image_pool.cpython-36.pyc | Bin 0 -> 1051 bytes util/__pycache__/image_pool.cpython-37.pyc | Bin 0 -> 1008 bytes util/__pycache__/util.cpython-36.pyc | Bin 0 -> 1950 bytes util/__pycache__/util.cpython-37.pyc | Bin 0 -> 1942 bytes util/__pycache__/visualizer.cpython-36.pyc | Bin 0 -> 5897 bytes util/__pycache__/visualizer.cpython-37.pyc | Bin 0 -> 5852 bytes util/get_data.py | 115 +++++ util/html.py | 64 +++ util/image_pool.py | 32 ++ util/util.py | 60 +++ util/visualizer.py | 163 +++++++ 65 files changed, 2034 insertions(+) create mode 100755 data/__init__.py create mode 100644 data/__pycache__/__init__.cpython-36.pyc create mode 100755 data/__pycache__/__init__.cpython-37.pyc create mode 100644 data/__pycache__/aligned_dataset.cpython-36.pyc create mode 100755 data/__pycache__/aligned_dataset.cpython-37.pyc create mode 100644 data/__pycache__/base_data_loader.cpython-36.pyc create mode 100755 data/__pycache__/base_data_loader.cpython-37.pyc create mode 100644 data/__pycache__/base_dataset.cpython-36.pyc create mode 100755 data/__pycache__/base_dataset.cpython-37.pyc create mode 100644 data/__pycache__/image_folder.cpython-36.pyc create mode 100755 data/__pycache__/image_folder.cpython-37.pyc create mode 100755 data/__pycache__/single_dataset.cpython-36.pyc create mode 100755 data/aligned_dataset.py create mode 100755 data/base_data_loader.py create mode 100755 data/base_dataset.py create mode 100755 data/image_folder.py create mode 100755 data/single_dataset.py create mode 100755 data/unaligned_dataset.py create mode 100644 models/.DS_Store create mode 100755 models/__init__.py create mode 100644 models/__pycache__/__init__.cpython-36.pyc create mode 100755 models/__pycache__/__init__.cpython-37.pyc create mode 100644 models/__pycache__/base_model.cpython-36.pyc create mode 100755 models/__pycache__/base_model.cpython-37.pyc create mode 100755 models/__pycache__/cycle_gan_model.cpython-36.pyc create mode 100644 models/__pycache__/networks.cpython-36.pyc create mode 100755 models/__pycache__/networks.cpython-37.pyc create mode 100644 models/__pycache__/pix2pix_model.cpython-36.pyc create mode 100755 models/__pycache__/pix2pix_model.cpython-37.pyc create mode 100755 models/__pycache__/test_model.cpython-36.pyc create mode 100755 models/base_model.py create mode 100755 models/gesturegan_twocycle_model.py create mode 100755 models/networks.py create mode 100755 models/test_model.py create mode 100755 options/__init__.py create mode 100644 options/__pycache__/__init__.cpython-36.pyc create mode 100644 options/__pycache__/__init__.cpython-37.pyc create mode 100644 options/__pycache__/base_options.cpython-36.pyc create mode 100644 options/__pycache__/base_options.cpython-37.pyc create mode 100644 options/__pycache__/test_options.cpython-36.pyc create mode 100644 options/__pycache__/test_options.cpython-37.pyc create mode 100644 options/__pycache__/train_options.cpython-36.pyc create mode 100644 options/__pycache__/train_options.cpython-37.pyc create mode 100755 options/base_options.py create mode 100755 options/test_options.py create mode 100755 options/train_options.py create mode 100755 test.py create mode 100755 train.py create mode 100755 util/__init__.py create mode 100644 util/__pycache__/__init__.cpython-36.pyc create mode 100644 util/__pycache__/__init__.cpython-37.pyc create mode 100644 util/__pycache__/html.cpython-36.pyc create mode 100755 util/__pycache__/html.cpython-37.pyc create mode 100644 util/__pycache__/image_pool.cpython-36.pyc create mode 100755 util/__pycache__/image_pool.cpython-37.pyc create mode 100644 util/__pycache__/util.cpython-36.pyc create mode 100644 util/__pycache__/util.cpython-37.pyc create mode 100644 util/__pycache__/visualizer.cpython-36.pyc create mode 100755 util/__pycache__/visualizer.cpython-37.pyc create mode 100755 util/get_data.py create mode 100755 util/html.py create mode 100755 util/image_pool.py create mode 100755 util/util.py create mode 100755 util/visualizer.py diff --git a/.DS_Store b/.DS_Store index 1de80bb211acc19a242e65d9d1b6d44263ae07ae..502330b5646ee96aa4357757ec349bacbe0147b0 100644 GIT binary patch delta 122 zcmZoMXfc@J&ndvbz`)4BAi%J>ig`1mB#6txkiwA2P{NQ{7F?8eRvJ#*&cA!iFP-`ZT&YK*?tk229PzvVf0L@os+04%ImmdHJ C$r>#H delta 61 zcmV-D0K)%-FoZCWPZbOR0003101yBGa{y%kbO3W~XE9T=@dKv;0SuFn0u-}20!smt Ty#pSxXQTnL1lj_#2MGNS$ny~F diff --git a/data/__init__.py b/data/__init__.py new file mode 100755 index 0000000..b44ca1f --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,74 @@ +import importlib +import torch.utils.data +from data.base_data_loader import BaseDataLoader +from data.base_dataset import BaseDataset + +def find_dataset_using_name(dataset_name): + # Given the option --dataset_mode [datasetname], + # the file "data/datasetname_dataset.py" + # will be imported. + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + # In the file, the class called DatasetNameDataset() will + # be instantiated. It has to be a subclass of BaseDataset, + # and it is case-insensitive. + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + exit(0) + + return dataset + + +def get_option_setter(dataset_name): + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt): + dataset = find_dataset_using_name(opt.dataset_mode) + instance = dataset() + instance.initialize(opt) + print("dataset [%s] was created" % (instance.name())) + return instance + + +def CreateDataLoader(opt): + data_loader = CustomDatasetDataLoader() + data_loader.initialize(opt) + return data_loader + + +## Wrapper class of Dataset class that performs +## multi-threaded data loading +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = create_dataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + for i, data in enumerate(self.dataloader): + if i * self.opt.batchSize >= self.opt.max_dataset_size: + break + yield data diff --git a/data/__pycache__/__init__.cpython-36.pyc b/data/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0057abdb17d14ea676c9071b036672453b268abd GIT binary patch literal 2648 zcmZ`*U2haO6t(ALcXqRhf(Q+zv`kw-U0Nj+5apq&qJl~*(WnxrNL!FKS$mV&%+4&f zmyl>8RSK{DA^kUf_G_Q|7y8t5J=p}JW>@z0*uLX)zmDhb;$ru&Cm+2Wbr}1bz4MDO z-as=i&`BnF!3K@wg8$nIk% zdvftDlZz@k3I_|4KV|);e{cfZ>~pvDeP-07%w|ujOe&3;h53w8cF1WTer+^@_Ylpz zMJIX9Rsm_19dgUYyyi#z5Fii3nvDY)%IGMNF;=5D_v&CAj-w;?h`qe|A~=i>TXj5+ zk60a1zZ>ZLwpF)m`v{}Dbr`<7|D4sKWV?y}-6pbC^Ew&@AFw0NneE6n$vT7?$>u>U zfqiX|2$DeTxNyw1kDm^*pUU)xS)1;EmD(MpRcdysSs~Lcm1e1#Z4Haen6w(D?--r# z<#s1+R?2LmpqbfplG)*oGI(2>muXS$DLur2Ynoit@48MtnO52s`IcL3hGJ65S)p88 zt7(x9l}m(>`OpgC;@ql*i(-MTdBsi>Ep-t&7x~-1gh9cWjLyG=t^&(L zpXnaY;B4K3{@|j`^tQ5dpLjhN(Owr03-c$_7qRW%*XujgM6Exas$rfL>svzJx&2`M zSypcUm{sdGpNg-=?K?s%Q{rWLX{@M9^_aS!v#EH!y)Eo(kxNzD-0lmj^J$^%BQdOs zN(;MJJ=`A_YMs!n3+U%C9iecmqr8+S;ASQ-w>?C{ZZ4x^QIB`yh%fUMv@TDA1TqQq zC9Iv~WPk^1J|~wju&hRo93mITl;nJ~-)Oiih(tcx2e8Q`D`ko#uSEa$0WvC7o z=U(-3_x@?R6|A}ofJ+AEI#V$==NC(1)te{}*Rl2q+5Sk2%h+Yv-Xn-g|we3 zg$vZRiDncQC;O~nzLo$|^%0DX2Q-!`8U-w_=dy0ZWf1=0(W%SoKz F_a9*1e!c(z literal 0 HcmV?d00001 diff --git a/data/__pycache__/__init__.cpython-37.pyc b/data/__pycache__/__init__.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..02825d704f1801a0775f5149d702bdf0083dc66b GIT binary patch literal 2605 zcmZ`*TW=dh7@gURy>^@^RZ~)`LKjLa48Dm)+*wG|Z-L`X$c+O}G4cEq}?G zN*b*tr1X{lkUa92=9MS@0=#g(+1M!&S!?E-`<(A~#t&Cl+YH}7531+=7Gr->XYmDS zY@_HGs00(dWJxV~;-T-A{wPQSPGf!fqfCQWTv7&+25bjR3sp%+k;kJD^8RmIrYKpqsP+q9iiZBjy$tS@@~qLgl= zeL8}>RRjBz%O+E=q*>U+?Xv-`wSHdeKbX3VWq&`}8carVE7w9!&DK;+hB7yL%M{b} z%k=KOR7pL?DaBZubTXDtsQM|JZqapHDaH%v0^8cnelZry+h)2LzjAjH7=0BL3p%_V z27H}wptO19MUauFu3&B%WF9(HT#&7yVORyH&kx8i z7|V1rHN|ACv1A=vlhuWpoFkpWb+v(6dx^|b??@3fMyd}ntSQ79kGR@|T;d|1EQ7y? z4&fT|8rcz>dzE+WRqTYq(??=>6zp}lRvvJBJBL?K(ZtVh>6h_-rsG`6%t+A z49z4W0=6x(vN)8A452Qe>iX6PWanZVkyo;DF3)SWRv*0Z#!YP0j&fl32Ff~D*DI)b@qLLjtv ziH=f@0I^LynDzUmOjgFTQMx}-ds6A{s-ml%Yn;RDunEmX|Ce9`Y1bCQfw6`}8)o?x5=WP65jcb&JN%85|EB9yXPx-(ylos2B);gM0iizF41n%T70< z2a=Kp4`e$%eLd0cr0MHfR@O5v4eO}x5`r(Nx=Ym=7S|$oxupII4aa~HkNl{G;=6G$ z1O2N>)?G^UTt@ULA=FJ8zCz93E}nSj!Udy|dy-54ox=m--fhR^?&+vAj@aysSKa~W cqBHuo{(tZt(UGP%3jmQI9WS~Xg%Q4g0q4GMlK=n! literal 0 HcmV?d00001 diff --git a/data/__pycache__/aligned_dataset.cpython-36.pyc b/data/__pycache__/aligned_dataset.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c99224218e21840d3c5e8c010a948c39f6c7df4 GIT binary patch literal 3180 zcmc&$&5zs073Yu?Nl{v@Hs0)K+NN%RRzaZMx&@j5uI>G39RtoH$htrwG?>yFN}^1W z@^E(7UQn0HJ@nR|TcGFOdu`9{%|LH+DjEdE{tLOZzcLoRM?zrXb$%K9Hh`Ibt?{QG?S2A6S>#fm3cp?IFfcZqxyjrnoE z&($$cM4YKaJ%#yXoW^ROck?utT%F{5PrGTnMc`~jGvCjVYLB1ljZrR=-cyK=Mo}ix zB#U`IR!N=}HCl>UpMBbTdD-n(YrTXBp>7b_Cq!6janK>Pwn9t{E2hVW-qgmJuL4yq zt)V|@B&*$fc!+tTw*uf zB2Ie}hhVmYv%Rd6l`zQ?guw~X{r|VZhyp+BGZ3zlKJ_Fh(a1C$( zvUO`>D>k%Cd&-6g+9{hjB>?<|J#i-$RViIzZz7B*9!5QcvQv6xrRPH0hlOBBu!8o; zN;tJAKFO7kgq*i(eWh$LlO?$e`LENdUh2&C+^?J?RJCD!YL(dfi9hE9@WGs~g0If` z5PYZ>h7IAUMR;_%zUG=G+%;=zV}-S^m8h3K(vS6PyR73SE6r8v)$m^QdJeGn%iuL7 zgR=UXlGQSNO-aILNi2#c4#_z_OJYg9f&Mv0OX8YXM*keACGn93d}`li&e2Ub@nOx z7S-x-tz0Ay66>N0o?$IDX<4pVAJUd2$Mp{PN)lQOuHJRI+ zWpb0Lb&%$fI3x*w_uk&m@aYg&aaQzlIVyDUAb${NxDm96L6MKNPc0d9M$r=;^l&j6 z_38l>Op4cwCxfIr;J8oXE?p41njfk0k>XhwSA90;{4D3aG#Tp(EcRnvQMv36ba+3{ z`tv#KNhaba9QO(g>FQBdJU)u!)3~)HS81gc^S|TIFR6~Zv%9yi*@Gz?RS}Z?IvXK+6~-yp&hWm-e7MVt-2kucQ7`i0XyAzb>v?hJPH7c zcykrfO{(ZJ#MNi&Dx5VZjWZJtEoNZ4x)h&tNO1lUbZ2!8J6-wG#75+$Nois;DMgtn zJnaKn0dFRF#bj<4jbg&;e5PwxW}^QGt*Q<R4~g6(a#2LiI7&6fPf{?oN5JZC;0Hen95aqW$6)j5*t6L|xq zvtr`$I4Kap?aLz}w@KnAh+d+?s{Ihf(|P46X=6IG@q}t^-pf-F%ULnye0=Y|Iii=S bz2K{2>}K>4pTH literal 0 HcmV?d00001 diff --git a/data/__pycache__/aligned_dataset.cpython-37.pyc b/data/__pycache__/aligned_dataset.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..41f6128103d95b94316e19b6fb050c1cb4dd08a4 GIT binary patch literal 3000 zcma)8L5~|r74GV8cbD59k7qL^ljHz|Jzzkd>MjFwe6m` zyPfKqBs*$fm~h~>oI%2Kh;#vA_wLBcdK9iq3JmPq0XxKU~Z$9AA(3nvWXKj+q$u9TQBotpDq5?;u^->Nsxtc zXve$>&+2jA^8O^s8gYX;_Z(T1!E;9jNqFSNP01cQo$v%JIC~vt8ohp*?DjRNMz@*u zA0j_6lkRiph{eA2r2pK(?(Dj8AVbN~4`p3O zXt``)2Tjx3nWUpUk;}ygE~cYGEknJ7S`s9c19`UhXQ$(uVA@w@q6|+>vSsK64&1rDk<|si!?8_DDvbUk@x!3ZL)T|zu3&m z^zQUyb0sTeI{a7+iY)8ra+2nWD5g3s^0GmDvm15~J1^VOeYMt0CX!{=#9cCxRgM7N zr`FZpzCW+o9@`I&U5tgA>w4wvhezy)WFq`+NcPArZ|ml7;3R;^-81X3F+EHKwkrmz zn0D5Thbs14ug2@)+$ zYdL$N?BqDX5@$hx+*PVjI3B-DCPM|kUL48daOyNml;){Uf%0Lpj(%xx0Sg#so2<X8u9&U?o|Hkbn0h1OFjQ9 z*NAju-aK|H9Q`a@@*F&0@^$d_B_Dy0^xC{7J-v<{o!Qska-_fI9J#Qt@wF1oDnx8? ze(P3Ed^Ah@67?GRPJ2TS*oPH=O-WwWUsJMPMXxDI*lZxq)Nn~w_^kRXj8^>>POHAf z%3AY=3?8|kxJugIh1r^IH)pN9z2p`n7iKNn|Hdb7`6)15J9cFF$bsgklFRyYM_;VM zxi6z*Mp(;+Y#e#Jyzq$1Xe-0Q=@;N3Lz+%w-GQ z`lbUv3}qYq2Jy3Y)t+CG9_97{#PGY6UU(nU`#oLkoQ@AjfeOV&iXIP9%Qnae*Y#I?#wzxE_p$%* zh4_H5;dZYZmY#)jMIqz!=*A>1HKOnI=1*xFnlo??yXQiM5Tit=I?06S-k21F{-nGK zW~F*UYA63g>A2RY?&xic6X8M2%VD9i((n(850V^V7+MJCMP@>3X)-VCKQnxY*T|~Z z_n=_os#$+JP6uOwAdZI*QQ6ds13f*^A|K$v%$Ho47h*U`r>2G(M~SKFLJh_ydbh|& z%Qc&6E|X^h!5kBr`axbkK1hM=TB6(nk7r(Ty6?thf@$Y8t&_usWw^bEBZ Z_G;4;eq&+!O#-`YfxutjPgL=FUjUmb?6&{_ literal 0 HcmV?d00001 diff --git a/data/__pycache__/base_data_loader.cpython-36.pyc b/data/__pycache__/base_data_loader.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca9773387d2056ca487651fd271ac92c10ea9f73 GIT binary patch literal 742 zcmZuv!A{#i5Z$$%V4~7PFBEa*l1ohC(nD1t5(y487dS;iqm5??mK?9UyMPiQF7PuI zKf~GA{zXrn*(d=bR(fy7-dVqSvp(F~>b#uoUOo#U{t{LzrtFB|z7U-V6eynvkI_hA zbS}DaPVZMMqU?y^_K36?Nfe_9B}PwT6l07D#ROZ}rr5#_I-fS5kCoE{<<*%{sBJFk zC@+qTKEWkKD5zcz{}45a)8&OYcA=`tsWO=_YWN2G z2Vk|Uw1=W{9!#ZAn0~1mZYLA)w@{$2eBo!{ZBdun4n&^Yuuml%8Np(Ql6!eCU}=)gNg7^>XSfe)p literal 0 HcmV?d00001 diff --git a/data/__pycache__/base_data_loader.cpython-37.pyc b/data/__pycache__/base_data_loader.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..00a1ef4fd855f995d33742377a637818f297715f GIT binary patch literal 699 zcmZuuu};G<5VeyuO)JX40_-eV8em{Ts47TZpbTtbIGNappfm|~I#6|^KZE$Ato#Km z+__3itDf}Uo$WilyYsNu>oT<8i}~Xh`foZk=3+2Fm;ya_KBs>`GU9=snqzKKhnyV|`HX4~0=Tq7t`-kV+>UOUcL+t`WvV zk__|t_!Fu2jLM!h$yxfjk6wz2iV`DLSw*ETUR6?=s7gyXgVS@+%H*noG&dC#xf&7O zi87L?8$DqIKoMlNg?{;72Y|_G1wa~xi1FFsW~==sC$aCCLY`oIrlrpLoI%q$_U#sB z6*`)>s(m8LUJ{L1cS$rASvpfKJlnY$_xbF2OM0ER@d3LQ_D^F=twZzL$vY%QuQ=t= zmeeX~9Qw4ETuc-|-36Ewa+*=R2k_HG<< R>vt8^!R1o$imDBQ{Q<&-k)!|s literal 0 HcmV?d00001 diff --git a/data/__pycache__/base_dataset.cpython-36.pyc b/data/__pycache__/base_dataset.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ad338871f822d175088a6dea3adca41b0d3a52b GIT binary patch literal 3458 zcmb7G-;Y~I6~1$S*mu3QvrWws5};Q=s<=wMX&Pve!cwv+ZBSTM&4Q3cHQacv?dx3M zYi4G%#JZLcUa1vNh(7{}e*uaAf#-Ru!~@EEpLpUsbG=?~nyO+~GxyAoGv}P|ocTO& zw_1&VJ-qhGCoRYMx3lr*Lw*lU{RW+I1k0QWo49=!vYUBRzwfhZKA6<{wGSNO3IEg) ze&SB*ORpb_K-5m1{+6hV5Z{K_5)FK}#kOeT+Z61P(``M+>CR!7nc%@R9wh@Bp8a_% zLA?hdar#U+efLvm$OIGaY1sEFlYKu4Dq~C5udU40SLVV`9Uo4>)VA5a7pvrctkLTM zTOCecGDJ;XMyDOkN@wX#+@(ViLX{nMJrk(oSt2FvGa=_`o@&nFxUZ6Ii0>BXgDh5x z^S?X)o$rl{X|nfdmJHH3+k4E#?OX5e9mV5RX{Z{6lHQF)?qmeRb)lMkr- zWjy0gMC?bPW3acWjf0eeV-4CtdLxv6bH`+S+XY(%-&9Iq(Moyd2tVIpfbcnm&qDFR<61&42_@1)1(3E%+AiH9a&8JQ#Y>p9WQ^K%SP5L921%D|Qhm4~A9Le}&`ujz-LonHq|mxfgACA)!WN6vzM=1fp* zPh5Eq5{=l%UAQGkMtc*#bTx=&=`I;_R&#ai373p!^1oSa7uvsj$?9iz)x+xPYIV!< zhA)ECvlE2fQ^THrT-Z}T-}91_dsZuimO0Y`(*&o^%CEL=hz9N1#mt{5s_o(n^v$O8 z7vy%O(^_Jc;zFeDrGx&GXo}X-9l7rS8H{&I#0lzYADT;G2^1Pf;q!qkX2!Suwi?99 z?MW*1cnvfa5`p|GnrfocPU*lrTe=mGxs=rnO%o1rnI7SMP?(oaM@ViPdR4aVMGx<0 z@$|8X_rFJnqB--|Mqn-6)Labh~TT{0hx9K1Nv-(*t-UuUgfIHU*~f>o_}3-j}kFrk<)J zqSbyqFZeJ^XQorBeOSnJUgSE?eonLDQE`;y3L0C7$gAZyQ$HxCvqB|so(iMiUi-we zRB00wRGnw+WoyTlP3uC9T!E^(h0b9}^{}w(;d)7a7l#(@Yf5l)0oV)Mh=y6b0zX#;y+jQ_9>gfJdaBo?-`_0Nq#ia6wv`o;-cc^<8UB9*3F;3Fa7#W$$ zAJZCfO4Sj%ma={UK~b%r<)nkG_$X~;1_j-A?XCgU3H{NRO4^q54G8r36d-H$1MLd8 z1kO~(Lz@jeMm^2AXCEp4q1hhCs;a9*%>Q`=xKesTM^t1Nz=|)xb*C56UZU1SY^t2Z zD!Or9+ytQAn~^G_6v$()VoB@;^k?gI2H3A!n&Op;9Az10Mu388XN;lsM@1)N(1&#k z8WO&NGtUyJD1y1uC+oLFmVQ)L2sVdPJMN!Q;<3gzRa=u?6@nq_*R&wOu(*|WMW zv!e?+8265KnyDT=l}%t}kL>DcOuS_8 mm`!O_VCNNKCCh{1-Rc##Pn4jDAL*7hyvvuv?}nZ58~+2BXdHw9 literal 0 HcmV?d00001 diff --git a/data/__pycache__/base_dataset.cpython-37.pyc b/data/__pycache__/base_dataset.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7127dfa19b10c20ed4e88a63dd4b9957b42a1bf9 GIT binary patch literal 3409 zcmb7G-EJJW6(;9rceGlq6-8BDr)>sB9C%%Om83y(W2mW>Hg%H-MeO!s+JYG`N7|9+ zXG>C1BaDEe5|E%*L7%|7DbQCa&==uVfnNC)+M9ldv%6Y34vOxAJ|O;r661v+c2#+0<-LRehL^jovo-RJk+K9T!KA`I3bh<6! z-S%TRC`f{zqv_`=gvlW!WOL~pqq!$n4n&9Ar4fIJvD> z(6vF8mKhOaZfAYj4o!}o1JI!MijAg)sombQooT9crfigJVN{w|&)d3D)-vD!T#W0o zOeC$_~s?w?AI)OGr!!itq()b`_rRv=4p|iWG^l3^OLP* z`KugzX6d42}wt`$GKwC@tC<#}GxK zvwdw+laI^HOlqkJOLdK^b(&n8QFNRYg%Fn*fDptlMKo|UA$n-qFXDYa^7}#LyYa=r z=@!rucWZ9pv<>VjQ@#FwfYreinjk1*PYp#wF~Iexp30K1w(xed2aQ>JMi8GGU%^2I45#R zf6rdS?FZ2Wh%d&fp4!0Gt9qOww#T_NlLgR3Bm((m6x~NrOhw--4dI&9@0`2c5?gf-B}oF<)RSZ)N!uksy1;Ne}+^b<2ECWhW-I9 zVeFqj{<<)H$que=q`n7L+exR>tdh33Kvmn$Cob@c4q?0ZAME}LCFy+vBGlzBn4oUB zr5`!aZSU7p54?zXFs=yoMHDHw&=+kNMgJL9gUj_>@06cJRv^bI3sxp5gh@ zu%}MmM|wI8F6cf5d(Y^Lsp!@{qiPu9`wndErVvw=S4L1z9H*+vs{`9+qer);Vc{;ud^>>$ER#Am zb^JJs37LFaWO2P8-;^WuL#!Ki)g3f;@Dq*a13d>?_WM9y-KC93RMGuu;a+iYC;iq+ zn>*DuO%qYon^e7pDp{LpOtSo7f{2vrM>I!Sq^$^Dm!y6HK@->KG1){^0#F-~0ig$; z(;5>`=tnJ4+R^zI1bRP8q-Xt_@ub%PXQJ`Giw544zTv`m&yrxzu1-?j@+y-j|9Awr za(V^FM6xSjrB~p3dlZu{CDTkUFF8(ieCwvX13-s&VqM2MkjGe?7oiL2=kv4&@?ASs zW);P#C@3-#NtZNJbe%ssKAs?bIJdAu!8c^dc>py<@a;##z;ny@qRc-PCG486BCt&?9;r+tA52a2@}b4?c3dNVc5f zXSoJ@MvFwJ6@Bb%_x<)+-9CmXFS$GBLfR7Sk`YcKTG;4b`|jGIl%Od;ZV(&%jg9F0 J(Ps3m{{r3O4bT7p literal 0 HcmV?d00001 diff --git a/data/__pycache__/image_folder.cpython-36.pyc b/data/__pycache__/image_folder.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f78433abaeda193511ceb7efc7492e0a4f5524b GIT binary patch literal 2124 zcmZuy-EJF26rR~XuN^ytCjEgTVNeUT2u^4vRH`T{ZPJjcF;P=dWKoT_p0TsedUrcB zP7`Z=gCoH+@CLjM*K^Aiuh1*bnQg74>R5B;%dml5?#F z`hDms1;GiY2^rF%({(_15_jZvJzB2&UEl5pFuL61J`Z?>hrG&be1X>owQhws_~IGq zhWsPG1Y?zd%rC)M+aY|JubdISBAnsE%i|G_s&r$vpwJBkKk;+7{Ks&e|rbk@)URWJ66v}sPFHAd^yTNKG@suQG5*=jvjOuy1D{FlRrqGa>|{v zP}40kqci-x+nHZ+-X6b$SFV95$cv>Z{gVty0f&EqH(voFaEYS1p-8!U9qYrEW1RN7 zsfrUVQWa+@Zua9uq|r$HP2Nn_4zrP1JIF;Zj*_*PjNe|ryY?(f2R}vGnvQe!9b3Q6 zq)@5QEKZeXSt@o=>_s^{9t@a1W-%A3j`bo*nn-b^#yjGD($67wdHrReEQp0=vS$&~T(NFN$M zRKE|%;MApQJw=IS{38sy}+UikI5|4hB6cM zkqpZCc9e_YyG5}S7Q^}|IuZ<{sDwuPsU;AEdi0v}DXr0vF49Ymybfz~K#}DbQwt~_ zM<5Wep~FjlJ00Ub-FeV*C7v^024$jxd%)G{I|rn258&UStczZ0t+gA8W8G+Z z@)qcjRVf!ImUG4hT;6wTfw6&rA}2ElrilFZ;umdLG5=)~1Rhg33muQv8)2P{W%Pw0kp L$qT)(5qkdv=zHt_ literal 0 HcmV?d00001 diff --git a/data/__pycache__/image_folder.cpython-37.pyc b/data/__pycache__/image_folder.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..28318ddb6a01b9a561fa3be11e5752f41bbf3c9d GIT binary patch literal 2087 zcmZuy-ESL35Z~SV@Y%6LNYW1|5(c$Ui{KQHDyX8Ul%`P%F;;^pGE}R>yLK))-(7dt zZEBrgNF?|hcuF4mm-dw>{)O_w%(}6finVrjZa-&!^PBO#Mx#cc{d>Fo`OgX=f8)n_ z! z3_(8ZZchTx7W^cAQ>}Q19-7{B=}-ytmu!Q@jltjy`k`y1op;kl)FG z3M$x1WaxczOpoz-uQNZwULE-Yk6Z;&k|#@3`Ug2BB|W1*!ktfODUDzkVb{~|BHUEj|~a=ovGEX;bL@&nm7dfg-i|BT=IoGYm_ zX?T)p!}CmTqu7m$brG9*JwZ(PAWwxc&uJ@~BXom(a|a&68%L`Kz7qZhIrm z`*Er_+w;v}qZel{?BdSuHsAW`hpo=u&i9?Y7F8i!qktCR)l82-w~Y%IuAVfL?6vU8BsUZjSl$GR+>i3K%Xh|*u+P~4GhsHPLG{BW(h`^ICfAOYl>P`Kl zwbF<0>(ju5WpGT@hP$IuvNAYh)3BtcQ5l|52*Ac;@^tMn1O&psI+JTRH2B`ka?^~P z&*C%@O_3mTGv zX2ZBBAi%AP_3&Ki(1puXo*PTK4F?!qZ38!|MeyCBI1>)T`Y1k>9HXeEK|bjv5QO^l zDtnjKXhaw3GE>)JZVo8Y7Aflh#q9tX0C)y~3gDn~+Io++@3lOIvA2E>X4*4A(&A7s@fWYAHF++^iJKF&n*R&y4vW|G}$PClsZ zjf=v?c{bmiXR2<-O5SO{q-(933xxF`ZEgM0I5WveI*zCnxLGZrxB|lZ$!Mr;)u=er zgItYlT}d-m8AqOJ8_qCph1SPhehS#t|0L zI%U9=MHbT8Z?gIt#O?DVeWU`m^NQz?-h_#g901nukSX8>L}F)j3fLV$89S`1n=r>p z@70{AZUe4t!1hS#?Lq0E6*_-rt+N}6L%!rRhf`Sj&s%nT;6nQf%Bm>iIF4D zzk&lA*-5~o*)OCKn_Ri*(R}}foInSiQ}~Ak^nnuN$~RPErapu@XC=@d&3e;{3g-}2 z&fNtLDmq!Y_^MlnWQQnVUUP11oR4xbPEoIO{&XCtvlW*zs)g@hil|R7zUab=`LCKF zaGS=ji21bMh-$OhhFyXR)!%QAO_J(1)(9Io-TLms>+XE#5?nRlu&v~}jr5w)5KFLQ rt24z`?$1h$yZSZX=a$%LW9G^xTB4UhEVDY)pU8{ok{|g|Bl7Qk literal 0 HcmV?d00001 diff --git a/data/__pycache__/single_dataset.cpython-36.pyc b/data/__pycache__/single_dataset.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..1559f3c603c57b5a58dd7beb0ab33dc522fefae1 GIT binary patch literal 1709 zcmZ`3%Wfn!&~`uP5fWBec9RPdj}>#sj06$_gvd^Ir6>{_A@+irgXwKMGd$wy zd32m)a(_J8i3_n)rAP0=8lj`3+S+m9^sqWM+j%L{!(*P5c@E-PT1Z}2IxUMjM6;}| zZ28U&4~GA*bi+ARw$5R48`=;KaE_R1XYS5D+@0dy*Fi((zNJIm0XkgK-KLf3i8eS8 zP6rG$Av&bWQXAJoR@TO?Vm&kNyi5zoNvil*VOUKt=E$jyUzbWtp)j$gb5#w@xV6j< zRUgi;;;gzT)=i64r*W1pC9-P-g!(k38>|oi^8U@(T!~-C0yBq(*jRiE1o#J$*dixP z(>ZI{DV;-AozfErIBw$rGhxoFvu|hPa6Ow0p7sBIYvSMf!?&YMg zla5WIJFT7_U8_lFpBudVX}5@TIk@)jN`Iv~MD}FRF(Kzja&EeuL$fb3r2CvdU&Psp z)30@`(H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T01hy>PoshJ++x+wx$|{ z_RsilJm>1kvo}w^q;wDq-jd3DFW=i9uB~;yy}otu(?{q#`sX8{KLb)bz!+jUM~Rgr z#|23UM$BXW7nFD`U?E7Kbyx&4VE8TSt^5ElH0)us@;p`iO~rVgVfzoh4iG$t_JGtE zU<@~C3qo(91=eVW8@$8|!qCEN(2UrJmV^ax>z_Vq$jqDhOSF#;?tUbTU=cRKELft( zxBQUET^%-|?kquT!iD!~_dRMnh7Ke7Sw}2vaO023CR$>Q^a|@(*?0gs8qPa~#i{y9 zNS$UdHOK@`pVL=mysi56*}Yg#x#Y2$R!2FDM?6krbu`NIRH?Wc$L9&hCqhr-_J5fc z9Q;xp7pcysTtTc@lyP32aG61g`qGxC*EK7ms4J;+F*0kdYQY@k+;k+b^EBfoqLhhD zQ)&XCd7(_;!kDg5e}PEbKNHkal-l_GROnvEbX_LYVo5Hx`dH-Log;UKS7iX6$%anl zgzL5r&cpcD&Ujg_z9M-UjDDRCrd7cQZ)={3G#`v8d$99(@GdPUFVkwEMNRM1od;BM zRYE9HDov}Bzp(60TGQjngz94|7%#QZXH-j3=Ung8tja4%^+~mNmgRilkf{NM2m+FZ z7MN>eQL;baj+7`T4x&Y_t^-5c$cV_RaCbh3wnUcP99{%PqXy<; z0rNMrLyU*L)+4zI5XATlfE7hrvRss$R<#yYslenCxF#3cdb)1A>a^?%+?fqq$=04K zt~Hmp;kU9ofN+H6HIV11JUCc#sQRE-R5QB~b}!I^zy=VD6Xw^0%^`Lh+IemWpeDol zro$3{*j5MeNvh&Z@>Fxy3taxL!7wXOOz4`JbGZhu$aP>n&v<~dwTNuP(lX=!r?c9| zYs-HJ+*B7Bf;xvl>sa2fp+XsgNH*H;^+zz{t%n=U7Ku&VS^x8doh<3f7Cix982qPp PGgJgD!mD0%JqmsU2`xCo literal 0 HcmV?d00001 diff --git a/models/__pycache__/__init__.cpython-37.pyc b/models/__pycache__/__init__.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..5472b962534e584e55119db24061f0139da17181 GIT binary patch literal 1163 zcmZ8hOKTKC5bo}I>|{-1qIoF@gAidscB4p6A`%73A$##4Obo+hx;NX&yzK5J4|+jJ ze#q|8ztmSxp1tYC>K+rshOVlfuIl=#>f7C0TIwQLzjyWFPamQ0_Herh7;j)! zIAuQxQ)d~R8e{@zU#0t1ye<39CVekFj>JH zypk%dEB3*PgRB|2RC18Q4TmbAr)GV`EBZIvMDpsy<+M=b76{}o;%?yM6}$$!iz5<& zj)+);yE|gpn(cPPS_VTQ4TGKnp-wF;@vzrMB#!|CA6=$-U6xrz3tq9bZWOO82`={` zG`ZI{$c@jH+lq&9XI5=%TK`lsrI>gEpQR-=!Vwnhpzlz5aJZv;9}X7PDQ%g(3u%thUqZ9|LLI2*uVAvu@lQlb{z*2Fph!5l>miBNL*z~DQB2f&#p#0n$dQT z)@w1MQjrT%1r+e&Dkz?*;)Q3P_zgVID^K|gdEz@gl15r@2!&MB(=&ZJefoUoJE!N3 zTCM!AA76a%t8<$6Z*Adcp#D8v@+JzdaTaLp?9SQ@&w8NujkduwZg4X&`&Qdx+WQ)} zc;T_e3lqIv;OrxExqGP?{37sprcPUp*v~%E-q=JNHlb1t-onaJuM7eAwAKyed)x) z^+QoW4^@0Q+UZ7pZ|9?d*YVw8XV>93u77Lio*VA}z>Ridf8f0CT)*K6PhtZdKa{Z( zh2HyA+;s=e(f+;@A2~kv!q|@=VS0ZMc<~*l69tiQu;uSQ>IB|SpH?e(cCm0r)wc(a z(hI2hdymk!-*-bE_@U=SgV>KkSs^`?a{B133N$`cw}L9NQo14Ha_3w>+7y0fwmqdKdqTvO$` zDmOTpy3iUGJ45c?{KZ(1`l-F2C z;z)G5X@Pr3e#c7-dy(k7ajHi$H3x3oP0e0}ZY!Ox*EtxV8cT=!LXeeGTe?RcJ+|vW zi9do3h0_tyAhmbh&Vd&~l7*#(VOrYt!cMpEii5NqL{d7mtun1BL`)y5N4^}o!K}&e zyZfF)i;zuSlp#>msF))-vt8K_{TKj-_;obPEfg9nFdP4?Y?GDoFR1@$Gq-9!XvwyO zJ7rrEm)TbBF(Y!Z{BSUgXKQJo$H}#nr1BK>dJe5KFm2>w=k5Omw3+)(@tWyT$|{B)DO*%rhcdk@}hQa zANwVNEj|o90lw0=ucsv5g zAOig5ZB%Cco323(GH)Jfp4PNZVsF>c78_YBzKNqi2v5Z0SI^ONzlBAV}^_N+Fi8pDMxeCNfC0t5{ z1}X%LiT3-tLX}PnrDgt9`%wE?wp@6eAy&~dEryY^FI+AN?tp}1g5}kuX6(uMEwrl$ zs9`aS6i$sxOPNOrnuhjLMNG{7>{p9eA*?QN(k+rk_|w?vd~UiEy-pWP;Ggo!3?Zj! zwK0P#SWtj3V^7HME{J_57kf(2x#*eetIelnA8$?ZQyo*tf*~@T;H}Qtc{o0m6tc#8 zj6^8Ww?Gb4r#D+!l+gn*+6+8bgnqc+5N_->-fmpkerpd1dl^CyX=8672Di{kSmDmC}wJv zr@>5ku?m(a_1YkybcRA8eR8tC@<$jyX9e=8bv9bC@hKyg^dc~;ouPDksuwW@>2-No zuNX-QxI!nc0PQd9r?=XciqXwNrshS=m#djtd!WX6SzAwtNJu1$r)#^uxOSS`lOLW? zR=>bTmvSpUJzyTFNzXYS{sj-mPo{YOEwqF#86OfCCi)(*J2oikQi)h0fn~>LT!j9x zv*JoCrT|!)Fm4=};?ua?tMsZ9P1S0N)vNbbCR%TGq9x`;2Vq&g3iam(rF7&u7(xA( zdd+#)Y>jGHknh0USLF5}?q+mK6x3W(T%2#MpztX+sW?N$SthW8?v#qfQ*WQX_ZEmLGbV$j?)lqsVJsNa}cYyYbtx&^D#5H1%61$ z9Hm>Q@qL>0o`u7YF@dCXD`WmDTV-|q9NS{+`si$KL^b7XjG~PW8IFEr#5rL9jEBTX zOc0KhAZI=c_`*=^KoXDI= zYr|BFW1nD50b`z=>=j;sIg1FR)46CRw1$Z}rMzNG5Qny_Kb%I)7P92DDuy9(pVRl; zFs;m1u8YM^q0h1!j;o3D#0Bmq)G zIt?SwFPlt#ZN9GUubXJ@y+qrT8Y!mwCo0GPn0oKu|4;8_2|T!0q5~8Gq*obRaTQU^ z%%T?BYKaL}*LrpMxcNJMT!>ea0talwXG!7btH)R4wWQEnpTM8NmFLD_hu%g~NUWsT z+r))G4R2C!&7iio{8x+Q$K!QdBze z0uBi6E~1<(-p774xBgdF$hbhq@v3?Uqd!2B5CmSJ>@L2IsZx_xA=ao`?7EWjJ3+*Y z2IR$}g{rtq&z2knew?nTq;~;M_EW5(dgQF39&^@Exch*nCFjlxvbea4;s}?d>~n!N z0I3428gx!Gjdga8E%Jw&e8wMh1T9g=FELw;SVlHE?!sCk=fK-F|~> zdJwm#M*~xFhY1#eiO=P{v3!zxZ+>Q_I5AJuh{#|P9CjwUTusonlm@E6#2z;l*sYB%g>^(R>Wq@hf zp$}yU;)n`W$2r6|5sJ|CM4st)Ko{MJ3p!9`rlGG5;(01wpyE|3=zu2Ppn{H!;+s@Z zVko{t#oJWep@Q6Bd8Z)D35SYYQ&m$AOa>v|mX}f3X4SL}yN1gwmu%a%sAd;YE7`Sj z-9BU2)n_+=PUSQzYpLMVVBmGsXFio#rkeT=kbj_CkW4~T`&PCSw`q&iQIkIFR6!X3 E0S6j17XSbN literal 0 HcmV?d00001 diff --git a/models/__pycache__/base_model.cpython-37.pyc b/models/__pycache__/base_model.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d713f6b8b65d99ef3bd8ad86ea5ab3058395ebe4 GIT binary patch literal 5672 zcmbVQOOG4J5uTn0IpmVur&h1EP1}+!%J#1Oh>a++lGuWcBoO1+0u<1ZC$sF?<#2~H z-0tCO#S8<(DNZ2 z`x>`+?y<&mCwe2t*+*I}{}twI^%_f;ek!;pJbuS-#$PeK#5+6m;)jtqkWCsj`AvEM z4R1V+_%(FT;Ed~!wG-AbiRWXjVP)LFDbH=R^Sr=|C~ba?mw5RxYZQ2eFXK&-ukclr zOMH#5qb%_aehy`sZ}RgfEBpe#h;o@<;+Ijb@Gbrn%2i(D&+sZ%TI0|1=kR8oKhLkA z+~C*w3;ad&oa0N7^QE-#jw`*p5%+?oK9+<^lJs3X(gaB~3}yAdYc)O1bzLDn(Xf5# z#KQGMk;4c@zW&KhE9!VVP077(ywep?+iS*hC-%F}tImzr9N|fvoa2WwcB0UGpU7Re zyVD`|a%UGiaa4V~`zXDHn!oqRX-1un8}h&pJtykMeiX_j+L+p{-oG;4|MoXeXp%Ve zqjEtsrER?~EOaO#pk7mY!QM-<&rnHNqB)># zDj7;0C7w1Duf@jDoTx99j=rg>l9BmJgU9EoNt&2jI@Z+dBNFA^O$PW4EpX4j2dso9P&Y`NL;ng?A} zW9e{T2+~q&OZUj5W}^b_`+ew8I870CQ+wBK9(W-*S=d?_riEQEY_>YCI7o{@B&9>g zD$^xp5#vVb$d^4gm~{Caci(ep6SAfYf(20~GBwFecSS$+W8V$HehuAn6GUSGXS1Q9V|=d}}kpeOp!903du z?faRs^m=MQnwllHutZw)B>}Y?!siVc3m}6RyD=)JcjN{WNyy8^D?}D8P`rwFJ9s2X zrD6X?R#`wzT4gE%p;7^l5}}F;nZ-!^Lq%DYP8+3d{!IH&`+2rq_?sbCFfz@Dk+Uyc zF38wn5()oDR;rK);U=8#bc}=2k0vg6nZ`QLIqX(kD5qPc${cyi3+}NwWUA?;f z)*dYEMKD2}jlHfQOHWH`|IlL8>mnd1?-^TIqcHVB@g~Neva@uf*X3B-nQ(H5lagd& zcqKr5|LNILoHb$LEa>#r3a_COYbV-(9lrqC89ZtW62o_510Ge0%|uJ!TL=UCE$y&! z9|0k=6ibli#ZyGKiI6LZ8z2+0isMiwyjX?G({inop)`9!cmR@Yto$+NPf>w9YK`^h zbbQ8)1-ZzWm0Xk^kL4nyAi1tA$`vCe0aoaQ6#)H3`Sg0jQW3hA%f!5l^=4wG)*hfS zUX<1oA`%iQ;Mvly&F`Jo_T-1Bgw-#x{?<&3PfnPIX_9k_hkr%m_{*_9{|-_@mJAOG z3ln`0wmUG8Mwu!N%O%E0ADD3-@W3#>Mv1ur$)BHY_JryE0rJ zX~UI~mY5?QfMpGrXgoJ4p(D@9#x0JT;;vcmm#-q*fx55C?QY!4;FKV!wyLmrcC>=R zr`RBJj>vf;WL2pVc%i6JQ(>Y>>0VkwlI|Y4e&Fr~9{dftSeipVMZvU0vx)$C_zuVA z2D(%bQ=>Hq)f;7%yY2Z{8Q20pq*RWQt+Vhx%|=hc;m24&QnHmHe~GQIihhA@vNgSb zex^jVqO_;fsJ0|+T0qFl96dcRr7Pd#GD-Fe0h3QcnO(J9$|F67VU)g zFfzxGm#+&0J8YDGG>({c3Z1fIJ<6)j<3b5Np83{LX-nmUKju!hU-Z#v6B37gXdsO2K@Ywz<+CaEQt{v+#^aIx9d=fO7~ru146rtDCdgzaUQkJ{}mN7EYNYeD&4{C z50E4Tffp#fi|=Bo)TCXARjTG&uB7}<5b&Y{d9kRYD(+I-lHI_M(`A+P&Y6?_1Zt=e zGt^Mc8EPoI`+$}u2hRv;TwDV=!Xvjp=1>DBm18A?ZfT~m#xAgV{7{Qe@MCI0v@$&k zI)05xj95nb$#DxhCJlZv7wuLkichiL;Ksl_-b#$2HKGi9g#ZUy&5g9Tw1IhelVZGR zWbngiWda40TbDJ8B*f)#EwJRFAc zpn5Ng-JlBO*!4s;+N;jezUocnR@9slh?bFBGKkV-;toiMYT3*@2Y2#8@Q=D;1 zE%5h(NxXOo?6L%gZUK!V!I~OOc^qY>rxkr2!@j^XzJ}Z5tAVMo!vu;z#TRC{v3Qeu zZ+c^;I5ADs2*{ukTy{=$xstGbhLr8)Bui|J!~psi#3s^OV~D#1)LI-UwK>M&%lEaz z&G8iocM3BLw)e51Xn>?;;T`sTe7sP14CJm9B^xG;PKpOKN?m?_LA6DTPFLwFzcBmH z7$qt2(gh@k$EG{c<(aTg%)n~nTSeRPA-shmb-ElbS^_|#M% j`BY|^YU(?{?DO0lXJVS#x3ZJCO-H1zn)Fcz*Ja~B-aah& literal 0 HcmV?d00001 diff --git a/models/__pycache__/cycle_gan_model.cpython-36.pyc b/models/__pycache__/cycle_gan_model.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..3e899c9bdface373349457a2b811b112f023f8cd GIT binary patch literal 5005 zcmb_g%X8bt8OMtN2tsdJFRLz6tFGBtl4W=Ln#2wnjn$;CCNohV!=yzaE<}O`33>s> zk%l_SGwsDQ{Wtoj^wN9w)R~@g>ajEJrN3_h!lG<7J(QaLcK6$FcfbAK(8HM-=ifgZ zfA*J~it=~m($B#71GMyK08C+OthDl8ZK;@RvDVjHx~edp8L`o?v?_93iOs&%vQ*_j zVJ5R)D$EKi7kbNP>ZeNG{u-w#hjmp{Kk5ga@RKBozgB@FOfyKsUnDGyF|h|>_9EfW z(l%|F{7R;dfe962q@^*HX)l$Q9vWnwJ^uyhD%EDo`JLE|(Y-j22@H9(!ySqKPW;+f z2_xNY6AOMP82A#gGdvZ>An1oxx?BTL)U(CO0g}%hT0UyaH7rw4j7e@k#$c@Uz%kQa zsj4zo`4h}lre_+7kD0FwUB(k!f)$?v)eDisMF27lZ1QnJzkuosdu;GZ9npw1RC|aE{yDS7z&GpC&4JrgwYM- z;iUtc*PfFfArGscEtDKYEF5G}c3Lpr*ci+pRewcb`KZxTAPj6n#!*()AGd^l(Q8JL%-t%xB+{;VD z?2QM$@7>?ogYaD}`i1CI;RcNP0q>0Z5TU4?#7U6VHDL||o`zi5QR-(rhz751oRdL{ z4H#*3a_YB}ejmmR;e>uNBu-Kn>**|6gzl<_x~Mv;tuAV3cc+c_`U6UfuHXy@O^p`@Ce_M^kbZE;*Ltk0L~sbH1B9M&Z<3UtoCYXF4_#*EZQ8}ybJDn z3uuctW$B^{`sIS2Xlc(k(chr^|D^J#pj|mPuzM9F+P#Lhj<$g|H#PvLZ@O8~(knHsEvdG$@AmHWzIQ>gx_fC{$N2lb zce8uFZO|_Fc8H>PgZc{%C+(u$XA379TYLqb_=_H z%pTB;DTq}|hb?{IZ?yCS04oox4N-|$M#z1?Nh{6CN;6++3X_LH>^FpY59ZS>p$R>JA@MrqAWrj zISYkDPVZCvPK6bvN91BchXaYZD}hOvh+N@-3HOUA9f8AqQ%xp`e8S03gbkl2FFb(L z4u+%5AGAdkk-8Y_gU*SVf6deY0gF8OT?cH}4C5z}~!MA9=%%gVT_w{=VD}?G$~JQ`Tx#ho}(zlq6HjG2N2Aan$D%5Ra3q(=8dO>P6b9OAw-!BfkXv|iyh_VC;E z!r9B(gB?!Ko1#WWlt!I?60s|>mU}p04K3XVP*hi6(HsO-3apOiXr?-^ITcskQqeS9 z#a!K_hOw)<+J=Oi>N2p)Se;eR)}}-8Yr`xYE0Higyjhe6g& zTNac~gSsWF}E}}H$%#vbI z59>DH!zuY|NZb4Yfpr4(IF8a?;hO+;HMgrPudz^cy$Cq_1SG{`<)nsC%stffTc@KOV4B7Ua)TiTZ3h+7qCewho#+9+Pim!uZ<=h_IcrOa( zv5hEf93DL>>#@4dAK_r~+AL^yg?${fpAob5YOW;%L1kYU6l+@sMF4~Z9x`7dYBE^v z=aU>SNhWwXF>|7;I-Z~0$UQd)N~ zASqE*Lb^Vuo-|FegUl`{&*g>=4O{%Dpcj?rqmZ9+spHRp;6EZjI&Rs;OB;19*RdqJ zqG6tRmqh=C{bm>2{ulO} zCHo}%(pnvq`H(c^FC`BhAUM&Ssy+%<2ImuPl)gvx9ocXNRSE2NtfHZ+0|wPv&yjRB zjEIhU2XtCV=MG#(9*!?vKo{#XXtQW@7pO=MCf~wD*ptt+a@5Qt zjf^TcK*GtETX{($tE4>i$hf*Pz0{O`oLel}W}$M_1zef|DZIorc@t9^!dkAh?6&;I zZOAY-g=)5P|2^Qp!IruqtsgAqUzlKu`*hUr30$!!r*wyd7|W>#px}!Gf0~W(w}G~a zBxxdT=6`qINPw@8DD8{{S=w2K31qFv=vH_;=6VPJmY{DDjPY$zteQn{^sKkg!$Ksz zUD8(zDyQh}d|#o_Cr9-frvxk(C<*cK`3N~M^*cOZxgMn8DT9eB3X&vW321JlSvcfB z$B_S)!0!OA$sS9Sx&_&DQu6!}5)7n z@Y*Kih*FyTCk^HFlx!H*plhmIvme-|ZP_!nfxk^jgBmEAwX<<{qg71Sj>lQN!(uzTp~I%O+^4#y zXL_pJr)uo+q?${{VmBm{EFlOIf-n~$t|8(9gn$qM5(0s==f9~hMeCI#sy;D=AH?JOj<((fkjQ=zy9vAT+!y7$o8ip{0*)Qu&2WfUUx8^onq%DY?aPOIoyvU2fJ)=<&Cq-FQkWv&yQED44 zTTF>6zDtXSm=?8rhNuPBN_op}RK$#^XSs6XfS5(DDe=fesw)nPLrATn>|t>PWshWK zr;#!zjw0n~mQq8?G4VE}ye&(aLCT}zF{C_}rPPr!FW!!nw`VB_kn*^A2U6aVrOYDb zo#F|kJdve5A`Ef7dvNVg4cpa|CO%v zqULfS*1LgB3!N~E{jeD{G1Ka?5a#e+#2ei}kYN5-Ejbq(D`sr1*uoO_wly%r^Po7O zoX;o@X5A9*whh`GXnD{s2?Jl1I425;anRT@KV>HDS8*xL-Sp$;@(py!^;@k@7)S@L zU0$#xmp?raOTW45$@MVogl!MKjHA>EdttEoMA(ySez&vf$DLj{-}UbXa{l8_M6b+u zqWQ2F&$oK(p;&O!VrxBY#=TxQO7neR#vQ+#T49)$pJUTYsO6kUs~^dxAp1VB)Qdl| z*6#*uK^O<(!&3I-f|VB8s2AVu2Wd$~RkMF!$QktZyBlYgduzd&>;0gKM$X*y#JMM* zK6Aql+gJSFnYh#U&U#Ot^JEZ(LF{qTJ&a_L!3%!hyWMVk@oldo&}1jR>&3Ez4#bzd zX0O|mUVNu_@ouvloLTFMpc|b*9d~+iH9FnDn^xPXaNMB^dK1+U7-qq&k5jUu#_WQ zk;j_b;VDHefGR-Efde3GlshrCoddlRyQ0cGTh?_WvB3q?OR3xS*KP{`0rq2WqbCsv zGlsaMGW!slGgW9EMkTSWnn(JUdDJ)zMpLZiEVWS#&v!0JGW8CFgj_T5~?l#!78_Xdg z;)ELJDRc92uE$8gPCwi2HT`b%!P653P+gJLs*$Ho6bJqGm%qg4FE5bOnBYa7%^=OS z{k1i}k?(i?IPljo2FQZW`cD_ypPc>b;(u)3yRbufU2sx|?2)=HA3QJ9V!RA?Tkds5 zTIfSk0!aDhsczs)3e>zrjr09px6{14OLjVA)?dgs*iW5Fl2y{eRn_UM*HZU7dU^F) zqlz4A&OAnul+>$qVBb<>h2XWg`G7FOUh$Nk#A###nDu$*=Dxqvt(B68p2t;2g%#-VO1`r!0&!rPsaP9~JyikaAq2?Gr3EbY&j1@IKIBJ%(`5zJ|BS@N3+WI|Y5rl_yl(P{PN z;OoGuEx-A?DJVD3#||Z9#a+n@_jO|n7A(988g*q0G-yfEUN=Nx+XC$}XhqQOfmY1Q zDVf@4$rGSXfqEa*(w>r)xD07i!}^r&<23sCCURATv09XGqRbS(*pn@`bx!rTh91wZ z+HaWBL5^WxXLMipmQj=R2&DK<&}lEn4d-Gnye%KW6sPW0sB~Y-)A+~)(l8qChJU*g zNJxOw_B&xCzv_b}!nTrBvTg|F2pKSa`A*hS>dAPy*Y1UWck?~8V9>5Zy}uTKBp$}R zsdKaAN0Rz*qz`uf50895l2ouzO^cq_?D|pUd1=A(LVqnlEWhkWejLlxfmsrgIy22{ z=JF(UQtAjl_S32!jMBb-*Gs3f!Hg2h5h3c>2-F!@y(W?`UmC>lC%{fNZqz3bu zS((4niI?Y}h^PenvO`7b_xs(u&@eW#$ct!KbyHqqjP@_|I7o?M9KeL;ssm^N1GHjq8PQzq06G|8NL!dkl&pdVQGt1~x2%B? zLv#`orplsdB$QrzNwF1l3MX};nA+VSOmk3<&E?c-t_we{bRy5c?RUEV&2ErZKNLj) zU_vicRzJ0nwTz-^caw?3hq{VSL@@}AhAE7OrA)KZS`Ay+!U42(Uc&@7au{tc z&A+s`c(oTrBP&0|a!T<%jf!F;#w2c8E7q1_8Zjnz#o5ASdAS%Gxr26gM>`UYdKU z-9Gz_{Bfq*Ktg)T-K%Fg76^-^VZoZD`HYXjm`)UQTk!-q1_H1dSZRyw2JEwCTE@U0IKmv{Ha-wjmx4XF%uB{A zCkJjq4QVIt5teci9dr1-lH0;e&l>MH2Dw`&UN#agl9jQfD?g5g(wqcB4CF@;m!Dwp z3R7JX#C~&mAum0Cnj5Kkw~^Z*N^Ior@-KC|oEx*Vm(B96p1k8rv4nzA34sBf;6M|c zG&iS*Gmzz3aOE?szvOuU4gfMxeHG8UweEMdB2+UeTg)?+>B(l$?Rwt0)#2<@=GD3` zBBH=h!K}C{MN+sgH;V}QP@cFjYPRnqf*nE;B9HmqHU}k$w*`xKihN3OPD)9RtiJ6I zieU>>ZAB-kZbrQVD(&bpEplRF$DvCXRoRoEmO%Z5j9T7Bt;DvdB(~n6#7DK0 zVf=LZn%ZTIN~TjEMRTK1ZHTCF7RucJ_1B@7nrRXJ&Z2hMZZ&FmmODVos#WT(MQ*<2 zn(C$2^DnZ~6v2i=x*mkmcoS+*1l-zriyTcv2xovDOO`fOG*{{+hPw^g4mH&-Rv8KA zCMbJY1-c#!dQ!8BuobH~*lHBR=U@lOD3>!f(ll4mj)O)vc~vulD&=lFC`Bd!)9lLx z#Ss%=-x_tA@(%j5U~{F_MlhP-f~eD8>ve>@%PiUAvipr(zq1i^H_u$^M9^x$J>TEL z9$32SW2?KXbXtFrb?!==Twy8EFzmK?a78aQO2f{3m*hI~N%nHT#l7^Fi+ib!V$nea zhNTymScGpK-vv|t3@A;KhtJUVM?{B1>n2LdzQK)fuL%1FGerAFQ-uGVvU4uPv7F`? zKeR;KaO2y$zVBSfB70B;TOmV5elKdE2lt0UW)+n8pLA!%(o%C+5w^}fP&okAHgD~nz zZSV3gT>mf!B?U?O=Ps%*zzo`eMv>a38@2tgF%$ce8>4aCqoj#0>}wT2g{m^_QH4#Y zLli64W-Xh}2Q4EfkQhSgYVO$*OI9#Zu!o_Ge{*0bHarBdhFBPL3A&I?$SnkEG1%z) zq44}D8$_B@$|kL50|4~m7}BZiJDZI?8f1oJ-Ie1rs6Dgo;4aHGH;)WQ@Ic;?kstDc z8VdCghd~1a>uT4m#HOdu?~f*3ud)g$SZbh zB5m~0@bZbyAs9)S8Q5^zz`-M&7NiQ!9y&{lja=p|b(N&$F%w1EN5P#RLvUG!2-Ca_ zTHT--Hx;x($`oE&bQ^EH=*7CmEg_=M(IY zWti?W8h`4Ff~r;)TCbz|Ld0?XtFE`b4cZ@LTcLXJ1DC zq{NywPN9BR)jvNdquv=^4tT)<8icyq&b3Wd>c3t#Hd>?JG*oX&N%^ob>QP;njSmdl zK8mf4?v?6SX&bffrtaG+r*WDazZp?9u^fQuDlw^Fz9NJ6#S?{aKnBc#2Y)*4r;#-Ie7!|)Q#Y~6oLFD5|{_O z#!6O3Z)j_z5bLA`DmCuXfm~Qv{xSoWOsg3hnQRSWsd6#H!XfzrGb=sYXH3Zsjh&1b z(5IO?bk|CvG)XLs{X90$O5I8_E|HRAXKXHB^aXsuF8vqycdi9jKPq2kp0Y+$&k^yZ z)1==3Xpj26C_qM=+hYfQ-sQA(J-D?_|1DHnPQj*yDW$xX2_E-q4EE$28jxRP@GA^{ z3qhkAdZKqHjC}avIL`qSowfBYocH_fbiM+5bQyu(Pr4?5*jHKvSrzqFffcU84j;xo z2znpsRde3rzGw~>z(JlmtaT0asCgt4Y8+TB8ZNCj zXBbmKoiXkx{t<7aYBVPdC947643MEKVPK@k0bIR;GnsIoI^`&ac7M)D{XFanvGeV zh6}Zu>z%HEmpZ(y_YrAn9LKheGn?9?v!iL|Ni~- zrfsno=uQFbZ9bMMuKCGhsR`G5wtW~YG_6kw z1@L_9!a_l-sW0IxtzFA%AL=O$@V^Y2Bzvd%OM2fa@8b(amE`>{$x6BG<`4NhEMoS@ z$UcME-8u!--~RC z+JLY?yIT$!11rf8n5lXRJgJ6q6Zvogv!0#6?A9dz&6+6SboFc=nC!(+kL(Y$IAEoV z-b#f~DE|d-2DBLvXp7 z&#I)!;>($e{sF;mWfDV)2opd?g*XU?a|xiqhNC0D@g&?AXbdN0CR|kWgTmh2lr_M# z9CP1#`enEQQ4he-ur?)&dF++*ic(M%f+b@&1^GeSlWIY|nMocec93AUfHDAlnG;bt zHEsYa&`{&-XZ@!3Jn79d+^BGNtn~mP-T``+P5u$$6Gro8l5A~M$D}XpDWLorS!LvZ zXibF`)MOJYq=)DgxxU)z!={J8v=A<3ucIPD}XWXRimU_ zI?BnTO!od-_sp> zx-@ny^2C@{OUXJXbkuq_WFrszl3Q_Mt)5KsD|x;BfcK}AjN*_VW3?bJC)nVE z4yPon7`tXCMdfTPA+=0V#_q8U4#)OnP(fM+Ddu1bDX?OLYLb(Ghtw(BO-PX*9-Ner zt?8uHHV3t2Dw#%(Dge1Xm`Q30VEQ2=nSm~>hhI!w`9He+ux!@;bW%onOP5coSy^4` z6A!jHSt_X@w`w;lSHMx=P$(V5CNzp_^^aL*x9P-1B0zIOF&&dZ zcU{Nm8g01ieU*a$;kU}`uvBV+s=BZ*jIUELub@e2h9Q85-Gw_1Fuj6U39b&fYm`+c zz^|zo>7(vVn$%n}C~4of!pn###t2aL`H>V8|tF^Rb6@`A4Oj$XGMO zMhm88&YhA>+ohXE*8E9OCUle5OLhxKZqQ5GO$imC+>~Q`UXY(eF=fvc$h$n+A&^_C zx0QlgLN9hJC>mR+Cb%S_nohzS0{_SgZqp20Rkp(zbLOeJpie-dZ*!F>u@J6M9Q1c3>x{c<#GGpsALfjTIP0YhWkT|D`f8v zpfbW??FRsqx9Voq#=lYyNR5IQfZcicG6`6amPzeI-ce4# z8J)KcPrg^iCi|;U`0&4?PRJOqKp|FTnlkofC1VAx?)ANj2kiJ*ya)y!B3{#4ywr4k zU*ffhI&oLD-w$vp=g;t=%zEYy(biiD)^3^FUrO$f6XsCINJ{KFuaa|ObO>qTfrJzpOiW703yA9LJ z6LDN89i>lg9FME?I>E@zsq>uT$}}ABNTJIO_l0ntck3qH@h6y1ILaLAoTYto!riZO zD#01mS>7k-pJzGMUSO27vQN%$A?Hc74#1#VKc@OQ^!1m*Eqz{AP_CW7XD1G!6SQjCAiPm$2pn+M_g C`&NGd literal 0 HcmV?d00001 diff --git a/models/__pycache__/networks.cpython-37.pyc b/models/__pycache__/networks.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..1e77d38baa664f0d47471eb3f689b6baac0aea6d GIT binary patch literal 11074 zcmb_iU5p#ob)G*CIV6|M<*t6blC6y$yP=~sM5+V7k@ z{9CPJIQ1;%^3FZ?Z|*(kJKsIC7w6`R8m|97pS=FJKd)&&poh_ABl88^QCZhCp$R?E z8tPZ?8Vy6Iybx53GiTGBe0^?S_rAPl}wd@9K@b$cw^Vtx*te#?P$qvo0~27xcrLMM!3FKqfvtQ=nk(j4wfxTBj$60GNj zA(vxqU5|}*Qy9YBF@}109ufoJcD#W=oJgFRu3D7rXty@A_fvi!VvplXW9?Xizu4)A!RNXL;Da*5m?a`(H1f zUF&uIv&~5O{rGHO_SXGo9G#6jefMeisdKLMqtK6CT5T7vaFvpmz5dy5PxwJ}7P8wt zxe=Y|-$^TNRGf@ZM)y$_iKZ9y8tx_Bwo%sQBFdVyU3Agf<#C(1qZr9OZC!2L6Yps) zoi^=mVQ9qqx{+v!AuMd68NS4|a-=hJ2VVD<@b1$mnr-wH5~0ViMs#K#qv3Q;jpOJfHhFlIZ|f(t zGJPX1IQQL6Vm4Q&P{RX4l(HOMbuCTyvY)@i@&v#i7)>R?zd}w8T%Y@SENp zY+#OSHZ16W0Hrf~{h0So*3eyAQqzXM1^zM$B3e*QKdo;)O8YUkm^04>y`~pL&!3rU zfbJaXmft+JqS)`Z|KyM9`pyMX9V&36&X%9%+FrNoHS+zyi+yi^cYt!Ms{eeU{k5fU zU;X#3yBGG!0P9w2k$R+d%Yzcj)QQ)iMr*x5q=i0I--p?5oeq3Yl7*~D^tjUR1)b)d z8Dr>7bh(~y&^%QoNve_-uJNR=T~F;BnB}$WjS6b;nz^)?QMzy)Skd&aUH^sf6*X&F zQPXpXZFHd5uAZ2C$O#P!YONlw zm<8xm0!?1mhX6S!r@n2-?!sKz0{0Cu_$xwL>jG1?e25yO0(Sv@FLZSR23A zJ8F*enB(_Qt1Pq)M}7}8nWHC~$sUb$j%Qp|GyaAyt*N;#sJR|&!;4gd72ozd?X|dJ zUFwBz@Bc5Z)V>Cc_M|+6_09}Z(;D`s_jbpZumG#=b;3q|!-GnMZMIZW-4LJ$8_+%Z zNi`PMMwh4G_SPDl>|bmCA`>6D{A%(w_f0;^V#gj1*OS!T~>SfIcn`R zd+an~mtL*gaCizdINaCIat(Snb{A71<8UE4B~sdG1qq@06B()#vLB^9EMN&(U&Xz3 zirc8|h$iSW?rCe8vb>7^O6EsaZg=9fmB%AOy}s-aGJ5@fa0ifLqKdqPae12ZGUdpF zW%E`dHrDd=z(0d35&30JuPnkoUqG(O5Qz-^jwu64Vt9jX9kjs$Ie^F6HTSe=Ikw>Z zHCT%&^oKMuNyzNoJp7TlXAHF%wvy=ZO$OOQLgs3&vK|3MXi@7&tt-_V2Lo!@6Z~-?hceINH zWXQqVSVG^?hefP|0b4&EV^!$AlVXw+COoq}bix*-3SXwBb|!rRQbL$D!m30Exl*ki zZu=ypMM&Sxq@@{YIW|Q(F%=n$t;9*}B%c(vEv)w(n4+A=Ow&rnSf0HFf433s(>v8o zNkk+q4Io%^}-`?N?~ zpT6pY1A5r76=p>%WKJ_+ApK^q9m3y;dXCvoI)9yc)wpFk_bP@PKPro*ovD|I>h=fN zz-H=T-dWZR+pR|R_F4xVf=8tdzDhh(c6pY{AYMTi@)v2+7Rl91uxd={hwzETx>v~p zB64n;Zs|os*)-WT-{jriglwOj*Nm#r0HUja^*5=7H7Ev`%ZUP@gkp(>pA`f`I{W;{F}I z+$fGF?_QQJ>PedAVVisTCvNWLUq-X&C=$(3n@gZgcMZ=4U9Lewi!r`NAQ_S0^tFaA zw1y#yOv-U$y6ZGdBHTE0TB0Bv{N~;RqBd-pV=2vF{o?AiUKHK@q3ZC*7qZL&_$6B- zg^Ea1V*qOcKD}$~Y0Q8Cxa|5tWr!F6gaJKV0%pV|y1W~q=Mx9 z%QwD6@0MBtrTL4O_>p*TiC;|FBGm9!MhBoB*&v!ZuM=d$N)Jq|$00=?@@jUoDq2qpauVks}*uxc>ztSLAC*)4cRs0nYhN-WZh3;U=fgmiS#dQrI9txu_g#pzKWA zVU-|5ut8X;Yr!#VE(j4)JI*K8wjLMO>3Hx$f&)NYS}$)HL)c?tjq-76caDxEc&^a) zVd8pH04>NR7LhK<=S6N$Zv*9pJuGg1Ij$zfZT+(L#!JJJu#?iFHY_I&w>OjW637>& z!p4p0jQ|zEQN^a-M&(nznx6xOdF$)1p?*@No;FXTf1CSX8J5uRf@%j!VI?{AAzO zV~^jLHd+^@9cxm)yR@gNC#=bl#Ce}E@F?S6S`Rn*e?I@NV;mdbOKy?6a<*bc591x*mqLCH|?&ydEazl+J*%h}; zD;Z t{S|wW#a*G5Io84k;lmRaQj)P`*LQSCOO^1+OF#IZ2^ZoQ-f0VJMcQxJ0H{ zok5;o@dP4cHXZ)(YhCxRy_)8v7q4D=2rDN#zgkhlX*=Tq5LwYv;+>F-+efiU@E;BeuFGE*!8p#MJ z85ude0=fdMWV%8~IZapcl9saJT@6vU4Jpa_}+*3Co*$l=^99P~i|3~-j;tZ!9smSbCMY|C9s(0+M-hTBi;w@Hlc$ z<}qezq@6?!h^Or#G(I$vjGyT|OR!CPl~mtE0DF)HuqS#==KllS6TM!Xp2G-56bYL` z5wZtu4!| z70Hz*!3SOBwCJX_iv9`7ER@865TS$0a4K+khV8)S5Dm$1K7|+u2EzviB0(#|!ok{P zX&`Djs{Ph8uW2j?9UAq_DjjaMP1KK6t=>X#}(Z58t+hd{U4C1PXI>6JPEs6ffHUvpk&F|Iyw`|nPBqA<#<`u|A!`!Io%_( z1CEU4|Lbx|oqt5bD2p-lr3$Zl6J_ITxR6v65cOkPvH&Ekg&!og{7=<>)HWM`J}IHSq1q>vtgUMG z)%!=BZk3c#o5#)C74W6xh3SzopLtOrKg+YIOv{Ef^j=MBd}vr;t9Dl9pJUl2S>W_x zV8Z9wvyH#{-GuSZ^8Ja=t0^O;omKqqFDQ49k};2$ze$hm8;QnCa)POI&G$C$5N~IH zD1V0}ze~xVQbJx)ne-x?bfW-gf>=%1Y#>Yu&_QkrQFzonYBTFN@oy5K=0wC)PWr(> zsr(*H6SI1e>gH&P}t}Nh&532jKag~BIX97UPi8nFbG09?2}1e zh}g$C#Iu(MQ-fX@0ztGJ zKz=YGV|BFWylC`-D)eQ73yqEee@3Nag09F0Io_w~r5OtF%0*-A)Ptt{abXiF7QVGi zuxa_Uq_QUN@nvPX=qY5Tuu0*P*|HdyR4fw^z_HAUKQBlh&Dfta#k(>-ACR`vfyu-j ztzs6l07cFg&}6Msbo)ufLl7ZZ$3G(sZH|Xzd7`}nF%rleY>5F^YpQrioZl@fNCGf8 zs8hsU+WZPWqZ){m7$5_yh=de}rMQeX$^Aex`p<)#OhmG%#Q~tIR}~2HHZoECaQu;fNn2;?jVQlc!V&x1fc?^LS=ktWt|kjZ89VzD+gL%tz4}UhpC6x?*EU}E0p>NY z%!`-nBbnD#^ojqS>-T;9i^E^xK^=BgFhqH8HZ1npW1BkMN;;$o%jjclC1wrwKrynS z{s=6lYHXdj|MS~VVFWbmksrKGlRm@`4yTM@`H)dRE*qdVUz^y#(IrVGSRvAUOp^3o zo~8T|<-`m8p9g2cQ1LCYQzl|*NeTELqWB-H`1>^f8x#Mtf%^A=H>v&_lGJ{gzkcvn z3Kg-XgS9%1eT#C8wyYDA_2C>v2-3oH>YVsICa*;Vu@v1TT2#&#%Z2&k@v5^_JNi&< Qx%PDJOzl|heC_Q20VI;s5C8xG literal 0 HcmV?d00001 diff --git a/models/__pycache__/pix2pix_model.cpython-36.pyc b/models/__pycache__/pix2pix_model.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..866164b2fa93f84b249bcd62d82b5b717d674dc5 GIT binary patch literal 7164 zcmai3NpBp-74Bu0o`w5}q!!Dx7Rw`xqDUXIz_B$HW7+^sD>fY@9XOrlRCCy)*<9U} z%)#hH1<6G~eu4w!;#)5H2S$!Y4!NXH0g{UjL5>D;alZGeyV=7<=Af&q-d0twzW1uS z_}0Wk`EOg7Kl}QQru|FH{N+%;k0*E^h0uiF(QJAf9ldMVhOP-y-#( z4jNX^3!n7mqo7F&$zM7s8x&8ob)ngY(1r0#vrR82E{L4SKhuwNJFhyo3&Ij5aEqcW zDtKF>DkktQiAgbqcUeq}8N4fER?OjD6*UZVKB_+SzgWY=<7v_e>5ty@JirssV0co- zwk)9j6-^=+c%402M3amEJzd%FcfFNQ242&5J1e`6SiASbl^wU&df)A@g#N%;b?&V> z(hD$Q$L|H9)9-nkRD9?ToX4$}6Fzo);q^j4Jaj_o4?14B;WYc5zI5QC2Zzm$x6&oI z1uFv#8jr)xyMx0h*K@m`MWZ!PX!>v_JFmOM-DxrlTzb<)cy_->Wu&z+Lt)TTKR0#N zhG(DaalI(4xUP#*XpoOtPM6Q1tBWdV6XfaV00~p3Pf=z-f$uXiHBc^y`GkV?SlE{2 z_MONg54pXjS2v^50N`-~f9L@%eWw$&+#V*}m))r7cKlY)6IhU{a6{lKWS>Nnv7>*q z+vztS1>jFV+4r0K%3}`JU~@RVy!`P`H$UBKeEfmk*!ke2?bh@E{N>AU|N7^5>PD0c z4+maU6yBbD&ijs-4(90wc5f6@%PL7!`UaE-`WqC_WE*ZYjvs94}!WW=fIRT z6j5&9>kNk0@^bUA3B}<=YGu!J!`RM^`@^M-9pW5FZ>j5s&3(VuVynZMR5O7{tX;c5 zT+C=E+;}Wel{^wpB8aN45RNNb2VKA{%I|ghZiu-S2CfV|38x1Dp6mCHaV<+Uk4wOr zzjp{QcDrs*bo`#@^ar8e?*$d`1DdQ>nAWTKP3t**R{f0OtuYjyZzYAM2}5N(49!0Q z3nAb&(nk6bGSNtH8zar(L_km#9Q?v+S)iJtM6jf=7n7|5C^-PpkPpR_0DM~}DEXcx zrdqJ)*~OeUkHJ{|~zSGF*W4;t`GSBD61|HC@%j+;hEs88TOx^GbNN{c3yZ2w~>D z2B~X2PSUyFBH9h`Z?<0tb*g;}QhNIhJatjqL$Ez3FWgoV;hXI{kh;K9T6><}M>LO< z^2?CA(Y}lF_K`{msMLdCq#=PTmiqvh$(E6ogzOLc2VqpAUK~yys@kG%imkx`l6w&4 zDf9SpB5Q90_Qc+)1ho;Bdrrqa^km>{up&H1`WtJKGPtagkt(~bvF=ckjEY!;EzBq~ z!hTc`-ebS%Mb)PCLr)@F!Hx1^UpDt6qt}ZHTdOQwNd>l6qaqfIsmE%xd}WIFE_^33 z39d)EhH$%48Oa0gLkEGpO(SBjuSS(r0S+tC078rq%FTW9NQRBIcLxKnCy+a73yVys zMipAtz;AW?zBrMOB+f9}D|mt*qtK9aEOXwNH4FNrVHs6@$tdUry^KV3MU|$dSItY1 zn=}iBNxcSoO`p_<3uEc&7$0Nlh({8LG_QP!M1^$!6e+C@ zPqIhzF+UGlA?B06b0dTK#h8zcAmt;~fv;Ai6!VeJ+h(j!+e2=YWBqcb|HMA^l^f-m zU&+`{^s&W&UtoTfR|w}rnpAy4(!3(G>4wVLnP&P*T`hLbtJ3g zBnWz1U#}NALhL#0#JiM>D2X<>+&CsuN?2#&1G|=$R7B^&xacYUW1|^KA}aljlV~X7 z{PV<>;un}ru!us_YY4-dUN!L0+f?siZ44K&+88P*RWmC7jtV@CXbX+DT1%~+S8I2q zGdd7E?IJcU9kHlk8%MNZ1JfLwP*aU^yxL@;KxlYM>u!y5!erSBT7-(qf+lh`inc0R z$Gw)(nn-(P{EhN)uY>?=RET?>M2nu(QoFj4*sHduj_^UjD9J07ytfmd;Qsk;eOkVQ z*+%)#53tWMc;v4^P~c(6yiI&0^w4eY%XKvDIo`Q8R$~b8nlqkGa(LnFVk|D5Q_AZO zDK{c(*KIzcb&OFKRTL1|bTl`Rp1}T|?az)SR2L^Cn5b+dJ&Z}j#FJpEB%?rIZH!d` z%DjSnJqZ|KCptW$(GrN7e1ilIG&)-;}(r1$m&1Cqut(Vl%7AmnVFDiu5H*T7f7+LEImS^%6a$<0{0u~CiN9%Cun35iE#kJ%n$y(AvijxfF8Dhkb9;&@n8@nB?I zFNun^v8bS3o1@|$;v=&GCL3J30dsoIr%_OdWeImyCcW}z%jq?b2;fxQE->N}4NB8$ z3s!MMXfM+i`;?HLPm!K%x6}j0kd;Atd6XbM52UMJkaCf@DK1RhN@>5Piwqs5s&l!A zV;5uy`-u#WVrR-wL?$x$#ynGoz6p!p(-*~FB6D6}6r+jEd40`721me`&TB3sb6#Jy zjLdm`U4TsDxAT0+->dO>&g*L-W9PiS79n#H53R{GZN}Up@6Q|=w4&T5PtV5fOYO_F zMKhgqE^9G0mA?YMLdB)_D>W^pU(L{8&Ctnn^J(woqlHZG)x+y84BdfdNr4YcccvU9n?GbC!Z2T-X8=O*Zv+OH!5Z{c}^bh9W=>v{!IpfG=? z1TIt+e-X#0BmE`z*pjk)raiWF%vbgnG@ixyND~Zaq7r)9k)9=S{zy+spC0MB#JMA7 z&(`5f%0iq=+ShnJZ*o~@{%!DXx8F>6&4N0{?eK-=3j?`S-a=8ob8^Gvy^_vvYMY$k zZ8EtnQQu*-6xKyOds%O8UOB3+m!iy40A8naSsj z{0*pbmsq9PDO;=ZeQH-?QMwGZ)g2ArBH5GUH@1-(9CRbo-3@GuP2iHrE+}K!!+@l- zw%n(}r-BZavO`6eiXIhxD(J&h6@2koPEVAZe*`C>W4eZq8k|D8tQoumb6GQP==hy3 z&F?hRSTnxUeq}5fC+-}M?WbEaj;P1q#vvL2o<^|4o3><1lxB;+k(rij=VNXme&d@c z=HM84N)JhkX{A)J97~nx*cpGNnA0p#xoxPe?}4| z+za&;=Mma+K*U@SdIPylZJI`1=X2Rf7tba4M7VrNE|R{1xVlu^Eqz$=#tNIB;7URl00-nzllq%{l1Ww zX~1`=_>2njfJ`UFGm&2q?HDRtBfRjxhp5Ga$DmC^ug+L+S(Y_xU9_%R%PfHtDxJIR zse{n(+>LKua5=`;dJ!;(n=ih7R2P>TUq8x^k`B)5IkMhXGkk~O6VSA9=fywE_&$U6H$R#1-DDK1Mb$=VSs0*+aN7uaE`*|B;ok1Wk7 zJ);d)Ij3ZcldJd>04%`3c$ z(cvy%z*y!rzKF5H>wF1gl`rFX4e5Ro{rWDNI!A|2Mn4*(XMppRzchen}#e8 zVPQXw#vx3a`2DmS#JEf%2GS1tQ8x}bwCD05gYh$UUs)@Z^TA#}=^Un@*Pk9lor7E( zKJ?Ze*Von_f4}|d_U7Y<-sbMZAMbQe{`Kd{*MI%Py{0Zp+0ihRHV^lMQ9m13*Vg*M zV2=m>_C^;c#P73@#>-E`sC$sD?k8e(dth`CK@^`kg(#EU>#%9Gf8=+P!61lv zKZ-*?8D>!ur&Z9?Iy%j&GZ+6lE3sv^tdDQcS^HuMGZ!!*gv2Qq_yz2eX%lv=9kWw} z?lJ4>$C|Gqhg^cg6}lEM28Sc?GMS&d3R8lyb@4+O-|EsDGj{kw7m|E4vnDpV1f)4= zWwZ)f)g}2Z+5*Y7jbxXVkco}U4t$|iK*3r+ex(Kw20zw74zx%pulf>dOLbak^GYEi9jfi(l8IA?zew) z`R{-G`rr5Ab9!G$b1y^)HNPb&J;_QGM_GgbL)_U>x=W5IE9B!or3!fOmr**xQSw<; z<(YgUOcLP0&&e~}(ulkJVi`D5Q<4$ITqMI}l*uyf;;UfG!7goMcZVZC;;A$#QX{eO zs=OBD@%{e^=$5R+em^)0Me4WI!q$eMG$veH?J6~GZuyi=qzz?n<0?|m5^3@9%cv7d zwxQh(7CkYK@E;K3zD1%ZFVaF}acN3*ONYp}eHugs-_Gds_PKDmnu{V@PIFXeD{s6RuH9qZyU z*JHc|*R%;jx@V~HDJjbP+yF(NqnO8xm*y}QFytZt-CNYXiB47{ zDj^CHh5F4X9*#2ch?r)tah`a~l)e@DA*d-i7ODR_{>ZUxe0ff>Gv-+lKsu%1QhIFx z$Wtn%(w>2og4Fs9iD&{%Dvg>2-2~k#=->riS9Dux2t@=5ZaU&V25EGHEWa#WQGAEG zRqE8W-v(4TnpeRCvPU7h**b%_9-P0mjrZ+oAox$PG$jw!SOad>VB^NzeG9hPZ7B=P zYKRw@kQYr@rcaPKd-N=r`JGOgx;1=aVxMZ5(H)&*c+n{!AJxQcMWxUJBuuPh_}w$e z0dI*Linj`>crJR?D;h5~$ctRe+{wZzsxPH2GuPI~yTv>HkNeG<_%2Q-%_pN!94RG= zp8=68ah0GCh>kpw1)T$N4-?N-nGhM8@``e+GgO|m3kK=z1)W3EU(?Lv+b&X(bcZ73 z$oWc!XXlB1%gfPfh7@+r1*=nOSoawYl0Wvw4a9cMXdvoq1|IvXbB1|k$;>$9F*D;( z)%E{eQ5nRX83gZq53a|xuc#bk=yO>0`v`Oj-dA`K{IBI9LrvwNf_N=Y1w2#v7wttx zpYmL^7rDWd=c2t(ZA@z5Bllb+ck0!9OXNffOE;<#R?X<_?x}udN~LoyTGNU_P)K`~+^JA7aIbjl z{H_SY*sB(lz5XclD!G&8Yz17f6*FQJTM(J&O3JEUc~<;-R$j|`c_pA~70=4gBRo2M zo<+qkNy6vUopa#I%p2Q>AgB0h#OH?c+|oM?e_i#)W3#58d#TP=&fJ(j7iJmt3rt`h zjTDR#nG`0^%T*Y4QY!FfJ_2LSaU$E$x}H7bbP7&8=T!N_*xCIUug%Y~mt3M0FTy7y zJU1!S9dMN?1icQjjM5;Cww8|FdSNw>9#%2cxLEsXIaZg zh|MEJc?VvpublQ3I8oJO)@;Oq8C7*v^-pzGebrx{oTTq(zx^~!2>Fxz>+?Z=0i?eI zCWxRF+0cwqlv%|#Ugm-93BU3)e?~+g!nZ_(lHPh*Ap9T5D7vGtbwq7)nQQrdBV+|i zTuXD^sMq=qCds@&{?Ericmbrp047OB1<9D8g1se~hrVF+J5S&v5;5qZNWfQWlb2E( z=cGWl_~D2&1=1Lgpsfnk0}~;FWdJ%PT|`R;QfQ)eEMk$sexwwr z2h{K5i6=mVJy}1xOMyFpL^s9``N+5j) zcp`Tmu+gcFO8rvhWo-u&t)(($Q-7yaqZCGO)3Pp2nOEgb+O%r4=9nRE`)^9U1)rVP zbGhvlh3DE>$VFL8KC@oET-f3Dsw`Gr2{?dk21RJp#-??{g=*SnYixq9xVdTNL0sak zqGh|~rO-COZDlv;5mOPm@tW2gA);cisVBhBfb=OaLi@Dug>1-DHl!cZA>ExFt9=aa zLCrC8SMwi1*i%8Dk~LI#O>Zewh=GMYuy|nc9$0*^_zx@rSc0dBlh9QyoLf4v^wQF) zrB_)r;caeKdK6h^8XF4vrYxjF^r=3uLkXv9fJ(32E7~ABmzxme2q<(~ z>>`Dpc5h8$1J|k6D{|9ChIeuisZUVqA^Q~B39v(rvtH3`ULlqy)2k0^ZKFk_u5%@x z!Fr7?P3T8Z@!jKNV22fR$ZX6VL^-!9M=WkD338wFp8?8 0 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.cpu().state_dict(), save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + # load models from the disk + def load_networks(self, which_epoch): + for name in self.model_names: + if isinstance(name, str): + load_filename = '%s_net_%s.pth' % (which_epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + print('loading the model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(load_path, map_location=str(self.device)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + # patch InstanceNorm checkpoints prior to 0.4 + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + net.load_state_dict(state_dict) + + # print network information + def print_networks(self, verbose): + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + # set requies_grad=Fasle to avoid computation + def set_requires_grad(self, nets, requires_grad=False): + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad \ No newline at end of file diff --git a/models/gesturegan_twocycle_model.py b/models/gesturegan_twocycle_model.py new file mode 100755 index 0000000..c8cd3ae --- /dev/null +++ b/models/gesturegan_twocycle_model.py @@ -0,0 +1,219 @@ +import torch +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import itertools + +class GestureGANTwoCycleModel(BaseModel): + def name(self): + return 'GestureGANTwoCycleModel' + + @staticmethod + def modify_commandline_options(parser, is_train=True): + + # changing the default values to match the pix2pix paper + # (https://phillipi.github.io/pix2pix/) + # parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch') + parser.set_defaults(pool_size=0, no_lsgan=True, norm='instance') + parser.set_defaults(dataset_mode='aligned') + parser.set_defaults(which_model_netG='resnet_9blocks') + parser.add_argument('--REGULARIZATION', type=float, default=1e-6) + if is_train: + parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') + parser.add_argument('--cyc_L1', type=float, default=100.0, help='weight for L1 loss') + parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') + parser.add_argument('--lambda_identity', type=float, default=5.0, help='weight for identity loss') + + return parser + + def initialize(self, opt): + BaseModel.initialize(self, opt) + self.isTrain = opt.isTrain + # specify the training losses you want to print out. The program will call base_model.get_current_losses + # self.loss_names = ['G_GAN_D1', 'Gi_L1', 'G' , 'D1_real', 'D1_fake','D1'] + self.loss_names = ['G_GAN_D1', 'G_GAN_D2', 'G_L1', 'G_VGG', 'reg', 'G','D1','D2'] + # specify the images you want to save/display. The program will call base_model.get_current_visuals + self.visual_names = ['real_A', 'real_D', 'fake_B', 'real_B', 'real_C', 'recovery_A'] + # self.visual_names = ['fake_B', 'fake_D'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks + if self.isTrain: + self.model_names = ['Gi','D1','D2'] + else: # during test time, only load Gs + self.model_names = ['Gi'] + # load/define networks + self.netGi = networks.define_G(6, 3, opt.ngf, + opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD1 = networks.define_D(6, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) + self.netD2 = networks.define_D(9, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) + + + if self.isTrain: + self.fake_AB_pool = ImagePool(opt.pool_size) + + # define loss functions + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + self.criterionVGG = networks.VGGLoss(self.gpu_ids) + + # initialize optimizers + self.optimizers = [] + # self.optimizer_G = torch.optim.Adam(self.netG.parameters(), + # lr=opt.lr, betas=(opt.beta1, 0.999)) + # self.optimizer_D = torch.optim.Adam(self.netD.parameters(), + # lr=opt.lr, betas=(opt.beta1, 0.999)) + + self.optimizer_G = torch.optim.Adam(self.netGi.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD1.parameters(),self.netD2.parameters()), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + self.real_A = input['A' if AtoB else 'B'].to(self.device) + self.real_B = input['B' if AtoB else 'A'].to(self.device) + self.real_C = input['C'].to(self.device) + self.real_D = input['D'].to(self.device) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + combine_realA_realD=torch.cat((self.real_A, self.real_D), 1) + # combine_ACD=torch.cat((self.real_A, self.real_D), 1) + self.fake_B = self.netGi(combine_realA_realD) + combine_fakeB_realC=torch.cat((self.fake_B, self.real_C), 1) + self.recovery_A = self.netGi(combine_fakeB_realC) + + combine_realB_real_C=torch.cat((self.real_B, self.real_C), 1) + self.fake_A = self.netGi(combine_realB_real_C) + combine_fakeA_realD=torch.cat((self.fake_A, self.real_D), 1) + self.recovery_B = self.netGi(combine_fakeA_realD) + + + combine_realA_realC=torch.cat((self.real_A, self.real_C), 1) + self.identity_A = self.netGi(combine_realA_realC) + combine_realB_realD=torch.cat((self.real_B, self.real_D), 1) + self.identity_B = self.netGi(combine_realB_realD) + + def backward_D1(self): + # Fake + # stop backprop to the generator by detaching fake_B + realA_fakeB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) + pred_D1_realA_fakeB = self.netD1(realA_fakeB.detach()) + self.loss_D1_realA_fakeB = self.criterionGAN(pred_D1_realA_fakeB, False) + + # Real + realA_realB = torch.cat((self.real_A, self.real_B), 1) + pred_D1_realA_realB = self.netD1(realA_realB) + self.loss_D1_realA_realB = self.criterionGAN(pred_D1_realA_realB, True) + + # Combined loss + self.loss_D1 = (self.loss_D1_realA_fakeB + self.loss_D1_realA_realB) * 0.5 + + + realB_fakeA = self.fake_AB_pool.query(torch.cat((self.real_B, self.fake_A), 1)) + pred_D1_realB_fakeA = self.netD1(realB_fakeA.detach()) + self.loss_D1_realB_fakeA = self.criterionGAN(pred_D1_realB_fakeA, False) + + # Combined loss + self.loss_D1 = (self.loss_D1_realB_fakeA + self.loss_D1_realA_realB) * 0.5 + self.loss_D1 + + self.loss_D1.backward() + + def backward_D2(self): + # Fake + # stop backprop to the generator by detaching fake_B + realA_realD_fakeB = self.fake_AB_pool.query(torch.cat((self.real_A, self.real_D, self.fake_B), 1)) + pred_D2_realA_realD_fakeB = self.netD2(realA_realD_fakeB.detach()) + self.loss_D2_realA_realD_fakeB = self.criterionGAN(pred_D2_realA_realD_fakeB, False) + + # Real + realA_realD_realB = torch.cat((self.real_A, self.real_D, self.real_B), 1) + pred_D2_realA_realD_realB = self.netD2(realA_realD_realB) + self.loss_D2_realA_realD_realB = self.criterionGAN(pred_D2_realA_realD_realB, True) + + # Combined loss + self.loss_D2 = (self.loss_D2_realA_realD_fakeB + self.loss_D2_realA_realD_realB) * 0.5 + + realB_realC_fakeA = self.fake_AB_pool.query(torch.cat((self.real_B, self.real_C, self.fake_A), 1)) + pred_D2_realB_realC_fakeA = self.netD2(realB_realC_fakeA.detach()) + self.loss_D2_realB_realC_fakeA = self.criterionGAN(pred_D2_realB_realC_fakeA, False) + + # Real + realB_realC_realA = torch.cat((self.real_B, self.real_C, self.real_A), 1) + pred_D2_realB_realC_realA = self.netD2(realB_realC_realA) + self.loss_D2_realB_realC_realA = self.criterionGAN(pred_D2_realB_realC_realA, True) + + # Combined loss + self.loss_D2 = (self.loss_D2_realB_realC_fakeA + self.loss_D2_realB_realC_realA) * 0.5 + self.loss_D2 + + self.loss_D2.backward() + + + def backward_G(self): + # First, G(A) should fake the discriminator + realA_fakeB = torch.cat((self.real_A, self.fake_B), 1) + pred_D1_realA_fakeB = self.netD1(realA_fakeB) + self.loss_G_GAN_D1 = self.criterionGAN(pred_D1_realA_fakeB, True) + + realB_fakeA = torch.cat((self.real_B, self.fake_A), 1) + pred_D1_realB_fakeA = self.netD1(realB_fakeA) + self.loss_G_GAN_D1 = self.criterionGAN(pred_D1_realB_fakeA, True) + self.loss_G_GAN_D1 + + realA_realD_fakeB = torch.cat((self.real_A, self.real_D, self.fake_B), 1) + pred_D2_realA_realD_fakeB = self.netD2(realA_realD_fakeB) + self.loss_G_GAN_D2 = self.criterionGAN(pred_D2_realA_realD_fakeB, True) + + realB_realC_fakeA = torch.cat((self.real_B, self.real_C, self.fake_A), 1) + pred_D2_realB_realC_fakeA = self.netD2(realB_realC_fakeA) + self.loss_G_GAN_D2 = self.criterionGAN(pred_D2_realB_realC_fakeA, True) + self.loss_G_GAN_D2 + + self.fake_B_red = self.fake_B[:,0:1,:,:] + self.fake_B_green = self.fake_B[:,1:2,:,:] + self.fake_B_blue = self.fake_B[:,2:3,:,:] + # print(self.fake_A_red.size()) + self.real_B_red = self.real_B[:,0:1,:,:] + self.real_B_green = self.real_B[:,1:2,:,:] + self.real_B_blue = self.real_B[:,2:3,:,:] + + self.fake_A_red = self.fake_A[:,0:1,:,:] + self.fake_A_green = self.fake_A[:,1:2,:,:] + self.fake_A_blue = self.fake_A[:,2:3,:,:] + # print(self.fake_A_red.size()) + self.real_A_red = self.real_A[:,0:1,:,:] + self.real_A_green = self.real_A[:,1:2,:,:] + self.real_A_blue = self.real_A[:,2:3,:,:] + + # second, G(A)=B + self.loss_G_L1 = (self.criterionL1(self.fake_B_red, self.real_B_red) + self.criterionL1(self.fake_B_green, self.real_B_green) + self.criterionL1(self.fake_B_blue, self.real_B_blue)) * self.opt.lambda_L1 + self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 + self.criterionL1(self.recovery_A, self.real_A) * self.opt.cyc_L1 + self.criterionL1(self.identity_A, self.real_A) * self.opt.lambda_identity + (self.criterionL1(self.fake_A_red, self.real_A_red) + self.criterionL1(self.fake_A_green, self.real_A_green) + self.criterionL1(self.fake_A_blue, self.real_A_blue)) * self.opt.lambda_L1 + self.criterionL1(self.fake_A, self.real_A) * self.opt.lambda_L1 + self.criterionL1(self.recovery_B, self.real_B) * self.opt.cyc_L1 + self.criterionL1(self.identity_B, self.real_B) * self.opt.lambda_identity + + self.loss_G_VGG = self.criterionVGG(self.fake_B, self.real_B) * self.opt.lambda_feat + self.criterionVGG(self.fake_A, self.real_A) * self.opt.lambda_feat + + self.loss_reg = self.opt.REGULARIZATION * (torch.sum(torch.abs(self.fake_B[:, :, :, :-1] - self.fake_B[:, :, :, 1:])) + torch.sum(torch.abs(self.fake_B[:, :, :-1, :] - self.fake_B[:, :, 1:, :]))) + self.opt.REGULARIZATION * (torch.sum(torch.abs(self.fake_A[:, :, :, :-1] - self.fake_A[:, :, :, 1:])) + torch.sum(torch.abs(self.fake_A[:, :, :-1, :] - self.fake_A[:, :, 1:, :]))) + + self.loss_G = self.loss_G_GAN_D1 + self.loss_G_GAN_D2 + self.loss_G_L1 + self.loss_G_VGG + self.loss_reg + + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() + # update D + self.set_requires_grad([self.netD1, self.netD2], True) + self.optimizer_D.zero_grad() + self.backward_D1() + self.backward_D2() + self.optimizer_D.step() + + # update G + self.set_requires_grad([self.netD1, self.netD2], False) + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() diff --git a/models/networks.py b/models/networks.py new file mode 100755 index 0000000..ad8bc06 --- /dev/null +++ b/models/networks.py @@ -0,0 +1,429 @@ +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.optim import lr_scheduler + +############################################################################### +# Helper Functions +############################################################################### + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) + init_weights(net, init_type, gain=init_gain) + return net + +class VGGLoss(nn.Module): + def __init__(self, gpu_ids): + super(VGGLoss, self).__init__() + self.vgg = Vgg19().cuda() + self.criterion = nn.L1Loss() + self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss + + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): + netG = None + norm_layer = get_norm_layer(norm_type=norm) + + if which_model_netG == 'resnet_9blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) + elif which_model_netG == 'resnet_6blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) + elif which_model_netG == 'unet_128': + netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif which_model_netG == 'unet_256': + netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) + return init_net(netG, init_type, init_gain, gpu_ids) + + +def define_D(input_nc, ndf, which_model_netD, + n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]): + netD = None + norm_layer = get_norm_layer(norm_type=norm) + + if which_model_netD == 'basic': + netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid) + elif which_model_netD == 'n_layers': + netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid) + elif which_model_netD == 'pixel': + netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % + which_model_netD) + return init_net(netD, init_type, init_gain, gpu_ids) + + +############################################################################## +# Classes +############################################################################## + + +# Defines the GAN loss which uses either LSGAN or the regular GAN. +# When LSGAN is used, it is basically same as MSELoss, +# but it abstracts away the need to create the target label tensor +# that has the same size as the input +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(input) + + def __call__(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + return self.loss(input, target_tensor) + + +# Defines the generator that consists of Resnet blocks between a few +# downsampling/upsampling operations. +# Code and idea originally from Justin Johnson's architecture. +# https://github.com/jcjohnson/fast-neural-style/ +class ResnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): + assert(n_blocks >= 0) + super(ResnetGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.ngf = ngf + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, + bias=use_bias), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, + stride=2, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2**n_downsampling + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] + + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1, + bias=use_bias), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim), + nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +# Defines the Unet generator. +# |num_downs|: number of downsamplings in UNet. For example, +# if |num_downs| == 7, image of size 128x128 will become of size 1x1 +# at the bottleneck +class UnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, num_downs, ngf=64, + norm_layer=nn.BatchNorm2d, use_dropout=False): + super(UnetGenerator, self).__init__() + + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) + for i in range(num_downs - 5): + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) + + self.model = unet_block + + def forward(self, input): + return self.model(input) + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class UnetSkipConnectionBlock(nn.Module): + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + return torch.cat([x, self.model(x)], 1) + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): + super(NLayerDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] + + if use_sigmoid: + sequence += [nn.Sigmoid()] + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + return self.model(input) + + +class PixelDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False): + super(PixelDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + if use_sigmoid: + self.net.append(nn.Sigmoid()) + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + return self.net(input) + +from torchvision import models +class Vgg19(torch.nn.Module): + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out \ No newline at end of file diff --git a/models/test_model.py b/models/test_model.py new file mode 100755 index 0000000..4d73f70 --- /dev/null +++ b/models/test_model.py @@ -0,0 +1,46 @@ +from .base_model import BaseModel +from . import networks +from .cycle_gan_model import CycleGANModel + + +class TestModel(BaseModel): + def name(self): + return 'TestModel' + + @staticmethod + def modify_commandline_options(parser, is_train=True): + assert not is_train, 'TestModel cannot be used in train mode' + parser = CycleGANModel.modify_commandline_options(parser, is_train=False) + parser.set_defaults(dataset_mode='single') + + parser.add_argument('--model_suffix', type=str, default='', + help='In checkpoints_dir, [which_epoch]_net_G[model_suffix].pth will' + ' be loaded as the generator of TestModel') + + return parser + + def initialize(self, opt): + assert(not opt.isTrain) + BaseModel.initialize(self, opt) + + # specify the training losses you want to print out. The program will call base_model.get_current_losses + self.loss_names = [] + # specify the images you want to save/display. The program will call base_model.get_current_visuals + self.visual_names = ['real_A', 'fake_B'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks + self.model_names = ['G' + opt.model_suffix] + + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, + opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + # assigns the model to self.netG_[suffix] so that it can be loaded + # please see BaseModel.load_networks + setattr(self, 'netG' + opt.model_suffix, self.netG) + + def set_input(self, input): + # we need to use single_dataset mode + self.real_A = input['A'].to(self.device) + self.image_paths = input['A_paths'] + + def forward(self): + self.fake_B = self.netG(self.real_A) \ No newline at end of file diff --git a/options/__init__.py b/options/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/options/__pycache__/__init__.cpython-36.pyc b/options/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8169d4919ebc9e93b27aabf5fd99a958ae1f828b GIT binary patch literal 204 zcmYL@F$%&!6ht>*Awo{z32fY`jfIGypoQARRz5NL#D&fNW%nD&k-Uwyt+%jpw-5*B zF-$ScEKLW`{qQV;-Vr|~A*_OZFQ5{+7$wJtaIya5)5$9}D07RYDzv;nHO=PoRG4~K zkYr_H0@)NCdV>tgcm^`KjW`q*+PVgAK~-p2aR=;_)yNwtNs|M<)4D4)O0ulTcnM)b UF~Hb%ul}N;>ZWlSFZp2c1;J1{y#N3J literal 0 HcmV?d00001 diff --git a/options/__pycache__/__init__.cpython-37.pyc b/options/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9bd8e779c0aed5f0b979814573fb1b462fd11d3 GIT binary patch literal 208 zcmYL@F$%&!5Jfj&Awo{z32fYGWg#LeXrVT-l_6$_xUkt-vSTDi@;27C-onb=Lj3UG z|NLT@RhkY&&~raL$9zZpl!USj_PqclaxqGd57ET=k54DBRHMutmbx(V0&1Gg>L@SWCOX;6}5P1Z{Y VCo}_$UH9tGYf?9Zi};cc7GKWbI?Mn7 literal 0 HcmV?d00001 diff --git a/options/__pycache__/base_options.cpython-36.pyc b/options/__pycache__/base_options.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1771c6ed5224d5314b9ff38cf07bf71bd06ae924 GIT binary patch literal 5468 zcmaJ_OLyDG6$ZekDC%Lyabm}ICXSPsu}H~so@$)<70XFu=V7N#;iw3N8IqttfEf(+ zfXS88W!gdz%Rawh8x@r3_E+8wuyHuu=7Z^ zytr@KWo|w+8l`8DWjtt@X>q9hz@xY6pM}m1{NlIJ2!qzJEpBq_2LrOTB*-rDGTuwP z!mD_e`3$e&S>dygUQcIlxv}_sp!_I|Jri>1kDes_Hhysp4L3)I1GO>5dTPQiX zLOFgID<=xYIyLv*fwSG|IBMJRxd@f7_8cYsK_Ju}$BTkUI%+4nz2^mDIm(n?-oyr+ z$xCanmsTCf^p)d~;xTQ*Fq2ccZMLQio*u+FJxO)F-*50XGJDvwf+@+2KNii5Qc@N# z!;puRDkv~$jxq~MozLYc^Pnv7#T?~?rku=Co&)6+e?CV!ttl_$C@(^vGyH6h@)9WL z`1u^=f|mSpj>5q63cr}6yb8)C{#uUmIw)`O%Q;E|lsEa69AycVCU4~^%e;)$R`}He z`>cZJ8heW${kQX z=67?HbM3- zXI*z&uz0xHjsh-Z%#?Kf@a9tXytTx(-2iWB(HF{VGzw`^?F~d);iByh1Ct!B8WcO0UV8S6S3%+2VS$%Jn25%>qH8b~PPx}4eoN%() z-0AvW*Xc)G1P&729lL}~<$K9_ojb`JQ5fve8l3rUc3i9L*Qcz>OtaY;43T)_)LcofpwG}>W1`h*u@%Nv zmaZ_sJrl`3fe-PUG^yOWp(j`f zD%Ee7lqChTP(DIV7Msl&E^`A%kBT_0X(yp54@Ghxc%lw$;BLVQljMT{iX_qvxwMag zOtOfVSOgTv9UpMdCg8QtBpe;}gT028mR)VSgDoS{2Eehi1Fb2=JU1aQ5I8)pi9q|5 z&*#tR=#re9jQe~Fy&bX1CS0_s=>%+U8{xg!Rcf%dyo~5_gKh+WvTZ+xRxD%qH7Rq4 zQbf3&7W>Go$+gL}sh$||WfT(RnT~|Zlf2512`|!=iw%wL$wmi}RQ^9+#uhG2BoXD; zmZunx9FByE9G9F6BR1)4Fj2H2rwQpIq$s8I-CbuA9}Zb9E%;$O`SX2u*Y6MeEFAVX z;V{JMLG%Sdv;zQl-LXD5p^sSgQoEc5RPsKO`%WacP+len=PdA^WaFb}&+DccV^(+_iBauLBEr zo}T*m`hSzZeUx0CTBMDjfgZYm=$Kw&nJ{guLqRRtAJD#ugqgIC?r$KGRx;;tu{rE? z$htrX$)o&`RFMdV?Y6(08p(Iw5J{j9e{ndIh-(bz@0JGm^~3I!gHfs=OYEu6P-Exm zu^O~K){uNUJw1#cHA-pS<=k;)XGnHU3$c;}S(TRBU{jLLAPq~F@JP!8S0YaJ%Vu&qf1fy9=->ub#aeV;(Km__ zIwPZJ#21v=v+yf$^G9=xJ5>*-!hPzQz0$t%tu?mBg^|@OkBh3(tB#5oD~(Dc;+)yH ze!*88m(?uz%HXSvDtSJfIO~u)x1ZhF8m7IFc@t+G#|zvzZp>>drd3#3ThN}rIkEQ_ z+MxFN>HD4RFc{y%S@5pL1fCo2XT*y zX~mBp(LvO%Pg4@{rbT*_FOk$*CY9ths@lt%o)#b|H66RAg**Dp@m$tFvc#ce&GP(o zog+-g0W;shAsW`nnBr?tB&IzYR@p3?HM3Tz7s`0n&B@a$Tgllx5RM8sxL#5%r&6xL z@Dn=R99tucTkjfUYu8qW0s!^_jeTR^ykl(K1rQ43BDhL$d;uUSz`^BFaa7(a$WM7e zmB{6%4UAj7*el~59B2WUQE9(Ga=*Xu*uZFs`l>Rr_D$1J)!t0623XAER|lXfYOXgw zs^rEhybO9(EsUzXvTyzM?09B0(_75uQ;%1P`!PNSd}(}Te3N|?n0X*iV_(Ue&EtRU z^X!_=jt$h3`I#q^Vz|a$i#L)JPuD(NeS5sdp00gxZDss)TuU0qCN-PvW90u|P@B!9 z0wod{Daiqj%p zsnTK(2QGUdchI7y zV%ztGBxg6uhmM!;YJsR+bfvP(Stqr8e%!4ko$bB@d*)7y+*dX)nRN>}2jzweQxyeh z-aLt1l$GmP{ OC9b2Py%BI}iT?vE2|P~# literal 0 HcmV?d00001 diff --git a/options/__pycache__/base_options.cpython-37.pyc b/options/__pycache__/base_options.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1340ef638d8778a5181a8f259fdfd9f65ae732f3 GIT binary patch literal 5472 zcmaJ_OLN;s5(dDhDC%Lyabm}ICXSPsu}Il+o=Tkf70b!S&cn_+3&)E~%X1en24 z512eDd(7@-YtPbgtG4#MKO<+{w)W(|kkfuWASFmiMJf&kJv}`=-CuXl;O5Ls&A{K! zpTG3vAEyoDfAli>E1+{7zxWOsZg4X&?CfdUCf=>U>RGmBX0(D`&^$%EsGlXfv|{iA zFaBuo;*MpPx%t3ol%7GBalc`v#ewnzkKQJK7CP7Qi{C;c3|hmsxXG;_4anA#AiKoN zcrWn^ui{zeQ@n;}g-=6zJ)OSk#^Q^<@}n^JOvs@>dXn&4_{FqeZVn9xYGaCZ{|=Oy z_CsI!Zr~?^lWwsH+W$0U8N9!Y7Q0bTEI#N94;+iD4qv+b-r_?y?A&#uMdkOME6(L5 zM~XNU%JIWkIZ-H9sJZ9%osCY%Q5%lWMW}qW?I`K@1EFp^UKB*qQJc}NZ7&duQKt0b zDmLJZUt0a`wCX^nuN;RIk7*l*nViUNvo&Gx^dQFRNviApev7xExrRL}n3Bx+Q_;*Q zC1vq440%AQf&zo)DASkn({)9@*?y(#ZTub zFM)D~pUqLuX~{3=C=5KW@bfv!tDs!qujMGOgYpKyn4>g6d6QqtQ5HaH@>Y(r$jeym zGQYBKpC#~IQu%v$^UjuloS;{;Y^dN#^~qKTu3lSqw7mD~Mc9EdDuUkmOYCxq+Cd1m#rMkN0`o zekim!$~MVK)fFN0WYni!9AeZ9IidHA(ra?+RTi~bwz%Fb^SW*riXbK>qk)1}U`|eJ zzU)0y;efH!Y=)iov+E@9g68XB@wOi*vXM_+87t;R;RXxbZ6Ony!Q072&CLDSlYZ}a zCme4!H@m*qb$SsOfrCVM+b$tf`Cf7sxstrh^bl=#5DV6hBqOfm*E3h_*>SFXe>I4_ z_4xNitR!`4>uAMe=XUZ&6b9S024{Yo9o6dEwFzr7)ogb910>!!H7_TZ&}Zl`Gtug_ z*k#5pFI-tzTEJterB!ZT(OTd7@|&cL72Te&=R_x3Qc>U0FmI)bS zdA-FxRg9xPaDa8QE2JPh#6UgZ>h_TDV+PD;CW@2UEM$OqZBM&`!qW)_Pd@g#Q51u@ z8*Q>lP%sx#2#SKY*yBMsW@bZn(2M=B6No2CrPs19#Yp(DD}~ErLa55R+Ki+yl+wFs z6eJZdIYrGh!4&enuJG3Tksqqq!79m2YcWSkp4a<;*|Z5a;B7Li_Dm#u1U|%X(xh_h zhMr&{s8qjQQkE3VLirFmnQt~@xXcY4Ju2d~rk#YMJP^q};E6i0fx8YTjFS%nD3VAw zjTHmCbXs$^W2!gK;ZDW zCIamfKA%6Mqf2sTJnr)e^ftvR8*|a7CKIr^4TSf6SE>H;;v%BU4Z0Eh$u|5LTCt4b zx1`J*ND<*iTI?aSCRfMPCVFDTmr+QNXF3utj`J!*COl75&Nno=#~bZOQu%u?V+-fT zl8EwK%TtU;4oAX7j!Vvj5gYe4m?+wi(}Z*pQj}79?v^u-4~MLl7W}ZC{NsSgQoEc5RPsKO`(`B9QC`Le=Pd9ZXXC>ycf%KQxbFHr>}&|HccV^(-1TH% zUI!NLJU#L6mH#Gx|2Vlau}B+113h#B(J{TiGGW?Q2ZCC(KcIaT2{UOO-d|rLtz^#Q zVs+5zkad9&l1KR=sUi^!+HHR;HInbW0g^xu{^D>Z5tkXx-z^RBYX{v+`=eAr7T8mr zp+?TrBQ4&j9{-*Hnpxc+o!{>rH|Nchc*`i%md)g3{yuTA(EbgoinZvx zqHh#KbcV*75ua1$nuTA1n?IQ&+^KePD%_){f?GjFpC^ zAvo8jcC26V)kbAC4Zbq?D#J>i4=2t#q|WSRceaLU&t=}k8OQMgH;x;#+KOovR@N4@ zXK#${{iQajeRlGGJ39=buy;-Din!nkA_t?i)PKN73zgDo^^Blw8~a;IuC@y0`{+$ zRLiN9YcTwb4mU^E(Bjs+#>m>Tm7xHD9dp;(Gj`0|#*;e$LSa+{R|$?U03-!CxI8Ql z%j*UC884{Pu)JfQG%#-Q;%*u5g&h-ZamO5%_6j8ThjWh%jFzacDno0>G!0cjo<&_=lZ>L97!>Qf*p+!8@<5lWC!pDHGjIWJv zvyTE(_vK0KD|xed^pAa!UDMf-fm$*<^>|zim)UFaljPXb<&T!$9xbz{%O767JbF5+ zC5&1O=8*IH^z+14fyQfnMb%Q14F;AwFKXiiJM61_M>*EB(LytaH1 zdZ*kn)E_*I_(4w?bVx{6X1NWAs1hAUPR(1|_LApwq!` z&60PhdbtNDTvVhqSO*MQ3x@X5phes{~1;`C)7DUB173mqD( zY+W1{fOdvtH$}*$RL6_!CE1093UFGYc|UX?8H(^)!Nt9}UX=e*^)*5*D7<*e$Ydkk z%|X*S`eK%0QND^z{9-OhIuSPJLjz#wq9q$dd4rl3HE*M_E0Y^g2`5iq$x|TW zSlWkgd0Gg_!zi|1*DiVk?%FlFl5nx%dqR@48|4GX%XhUvR4%$w+2yR0T0TGO){@S4 z--A7Kr$z27n-|Qwg`9(OLxrh|0yJwLM=nZ^9dak->idlaGnbXPTV+i;&~=*?NkJ1z$#Vi8_-_dKl3x3X2f;a+4gR6%Hlyw2YhmP%wN+JVQiR?O4`)wN z%w1HHBoqk7Qt;l8#0MYu2-?YoH1=(<3vk2zhbYfb%pa&Mc}X{9L$AHpbVFa_Cf2tB zc-Its2yOz3Ylsf)074k+rLf<9B&+d*z4@&WAcAa~#Z4&m;Nzuotq_2CFU zeCH9uefa1G3xiohlp%bKOpYA_pKJ&pZ73e@!C5nmqSE34TVIun=I}FCEts>Jd?kuH zmu5cm+rX}C*$zQ2#45LKkjcDm_M%8jgSO$IbTgjETQhDBjkuJCZEcoGE$7YcC=#ZY zspVQ=Cze<=9REIP9`rElV3)s-W2z8KX&l2k!~P))TB!1bU5fmqV|v1*O&y!7 zOr;r@b(LmGn2XN)%MN56$ZW;(O289lr0tB($IVtovOa>PBF!YQq5_#WVH8PBl#Gq7 zwB%N=Wb>(Wc4+6rV5L>L>}zIKA?Gu%4MpninjYc-n{gCn)fF#9xmLgY`WtDE(PYl- zgZg7S7J5}C!&<#MMZsG;=r|;!g(~H_YUGS2Bgfy(zn$IE?hdwNRVu4Qj#x4k09@$h zs*q*ZP}>d`7=d$Rbuz_c^KB1JGFiU2lA$seuB^yCFfv~xVJ);7ouE#C#u-0nt#q8o^otliMy)OSYDO#xl=!gnqMxL++K18$V=HZ z)=}ZrvfuEu$}7$7RrPG0<}$AOCmOrf@vSkh*Ls2+1}Nq}Dnh69h#vGmZ=W7`Q}4hF zX>)XeVcUl~e_xg}e8VHoOHoM9+cD=@)>ZE2Q_g>0iM(Iwl60rtQFkz%_EEVHk8uSb n(at`b`N{MJ-paQ$yKZz#JK?T&+U?PfzUZ@l<-SPWMpOTPGoNIf literal 0 HcmV?d00001 diff --git a/options/__pycache__/test_options.cpython-37.pyc b/options/__pycache__/test_options.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cad7876f123e8aa175c76c05ee0029138c75b72 GIT binary patch literal 1258 zcmYjRy>A;g6esUaACctLu@g6~86-qAPONVO z@Lo~yA-D`Eo*~+^J?O(=3mKx;H5hFnW7L_z&K7bVkvithFThw4F{#`(Jb1CQN3@(xioAevs7w1tFMQl zFqKR!*8(T8z^38)_hEgfgINo^_?;b7xmZZ!GOSXZA7+7h)wiq`syJfjB0Fj`Jz~-( zE}M%~B`KFxnWS+)6Rq}lEy!As*^*_YfQQUT+X}6Y+bxY`odruol1gBC2{NnuVJIp{)*srB=nF>zP%#oK3x^FA{gx^Z*O2M`4(j7d#imO8x%kEvXMt zWKQgix-wlBdRfGSO1(Hn!du&GbBKp?Rmf9S%PEbAE`PWGVR}uwJJ^g=p{x=aV(~-( zaG@8=To&y>ZPS~h1@4X2@dV4}n;DvTy!a%=17*%!TQT>*$ZQ_>E1`|lABf(?DF3~T z(lVFPDUJ_gMb9`K-}@#y6UE|3QASo({64>T%yDo9dR2w7ye#BX$9@(SKVK}kJ?9EA zFJ)KQM}=3*e$SIKD>b(l<&#yC$*Am}Xyiu6H^Qu1=`rTeLo&CK5jvpM8F7wnU1n}S;r#hhWZg*{q^)*~-N10#N9H~} q#x;CMTlr+_$CGDxE3eXQr_mm5nY-R;=c5^Z+eLlgzDTX3iT^(eu4PjI literal 0 HcmV?d00001 diff --git a/options/__pycache__/train_options.cpython-36.pyc b/options/__pycache__/train_options.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ba3138fefc14d48253c1f7c460e1d571f79112f GIT binary patch literal 2404 zcmaJ@&2t+y6yLQSCv}}R{YVS66hWcXsn<@@QaViO2W6%lAk4HmndxZcm29hMS6fNz zjCTx|_Q0io2>%UdpBOl0_zyVop6sOSPKV8`BK`C}e(ya!d2g<-H~#qk(z7RxAox2d zJvEp=fY%Zjc)&vyBw;TMgU11{@an68S4H@;+N(wCqL+BDB=Nqh@elZ=5^@=k5BZf6@)2ke{&5NU1dyxz(-QKTTjAIE=OyGi zAT7RCLT&(?c8N^{Ryw>}lGp}`osz^ZaJk8El_b9KI=ANt|8g48`ZR=k+}SHmMG==a zQ*1&9M!YzLLlP8SoaQ2pCqxg39q9v^4vE9k;lv1=D`$yLiH(gAX{$P`xk)Bwt6U5i z?9FN;p|axi8{=uLl|5WR>#-b=Ok3&XShUH)J7G!@nXsX-M5crh3!ao9V`;fg$bsmS zjHN=61L;O27RCV?lIBTY7$mt+&0uBpL8mwyMR~@V6LjPfHII#>_x5sZZ2Zp9oF=GO zG@~dpGIicSvAwMOKY59vIkXZ;!?LlU3Ygfr75UpapmA;4)PlxHkh$bhj*Bx310vJ$ z$g?omhsCXT>Bl1x|D0*yYYB72fjNX zzamOA?hP$?C7KJ>s?6#v#v8z_BAN9ln)8cwp1LAN$JQ|ej)rsB#P77Bjs)}JEqYU= zKOuKZ%JwcsJG$+DK`DTHW@S3CYbCt5iUk*2O|lo^FC(v zc^Jjl9$avPVq=*1i4~AESomV$EVsu6i5X;<6HgkQr_u>?h*w>4^{PLCu*R$(2d)G& z^L_z)NW`m9E@C!0JdKJwZxqq3OhH3Uh?gKs>PV&#Mx*DxaBOD?`tbaDzt~K4B2t$p z5Vc8=3F9m&YEh)jFnVpIEQTV@9uIHPz2=xM6#kfp^PuE|!a-PSyQ z!7N1eB4fA!`Ly-yt3FC{FE>3u6z8jwaG2k^nLU76VtjWsLOpjXeBvBLZwm`yPC@2}kvoQ&?%OvF(5 zZlChq?K|ClmJYvTx(f{H4&C0RkpEOTDpR@B_=}csQhPOeF{B(oNusfS82) zmcmW04Rr_l(Ii%)3t7fbfv)dAx>$9xiP-=Wc;VYN3_;io&zAl+N^`_%CUB h*X+Yha>U&i!f%z-F{1`;%+~JCz23un*c)nA{{t6+=YIeI literal 0 HcmV?d00001 diff --git a/options/__pycache__/train_options.cpython-37.pyc b/options/__pycache__/train_options.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68185f13dfa4f1f2c64be53b6e5789ea393ddac7 GIT binary patch literal 2408 zcmaJ@PjeeJ6yLQSCrzC;{gW1GDS|?&Q?H$*rF59kmX?`vfH2eMWG17LSF$b8uC|ia z8SfY_^uVRx0gikd&OR}4%J2<1@t*9Y>rRKwtRnsN{`}s1dh$M6TU!m_`Qz&=2Tx)4 zH!hB!8cZI-YY7ZI;GqhVup5TKlYm!v^+mv|BK*19t?=5jpt*Duf><-0o_@rv_%?G= zr#1%Y7-4heEYT^ku@NF|R;M*L&ct+? zi#~(BX>BM}R-AohJdL%oM@wiumVJ_GE1evP7MXh|OerE0HV~G`lrUn!lM-YkE!PP- z6g`r$R48&N-H^n>I3PpPJn0F8Bjas1~AJ=W-W?l{9>J_u87gGb+a@*;u^HBlT5v4wy%LRimw7aO_d$6w7ADL*-W+D zoW(Dgg{YqA%n0)2i@g?s{_Qhmp?NVsBbG{~7;Y7(p(t7_YGomgc7R{;(4PfZo-`@x z<^4XS$_;(8;PhFN88MPNw-Ee5q{2YL_`IcyQ=Xac#-~6Gc+HLwoN|a)a0$XE2llvc zFkj#4%(Y7`LAf*s9gn6g5eFT(R60mm8JcM&<1rqOq{ms&TulCK7n;HEPOKBz1g=ac zSdJ|P`o3fokAy95CvXGG3}S(^&B%Zy2_s@8%$O`n5|n5DhjhQWZ17ibx;n2DWUSj@ zoKt27Io5Ytg@C73>&$ei569JYm!;juD;c`6W`*i^>lrgvnBN1l0jB=_sWXI=(Rr4M z7z*F%QNFWvud~n6!BBxXHDl?odA-$4Yb{%lIkK@%=~Vt9CXv>p%i8eA|K{2pi$~(%*V%jySG8R;h=@ zWgpl4E1~@=BYW@#u|g?UkJ9NXrO=JJ!g+&IIG}2_<30Nvl!i;d!Bk= 0: + opt.gpu_ids.append(id) + if len(opt.gpu_ids) > 0: + torch.cuda.set_device(opt.gpu_ids[0]) + + self.opt = opt + return self.opt diff --git a/options/test_options.py b/options/test_options.py new file mode 100755 index 0000000..e114ec7 --- /dev/null +++ b/options/test_options.py @@ -0,0 +1,21 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') + parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') + parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + parser.add_argument('--how_many', type=int, default=10000000, help='how many test images to run') + + parser.set_defaults(model='pix2pix') + # To avoid cropping, the loadSize should be the same as fineSize + parser.set_defaults(loadSize=parser.get_default('fineSize')) + + self.isTrain = False + return parser diff --git a/options/train_options.py b/options/train_options.py new file mode 100755 index 0000000..fdfb4d6 --- /dev/null +++ b/options/train_options.py @@ -0,0 +1,28 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + parser.add_argument('--display_freq', type=int, default=10, help='frequency of showing training results on screen') + parser.add_argument('--display_ncols', type=int, default=5, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + parser.add_argument('--update_html_freq', type=int, default=100, help='frequency of saving training results to html') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--niter', type=int, default=20, help='# of iter at starting learning rate') + parser.add_argument('--niter_decay', type=int, default=15, help='# of iter to linearly decay learning rate to zero') + parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') + parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') + parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') + parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') + parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + + self.isTrain = True + return parser diff --git a/test.py b/test.py new file mode 100755 index 0000000..3df0965 --- /dev/null +++ b/test.py @@ -0,0 +1,41 @@ +import os +from options.test_options import TestOptions +from data import CreateDataLoader +from models import create_model +from util.visualizer import save_images +from util import html + + +if __name__ == '__main__': + opt = TestOptions().parse() + opt.nThreads = 1 # test code only supports nThreads = 1 + opt.batchSize = 1 # test code only supports batchSize = 1 + opt.serial_batches = True # no shuffle + opt.no_flip = True # no flip + opt.display_id = -1 # no visdom display + data_loader = CreateDataLoader(opt) + dataset = data_loader.load_data() + model = create_model(opt) + model.setup(opt) + # create website + web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) + webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) + # test + + # Set eval mode. + # This only affects layers like batch norm and drop out. But we do use batch norm in pix2pix. + if opt.eval: + model.eval() + + for i, data in enumerate(dataset): + if i >= opt.how_many: + break + model.set_input(data) + model.test() + visuals = model.get_current_visuals() + img_path = model.get_image_paths() + if i % 5 == 0: + print('processing (%04d)-th image... %s' % (i, img_path)) + save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) + + webpage.save() \ No newline at end of file diff --git a/train.py b/train.py new file mode 100755 index 0000000..0877a35 --- /dev/null +++ b/train.py @@ -0,0 +1,59 @@ +import time +from options.train_options import TrainOptions +from data import CreateDataLoader +from models import create_model +from util.visualizer import Visualizer + +if __name__ == '__main__': + opt = TrainOptions().parse() + data_loader = CreateDataLoader(opt) + dataset = data_loader.load_data() + dataset_size = len(data_loader) + print('#training images = %d' % dataset_size) + + model = create_model(opt) + model.setup(opt) + visualizer = Visualizer(opt) + total_steps = 0 + + for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): + epoch_start_time = time.time() + iter_data_time = time.time() + epoch_iter = 0 + + for i, data in enumerate(dataset): + iter_start_time = time.time() + if total_steps % opt.print_freq == 0: + t_data = iter_start_time - iter_data_time + visualizer.reset() + total_steps += opt.batchSize + epoch_iter += opt.batchSize + model.set_input(data) + model.optimize_parameters() + + if total_steps % opt.display_freq == 0: + save_result = total_steps % opt.update_html_freq == 0 + visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) + + if total_steps % opt.print_freq == 0: + losses = model.get_current_losses() + t = (time.time() - iter_start_time) / opt.batchSize + visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) + if opt.display_id > 0: + visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) + + if total_steps % opt.save_latest_freq == 0: + print('saving the latest model (epoch %d, total_steps %d)' % + (epoch, total_steps)) + model.save_networks('latest') + + iter_data_time = time.time() + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % + (epoch, total_steps)) + model.save_networks('latest') + model.save_networks(epoch) + + print('End of epoch %d / %d \t Time Taken: %d sec' % + (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) + model.update_learning_rate() diff --git a/util/__init__.py b/util/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/util/__pycache__/__init__.cpython-36.pyc b/util/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4deaeba3e212608af04d73a87b2eebc477ca4097 GIT binary patch literal 201 zcmYL@u?oU47=%-B5TQ@t6S$aK7Y7kRK?ilItNhgFrv}m_B!4S?B;Ur_)wggm8N`G8 zI1UbOo~EPcVSHgh?}(q05Z1wd5Fl8=Cdui+T&(~2bo^!-VcfJ3TX?KQN{t?&8Dvu@|L9H5cZM) SrnY|#myHpv!d3j^gV`4X(m8AZ literal 0 HcmV?d00001 diff --git a/util/__pycache__/__init__.cpython-37.pyc b/util/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c32e50c02cfe3bf8f534046a16ad1d39ec859520 GIT binary patch literal 205 zcmYL@u?oU46h%{T5TRe-7r2<($w5R^&_P}5Dv#PcY9LKQ@>=O9`8Uq4{)LmtAYQoV zUM@U%>ogs)pyy$HiTRHBDG6m2><0mY6>O569?Zr0k59*Mra@jIGqE)zz|M+0`1rRkkdjJ3c literal 0 HcmV?d00001 diff --git a/util/__pycache__/html.cpython-36.pyc b/util/__pycache__/html.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc2abd9801ba2ab58b4a3104195419db9a6573e7 GIT binary patch literal 2366 zcmaJ?Pj4GV6rb6Dwl}s@+7b!{6eNOVMC{UDN{UcLE47C<5^becE2?GVnb_;Bcb(Zu z;#dwB_ee-joH=lSZ^1|4JXcQn3Y_@8ar37^VpsFtn|W_%-h1 z`+>kx(MaM5+gfV#cbhpRy(iTe({`P$A+xPs;tdi zSo?$KjB$APl(p)|@RY5$xZ$67Nj!O62zN2cZ$TuBxL}bZ-Vu&)U-1Y52>%s}JaQV` z7nem%gwO=yil~bQq>5;YMd+&H5)Qg*yob*oKT;gqbtzBu5MRK!hf%%;nL1MzvndDV z+8MY5Z^{N-@ENW9b|s*d882^RR+-D3ooak&p1jdOhNe2GO&#GJ(w^be74DX&e!_R* zVBu&9`nvFRW9rVh@DCl%p5varV^7!%aJU_;n;;z~Jy{|}zD{)}Z3w2K;k74trJ+$Z=Zg(o~1Z zmV_lFQ4G^O(GsaEx;sN`XIx0e>&rwKZzy$Q+)}HSXS`Boo2rVV?0G7H^aoS#g8id9 z{bDFPPe-zwCRqo5+`MtS^DN1G50j#!(^33oeB)-UWSL7Hr+KO4B9{+{JWfXOZm$>X z-8dCEKh^uOR_Q2{`hMIkvO>jruei70&17dsr&*`3hgo~HZ>ljiN_8BgG+4O=!gzr9 zGH>!a4Y$dwJmixGp`m?Z-6dGiL!5?11OKri7R2p*s&K&@{UpYHnPs!ldC zxtb;W#g1NE-b}|*+-W%`*eDdvH-4%2GYPEt*3fWmLohD9IEmpp&bk;wjl8ssT+) z`P!cbZ{O%yrEh=s{9N6it=kOO)+bAQ$mHstN=9qT8%icytF&;(1YMK{GVJPN)N<8_ z2saA)NtX8VDA=`J(RH?X(gR&@c|8c**72vf+$;J2b>B>I5X$KnK}c2ms8W3ZUwfX8Lt@09I>brNw!rNC<#kIx|gn`tHu-@F*5S0=l16=o$tlcDv5h;34{ASN9(7Aiup59W-r(W+4^%afsuV-v>l zun;>L$@Mti-bu2#M!g49b&Uv>`TQ1dk@OW2>cZy;btOglqP|vAGeaj_th!Y%^uoH; z(4oU5&Et6D;m^^DYFeT#Q-G6&zRa>>8PLnbF~Mq5pEtY}(H`Z!Nkz-Cj#or$Q7u9n zHBNpbZJqSWIf&?VwFZKJ3HF?O{RJ&jid364zQP{@smx7~jBo=&{cl93`t>|H-%u7y W!m!1V+iX*#4;_!w+q5tJ2L1)mu^(yx literal 0 HcmV?d00001 diff --git a/util/__pycache__/html.cpython-37.pyc b/util/__pycache__/html.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..0cc87076947d4c7d37a4db096225ff77d58885c2 GIT binary patch literal 2323 zcmaJ?&2Jk;6rb5|Z)~TuC6rGEi69vfyHE*1icr-el>ydGF1$6SH*an?8ZpE3?6{Keg_du@PZ}8or2$A%$>x!%Y-YuCro&UPU1>W_&bi8OIHLU#ELKdc_5-= z&RLHoAx4Bzn{3Ef)DHQV?8kPpB*S@K#Bynck)4b%TRNN7&g*s`K4AD`LxwVvwPP37 z{=9k2IK2CgwdyDElx?=S;U9EKJbhdUw=k;PAd)3qu*4Cs2uHY2c!B_g|AZwTISuZM z%c3S?XaaFX)I|eQD4Jp!x=38YMpupZ)%|b2Rvf?UdLZ-=uVLK5sGfn$oEb~mjDvFR z4BWXlV*@VuF|GS{C7_jKUVW;adAN`{JJa~k+;dZZExX4ag2U}#(**f2>&Xf!@^!8YX+toflG}x>`Yk7k z`&y4vxx1G?GGTXAYFX-5QxO zFzfPy&?ctQBiAVNlqFiK91zB*m+eL`sk>GME%7rzy@LtucGN_f`HUX|4j0M`FC29V zc8Sx!0nvJL1@nvU#F-E&rC+Mop{^`~5%_f3dhifX0EOyttdwt2)~%HEL<3ghNzBM#l+YpQk zFHU246MJ2Zp+;WXPu_<1c?eBDZJvc-l_$>-LYL4>4?(MmVI%kp!i>!swGZ<8=np!2 zhwSJ#&L+QV7YEnDQF76we*VEd@DGB6Py}$Bqb| z@}pI(#v+_~E9^|?Ll?b#F!jaLgpA~tpOoB(2D=KG;Cte>7U>fh|Liab;#CU?t zlRXZm>bB}?6|y1(8b{-Y(|im;7(O;UQv^{(*N9po3m3h>I9sX}Spd=;%|cJm75pME zcfjmE#>t29ZK91w%NOfwF}i40@8Y)=ftdlL^j_naKv(#CIJ)OByRbextS&w{gX?un z1VNOoumes9NH~YifirXF05A3CV|)8x?%Vt1IeZ&yRVOcnoV2Nj7-wI%0^6aCHyX;kSW5FmJts1qmG8@V?HF25_M`Euaxt^xGds(s2sMlbst`VVQzNq32l0G9sZS@?X zuB0eu)YmF%X6S^=ksEoj7uT(ZHXUYpnWj?@S3xJKX^D1A0Zy0tvM5HYfL8S%5Z)i#=et}IDNTY%fe_6#xd07C1d$RcAVGu_BqU4G%6K=9FaEprI^-f> zb$JF#DjHscC!pR6i7Kx^#f(kDtu?c=voquI_l=*gt`2^_eR#U-5b~QWEe`fRNc9+$ zB$B2i=yn>=BO;mPH$?IU3z(!QWaQjJ9663q-z^5lK1j66ylF{7!Tawm(En(!h9|c zW>KofwMo)(+lIgncgmUd1(ZvS5FYeV`Jf2(=m6w;;~tcqyP$h?uxl5jMdt??|8QY? ztpP$}GG|v`8+u4izs$Lz6V~tvmvq7A93mZzz92e4Te3~W5Pj~%Yxsgv(y%*gx@CG3 zA8CVJr*R|)?7xApkN5)@Y>U8+i`k9)fJ_DpGKVVNd9SPFVa@5*a@CInx>{*qARwDH z42P)W?-ognL16uQH5#@=S=SUgI=B5gucoz9z^o%f6Ka@Tq7ez}g=ML7X+0h0vdC<& z&Xco3gMp4DH`a|(Ak6Y8G^3tIz}w7VRV!2L9L59(#JcTnX&LmpQb4TbCM(-n=@b%W zUFby5hG14tKoOwLfoDM1=&*a(kot53&I3AR=)GS5PimWeJaTac@-S0E*ntpPA?p~R6zwuRS#519FT&L6|LQE+Re|}jzp{M zi9Q2Iw3WE<8a#$iNSyi#aA784OC8PU@p$HoJ)d8$uMYsnpJ&a7{Q%$>EmlLIbBJai zAt+F621z%wgdKt6N?d~ypLxO+dk3T77W3d_Lbw_nokKLci=ZK43MgCgYe<9=D!@5V zAy!Aef1IaN{kAM;6Dp%YGubON`w*dF4M=F<9FKDmm|aC{HpPhK3Xq1M7>0Qx;rkUb zNt<119~X5lvwUiOl+i!zi2K0m*{PvBeh1UT>G$#YtjzUzVwJAkxH9Fro;W*pStXyy z{io7sTWBY>j8Bq2Gn1p=Ax2Y)< z3ztzNZWx1xue|w=>n&r_0c(RMPy(Cz4P$+3C^nbx0PciT-GmQdIkxM0$JgFIGLdkUQALq>KSTZTbZmI?P^n??xeU6+u5@! z>338zQ(m~dYQ^t>gz_Pt#9R*1wT}@1Wee~dunjisK0ahI+rssL4LL=xHvfy-R-cSQ zdILq8YbpIe%DhzdjQpXL^E#b%JEo8PW}ASdzXx;@Loc72V5uE3i=#F&5@wTj*9a)( UURlf1-XcEee)*jGz9JBR0onNHH2?qr literal 0 HcmV?d00001 diff --git a/util/__pycache__/util.cpython-36.pyc b/util/__pycache__/util.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6880cdf4494030355e5533d834f0a65fc512da61 GIT binary patch literal 1950 zcmZ8iTW=IM6t+E=ot@n*5YiM1YEfINu7naAsurPAQGsZYAVMM%l2tS{>&a#|bHnzA z1kFAzf2RMY&+}BNPyB^G^|NQ2rkRz`IX=F9KIc2$J8Nt4A5X8ndKMD$7di8Jupgl5 z5<(J5b5c-FskK?oICI+LKD3tyMaaX|SdaHoe-`ai>C50f<$ZZs_GAPrmVFt&Cww3` zsha4hfF;-(X8Ki zvbSUl4|WVYS~AH#CNj82WVq3-7<A0kXfVD9-!fgpJk*i)o5xzXzT1XZ2w0DGN|p zq)x!$Y($*6Eylw(&U994Bg#qI2BuPz10L?BrLI)#OJM}}CiT1x_VY@ZHYn#sebNT{ zK-6i=%9{Ho%TwMrN|d^;blQeOo0B>n`+Sg<_1q+xjT#TL!n*m;braK_Fz9-2*WZb{ z2wZ=w?SraFw_nuhBoq1e>qPF{zPr6A%IPnn+BR98e3RVXNmQ!K6uSd^QkCh?Ry-4R zax|SL<|xTzTAIw9Bt~U*o|;F=q{=Imn77r#lS!U#&jH+tEnKMI;n~TG7!b6@LFkZ1 z%%>Z41nUY_>+rQbocXp=ho>yC0sbEgBwZkL3-9@Ed3zZzeq*ddg8TKecKNl!N&pjaHZBSu|#dSfk`zl4G-KlE+w6~WCt(G zOiWAUIavi_6JBlC5$fYNqA_BH$hn`o3jb-SEK@B}UBO&|VaSX&WQl+^Y{t4Px~tJ$ zFx5`y{e4vxLMd^=eJzes_%B$0+!3c1=79HCc9x$<8`~uv0l&a1ZOKo7P<;Zu_jqED zu%Fd7NDtBUZxA+hu9WEl<*Tk6PXZ}Qm{HPD^($!<9=IFucpJhZ9<_5XU?$+>YGCPGhg7$E|lHa&!;s^N}2M&!H>Sud%1L z2ig+xU5lXAXYfqyn`A$mXd6}+(t9wv*AZ{ymb36m4vE@2{d*5tHsdRYH01Y1C6!RY!Hf3RF)8 z?@82GYidOe!81^o)MfCjs4Hp}^iZ+qB>d|%y`Ww9rP-8$n;~| zo%d?4_=3$i*a6sL!4&(LDDNgw{+in{@7Hue0c2-aQ@HOGMaWMvUK-Ij5pt(a5vE%< z1_RrJ*7&=85DuCkHECh2EJjJ=*;0=VB7ZL_OsN~8q?HjLRnx}X&q`?WOFqJCObQyZrktH@7tbVma>jBR^Eza6<#|0Mb{ zfcsZ(9h7;p^}I?(sm!)s#cKP`&eonR#=pvP%cfQQO?+oN)`=++I32LZWs&@X;+d@C zqwzSlM{%l>!lw2lwmPk{#6F5gWmf9gz9}D`jIv~F3c#JX4Ik8N@oc9#lF=XzLImF`KbQh6}Jds7i@$W0cfX{|@s|z%v4VX6f9=uK?Xy>h4na zmO4Pa@2!)wb~0)_Vo}z z>lF}D$5rZg-)hW^9xPqb0>j(U5sSEeFR;h zehp_Dybw!7xZVY=KLgJQUm|Sy=wMh~NO!Yp@0g6mvj_Oy2_g z3@IW$LyE+(WI}S-*3O(ZoB{W;Y*MIH~J00EmmiZ&-A|gfyP|c{7hrbxzTSj{bQ}$ehw{=LvgwYm1b_V{0ZYRsCm zmaLB&MGL#MnNie==7F)H&6d=TZebTob7Qt#v`5XKnzPPqWwu(F%sjOCZ)B@z7aiFy zRtj?+TpU0=%zkRl+6qJC|4A5_?-b@sm@if?FhFw&tC#17t`+ssGN^#pq5-I>Aa^nL zl2FI^EFMHRi$3GzHjS+$#DL3G#3 z+Htg>gPpQ=6pKlah|;atQ#`Z+kw>BQc_8DAHx$ux?)$D7W?cFT(^A}<@fhsO6DdkN ziPJ;sg8eAz*W$6y0vVJBm_4JQRx}icY72c>>c^$-|C`2dfzqE&c81wF+W8n9#6hyN z@3XtFy}9!+NC)2!vK<-c{1S#50W}{5xqmbm`0~h)S(M6Hp8Ar< zc@oL@{V+>1?#tusz0)v>cE~Pvs*i^?@LqmeHZQ~`nqWn|ibm7x`jXx;Ts*6~tGDph z!q_$Sv@q^sg#Q&VLqcEt3;^%q5xZz2W%00fof)5Luy&JK=%K#ML2t7ftD|>VgSqHy ztjSvF>#Pm6TPoeZQ-?%+hO{sJ=(Jrt%21)z8q5re>r=01>u8M@y_TM8^69W zGv)Qs%glw#unKc@gEb4Q+U*k<=8v==YmX`eZ9TZJEc9X5?mHslN3a|#hq9Jtlpggv zJcvc)M<-#FL$OoQUz^}PPQz>*rvqPP;bA2ECTs&nOx7x<*FqizGV(Cy!E|9289()4 z{k`6_m1IKrNjB&~^-DdN8r$LY+i$<{^B#h1iv7J0f$rVf7G3WJl(*gPE2Ya~n0{~y z)h%nWP^t^>LuF^!;tTom+4R)>_vc{9OfIHynLKUxnqaqF=)3bnU{( zdW$_m^7r(ko-D5livX1UR9a^|xZlPfYThMizu& zPnT#Fnx)=#{;$#~ zED5JQ*O3QV(vr$?&*41h$fJ-38vkHi)sRWqGl!XlQ`jnhA{Seq)tNPG;Bg1q4C&yk zb;97B8!G#Hgf(sCS?;Jw*$;9om)S*oL;H?4bBiY2{qhJI8Qi=xFo_!SqZV^8+bUY* zaKWL$Wzm2;chxFI>52SAlS{~+URcfQ>)Nbya{GL(`Ap{!}UPXcrA+k`0BuhRm|DNzw)G zrlzo$&ajuCn!5Ihmfo;5xqj$^-fwhRIKCHwX^#F=Ej1o#=_BSqfxU zi(d{(ZQ_Y27cEfi5VeIQpXpD2tMbduXC}W^n9mHc1FeI{HSjUR)_Qz|&2(L76pXex+u%N%G3OZB=>pbF$umQkz{NPm^fR z^}gaKgZ(7(e!ib^0oh)U-DTxI$9ylqCi4x@N5l#90JksgCF zym=x^;0pJXAUynPBD*`)rRP3h!poH=@`bW?##xErE5&qWSIMy2jikAU4O9*<;6MKsd-ETm0Zg45csT7u zc@_@0d#|&za%j5#y%Q87h~ugBD71TTBUS3^{Th!Bl1L1vu3F;#WSc#PE2hfg`JeIY zzRF*teZPaIG?9#zWKa4Jw$qQg4`!JBG)g_;dCsEv4#iaZl2 zsnWb`peR6w79U3ZO~UCBj>X~-yQ&&%`Mq;r|2=Z`{BH=uR-)r~X$i&k7cw4Bp`2f# zhFpB#`u@X@KIHcRmNv)_B2M0(liTO7)2Kz+Sa-8Z4z4OZC5N+$17UQMNcHiYvKCHt zN57#?#1DnwAJMz3j-$!}mo8u*M_PTC8M71qHCq4oXv$_)v89u7e%fy;Wv@!YKJEYp zs(j;Hgs7+d2CaI98jl*P!!D&%n~TsLPB_QWQtbtXg^FhqxEd-V%B+_38x}#<*}tln zHV`#iXEjqtdr}CtdBZ*5?%STC7dJ5;5yKVpjXm<-b3VSW6cXOL@mP>Re7w(1d-F|+Ld_|Z~bwBsF^UwF@;m`@ zsUPw8X{FLw)u>ZgSM{m?MXmfCh=OWe4Gssc3c`obU%4!Fs%l?Y^`P4*>%J;%eZO>l zf1I&N5&>@c{u2}q6c>G;-zPRcpoX-EA5e2dO+n2Cf9Fi{Z(`ym9>o(BVWV!=?fU)t znwmjqrt{`>dE0HF#3#S66u63BqyYR~YJNzKikm8Wxgw17)8051p}Im;S6OPyFXN)D Y?^YkQhrCP^bSk8i^d}$RvfQix3#_W^3IG5A literal 0 HcmV?d00001 diff --git a/util/__pycache__/visualizer.cpython-37.pyc b/util/__pycache__/visualizer.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..9fcd7406972a08640d2c52cee81b77bdd6535d19 GIT binary patch literal 5852 zcmai2&2uBib)Rnx1^@|uf!y_KHFlQrBDD)zk{u^2igGM#ZL4G^v9_WWEY)a;9s&ab zGqc@|T@o6cs8sC3RyI{R>Ht=`T=q4Ua`~RB`~^PEHHXwz<>0d~&hPbrT$+{~pr-px z_v?OdUcb-&ZlmF8_+9_I{3ri;L(~47Ue13efbZZDpP*qH(??oYSH00SRBv`o)mvSQ z8O$8nW2fur8ej`>rCR~4F?-~WtKBN#I&)d|GrjA5Ut=DteWtP6snM-5{e7+7cn&Js z{&O8|TQ97Mj7QWD*3eh*yp2bAn8~%dcBaj2 zsn3nvVEQ%fOrJT@99nZbcg}PgnN@)0pX6Rb(&X++2yyjIv zH3fO+V~^hb0N*@p=5Aiiz1(8XQ;Wn8(N^|3Np)J?7QwdYWJ$k|SNwkWR=08x_a=BC zVkb^y#4|wllbFdtd!^gxM>2@VVLu93%nK)xSttjEdk~5!3CB@kq@u9OF_Dd886C^Q z3mFRrQOF`**cciMPO4N?Wt51NzZQ=RYnaAKVT(bSMVyYXFhRWVMEE!==@m|xWl_S) z-4qL=y;fA7LdiaZNl|+t z2T{Y`;uSQS?&?c=)9~=D>YiT5TODK9)l;1RoMB4xrbU!57B zX)s@tS?HmU%t3Fn3Ukprtjav}6;@+)^e$^akCqDWAC=UJ&mQedKRP}VwX!^@sg~hx{`ve4V;9rRAFqpdJ%U1%=c2CQWW=b??NGXq#2W@@=HtOIV~S;Dgn zj&)`|wK3PkDwF?4t_)Y_YnUTi4ztfp%xx@2wc!SH&J2vSww{FO^|n||xfD?k4M z$R^m|e-G&X?QPNaub{l?wO=kg7Q>dqlOSeAB^GK`x00j*x#GelP6_HRpW{7&PD3RM z7eGJ|M<{_wr21z!+h zVPZpG2N~X=W|aso_Xbh#D1&Q}LisE579|jm`wFl#j@i;5IpZS;QWP!(7bxqtQn(yA zgczR96_Ot+$bQ1%#oBgZiD-1lUnhFsB6@1!VLXZoGfAH)FHknO)og$Rjb#vQ+(IW>=KkCWy$~Xr=*-xo;#{T`b&cBI4 zbl*jHl_eWKD3WhrMkwpmr_D>W)Z`ZvTE`LzQp#}PI;dBn#oP~lu;h7fr;`!$lT`Y> zG)bVHzD)h{SpD)e{o4a6GjVTcryt9~(LtIX5$%xo2IL5#32px{-KCB?GGyciA zrXe-5XAU!SC%08XL@w5yxy+hX@pyf0hWu_;KW126Rq512tZ5)|^5!*4XOJ{`%+4Dv z?c3VS%WLy`zC1^Y1vl^X%`;u4Id$e>ww~9?;etbh%e)GA?x|IZ(j)nSCYO*7U0KcC zb#2x>zH_nGe5Ui3qSGu>C{|UJ7c?IIFA^z_Sx1D|xlHBcA-=6ST*=|~mE@dFNI|aV zF7p7@p6aar)P`j=@+wlps={7Fs#x7DC2eeK3VUgRz5LXq(2?A9%{)$=0Ji!zL4PkKan8u)%xQv2c-t_M4aYzP;3&leDRgf^hdu|3FGE7 zlW*nbGefk1x%JExTWkg4a_jhe7xe!6Dm{5^czsTR@djJXjl4l8kMws1XZljUoUc&2 z`pkIr3Qpwa@MYlnYz?i8vvb%wQs33%f0JLOw4806nnLf zY-83se)oc#J;lxP3;h8PYb6h}m3$rgv!c#*15iUjt@APP^dI=@)!|>{>-k1bUwn&g zQlgBRpVG{Ho8+#1+g6!2KPTG_DY4lV@?;e4dH$FDWOy)&{7(;3j@0)q0^BIfM05{M zA`6j{MsN6wDUpRexTSmF%#Ppip9~OP?$S_nkHScpz2Q?HeHW(PPww?1%F1W|4^hed zPgT->&);UW>mZ8z19|?IAtfh*zM_~;B(hyFhBsfNBRIl?QP?~Baw5As<);_gw6-hR z-nQ7?L8sgrjqq}W?>+wmgf*=aj>4E${6+4LszSzc6t%6wO5o!P6WKyhS#VY$^hz;Z z*;O(ucOz-;VFQ(t3wY0c?9Klk4Pffjz=LTs%F^CoyYo6*R5#Q0caBk4Ac`l_r@-#N ziA1TbcPl(P97SUA{15m|J+;pN$u|2K-k4I@_A7jc_Wdg~g^5(GFh`USBvjT|I3|=p zkHgU<5}fKymFDm!LGYy-U!kUj2Bi)U6Op9?6;hHFRiygJ&*Gzqf17ab5{|{<9(GkF z)AHM8(1fRdha5fs0b$rmbo^CXLQ(xnzQga)*akJ^;Jelj9{lJ%{x<*%8|3>DU#2;7 z`usH-wI~;BZL?U7(1`MsM>0x_L_d}QWX9Y1}`ots8pi(OWGm|G!mX3 zS(PDt@^@1IsWvw#!ZPD?Dq`>qQKDVb<`#-2L=WZ7Xw0Iqf7T9lRk%_4WU|UNTvnW_ zczN;(M(m*@D{~hQMIyD&ncmkW%Bm9kRA1MaduCi7d2K!EVU=y)Uf*J6!Ge; zLB0|Z>jolLaA7gYYhCY8S2#?s$= zyQ5f0B$(OWQ`(O`35Ck6(qnx@5IK#uU7GhNfV+)eI-%&XxDa`rkVe#yq~P-u zOad#`zCav5!-ODf)S&Av*vBp1fIYkdyPz6|VAzO0z4_;*+H+_N{OG3f;-Y}g9;FMk z8o88xh^%hr9N$7wjY~V^16zpkm6@A6>Z;A9`~o?e^XNJ<3O#pJ9fG2sv>Ju^I&!V* z++tSld|^3m?@xDbD_2JW@s3ZPT$CUkB%sdghj)Cs@hRc9PnTsS?DGzIotoRigK(9% zr!-zA0Kj!4Mu6nq2Rhs2~<}Dx{P2Cm&z8ylejl DJ^<&d literal 0 HcmV?d00001 diff --git a/util/get_data.py b/util/get_data.py new file mode 100755 index 0000000..6325605 --- /dev/null +++ b/util/get_data.py @@ -0,0 +1,115 @@ +from __future__ import print_function +import os +import tarfile +import requests +from warnings import warn +from zipfile import ZipFile +from bs4 import BeautifulSoup +from os.path import abspath, isdir, join, basename + + +class GetData(object): + """ + + Download CycleGAN or Pix2Pix Data. + + Args: + technique : str + One of: 'cyclegan' or 'pix2pix'. + verbose : bool + If True, print additional information. + + Examples: + >>> from util.get_data import GetData + >>> gd = GetData(technique='cyclegan') + >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. + + """ + + def __init__(self, technique='cyclegan', verbose=True): + url_dict = { + 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', + 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' + } + self.url = url_dict.get(technique.lower()) + self._verbose = verbose + + def _print(self, text): + if self._verbose: + print(text) + + @staticmethod + def _get_options(r): + soup = BeautifulSoup(r.text, 'lxml') + options = [h.text for h in soup.find_all('a', href=True) + if h.text.endswith(('.zip', 'tar.gz'))] + return options + + def _present_options(self): + r = requests.get(self.url) + options = self._get_options(r) + print('Options:\n') + for i, o in enumerate(options): + print("{0}: {1}".format(i, o)) + choice = input("\nPlease enter the number of the " + "dataset above you wish to download:") + return options[int(choice)] + + def _download_data(self, dataset_url, save_path): + if not isdir(save_path): + os.makedirs(save_path) + + base = basename(dataset_url) + temp_save_path = join(save_path, base) + + with open(temp_save_path, "wb") as f: + r = requests.get(dataset_url) + f.write(r.content) + + if base.endswith('.tar.gz'): + obj = tarfile.open(temp_save_path) + elif base.endswith('.zip'): + obj = ZipFile(temp_save_path, 'r') + else: + raise ValueError("Unknown File Type: {0}.".format(base)) + + self._print("Unpacking Data...") + obj.extractall(save_path) + obj.close() + os.remove(temp_save_path) + + def get(self, save_path, dataset=None): + """ + + Download a dataset. + + Args: + save_path : str + A directory to save the data to. + dataset : str, optional + A specific dataset to download. + Note: this must include the file extension. + If None, options will be presented for you + to choose from. + + Returns: + save_path_full : str + The absolute path to the downloaded data. + + """ + if dataset is None: + selected_dataset = self._present_options() + else: + selected_dataset = dataset + + save_path_full = join(save_path, selected_dataset.split('.')[0]) + + if isdir(save_path_full): + warn("\n'{0}' already exists. Voiding Download.".format( + save_path_full)) + else: + self._print('Downloading Data...') + url = "{0}/{1}".format(self.url, selected_dataset) + self._download_data(url, save_path=save_path) + + return abspath(save_path_full) diff --git a/util/html.py b/util/html.py new file mode 100755 index 0000000..c7956f1 --- /dev/null +++ b/util/html.py @@ -0,0 +1,64 @@ +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, reflesh=0): + self.title = title + self.web_dir = web_dir + self.img_dir = os.path.join(self.web_dir, 'images') + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + # print(self.img_dir) + + self.doc = dominate.document(title=title) + if reflesh > 0: + with self.doc.head: + meta(http_equiv="reflesh", content=str(reflesh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=400): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/util/image_pool.py b/util/image_pool.py new file mode 100755 index 0000000..52413e0 --- /dev/null +++ b/util/image_pool.py @@ -0,0 +1,32 @@ +import random +import torch + + +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = torch.cat(return_images, 0) + return return_images diff --git a/util/util.py b/util/util.py new file mode 100755 index 0000000..ba7b083 --- /dev/null +++ b/util/util.py @@ -0,0 +1,60 @@ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os + + +# Converts a Tensor into an image array (numpy) +# |imtype|: the desired type of the converted numpy array +def tensor2im(input_image, imtype=np.uint8): + if isinstance(input_image, torch.Tensor): + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() + if image_numpy.shape[0] == 1: + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) diff --git a/util/visualizer.py b/util/visualizer.py new file mode 100755 index 0000000..69cf926 --- /dev/null +++ b/util/visualizer.py @@ -0,0 +1,163 @@ +import numpy as np +import os +import ntpath +import time +from . import util +from . import html +from scipy.misc import imresize + + +# save image to the disk +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s_%s.png' % (name, label) + save_path = os.path.join(image_dir, image_name) + h, w, _ = im.shape + if aspect_ratio > 1.0: + im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') + if aspect_ratio < 1.0: + im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') + util.save_image(im, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + def __init__(self, opt): + self.display_id = opt.display_id + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + self.opt = opt + self.saved = False + if self.display_id > 0: + import visdom + self.ncols = opt.display_ncols + self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True, use_incoming_socket=False) + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + self.saved = False + + def throw_visdom_connection_error(self): + print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n') + exit(1) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, save_result): + if self.display_id > 0: # show images in the browser + ncols = self.ncols + if ncols > 0: + ncols = min(ncols, len(visuals)) + h, w = next(iter(visuals.values())).shape[:2] + table_css = """""" % (w, h) + title = self.name + label_html = '' + label_html_row = '' + images = [] + idx = 0 + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + label_html_row += '%s' % label + images.append(image_numpy.transpose([2, 0, 1])) + idx += 1 + if idx % ncols == 0: + label_html += '%s' % label_html_row + label_html_row = '' + white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 + while idx % ncols != 0: + images.append(white_image) + label_html_row += '' + idx += 1 + if label_html_row != '': + label_html += '%s' % label_html_row + # pane col = image row + try: + self.vis.images(images, nrow=ncols, win=self.display_id + 1, + padding=2, opts=dict(title=title + ' images')) + label_html = '%s
' % label_html + self.vis.text(table_css + label_html, win=self.display_id + 2, + opts=dict(title=title + ' labels')) + except ConnectionError: + self.throw_visdom_connection_error() + + else: + idx = 1 + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) + idx += 1 + + if self.use_html and (save_result or not self.saved): # save images to a html file + self.saved = True + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + # losses: dictionary of error labels and values + def plot_current_losses(self, epoch, counter_ratio, opt, losses): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) + try: + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) + except ConnectionError: + self.throw_visdom_connection_error() + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, i, losses, t, t_data): + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message)