From 4de84dd1eb8930febeae7c3a25e298bbf8f4c50e Mon Sep 17 00:00:00 2001 From: smail8 Date: Tue, 17 Nov 2020 07:55:25 +0100 Subject: [PATCH] Bug fix --- DataLoader.py | 56 +++------------- __pycache__/DataLoader.cpython-36.pyc | Bin 0 -> 6677 bytes __pycache__/network.cpython-36.pyc | Bin 0 -> 2312 bytes __pycache__/utils.cpython-36.pyc | Bin 0 -> 3896 bytes prepare_data.py | 92 ++++++++++++++++++++++++++ test.py | 17 ++--- train.py | 35 ++++------ 7 files changed, 123 insertions(+), 77 deletions(-) create mode 100644 __pycache__/DataLoader.cpython-36.pyc create mode 100644 __pycache__/network.cpython-36.pyc create mode 100644 __pycache__/utils.cpython-36.pyc create mode 100644 prepare_data.py diff --git a/DataLoader.py b/DataLoader.py index 195102f..1b2208b 100644 --- a/DataLoader.py +++ b/DataLoader.py @@ -8,14 +8,13 @@ import numpy as np from PIL import Image, ImageDraw import cv2 -import matplotlib.pyplot as plt import time import utils -import matplotlib.pyplot as plt class myJAAD(torch.utils.data.Dataset): def __init__(self, args): + print('Loading', args.dtype, 'data ...') if(args.from_file): sequence_centric = pd.read_csv(args.file) @@ -30,16 +29,12 @@ def __init__(self, args): else: #read data + print('Reading data files ...') df = pd.DataFrame() new_index=0 for file in glob.glob(os.path.join(args.jaad_dataset,args.dtype,"*")): temp = pd.read_csv(file) if not temp.empty: - #drop unnecessary columns - temp = temp.drop(columns=['type', 'occlusion', 'nod', 'slow_down', 'speed_up', 'WALKING', 'walking', - 'standing', 'looking', 'handwave', 'clear_path', 'CLEAR_PATH','STANDING', - 'standing_pred', 'looking_pred', 'walking_pred','keypoints', 'crossing_pred']) - temp['file'] = [file for t in range(temp.shape[0])] #assign unique ID to each @@ -51,8 +46,8 @@ def __init__(self, args): temp = temp.sort_values(['ID', 'frame'], axis=0) df = df.append(temp, ignore_index=True) - print('reading files complete') + print('Processing data ...') #create sequence column df.insert(0, 'sequence', df.ID) @@ -119,18 +114,12 @@ def __init__(self, args): sequence_centric = data.copy() - if args.sample: - if args.trainOrVal == 'train': - self.data = sequence_centric.loc[:args.n_train_sequences].copy().reset_index(drop=True) - elif args.trainOrVal == 'val': - self.data = sequence_centric.loc[args.n_train_sequences:].copy().reset_index(drop=True) - - else: - self.data = sequence_centric.copy().reset_index(drop=True) + + self.data = sequence_centric.copy().reset_index(drop=True) self.args = args self.dtype = args.dtype - print(self.dtype, " set loaded") + print(args.dtype, "set loaded") print('*'*30) @@ -195,32 +184,9 @@ def scene_transforms(self, scene): def data_loader(args): - if args.dtype == 'train': - train_set = myJAAD(args) - train_loader = torch.utils.data.DataLoader( - train_set, batch_size=args.batch_size, shuffle=args.loader_shuffle, - pin_memory=args.pin_memory, num_workers=args.loader_workers, drop_last=True) - - args.trainOrVal = 'val' - - val_set = myJAAD(args) - val_loader = torch.utils.data.DataLoader( - val_set, batch_size=args.batch_size, shuffle=args.loader_shuffle, - pin_memory=args.pin_memory, num_workers=args.loader_workers, drop_last=True) - - return train_loader, val_loader - - elif args.dtype == 'val': - - #rgs.file = args.val_file - #rgs.dtype = 'val' - #rgs.trainOrVal = 'test' - #rgs.sample = False - - test_set = myJAAD(args) - - test_loader = torch.utils.data.DataLoader( - test_set, batch_size=args.batch_size, shuffle=args.loader_shuffle, - pin_memory=args.pin_memory, num_workers=args.loader_workers, drop_last=True) + dataset = myJAAD(args) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, shuffle=args.loader_shuffle, + pin_memory=args.pin_memory, num_workers=args.loader_workers, drop_last=True) - return test_loader + return dataloader diff --git a/__pycache__/DataLoader.cpython-36.pyc b/__pycache__/DataLoader.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69b71ad089d8eaa3d972a1b1c6f20f6e157c69d2 GIT binary patch literal 6677 zcmb_g&5s+$mG7!}I|xFetE*nU zdiCC`SMQ_Ws8ouB?r**S+BA$m88bf%{WtN%zd|Dn!P4UC{4Eck$m1VTs&FhR6j* zxEL;l%OlfW5w^&GWVow>y=T;&r{K%jtus|_w~|2mZ7(?T+jXvT_dEW6pbGl6BmGAW z!kqj#I&b2MucHYJlEvkM2@cApFolIT_7*nYIgu9*-nJ-U$D*=3C*OMOt(^wZpx-$> zB|Ijc`1d)(NT8?3MhI#nlNso7`I`hCX6=r&g1%8(}Ha$B3Psxof;{9Nu9D{U6==0s7y^^S+UqM19>z zIlb>1r&jmR;LRrL;HOMHi2N{j{J%*htR&TNF06sadCj{sGBsB}{W;UTGjMLFcDNA# z@n@{bL_rika*{w9A5`?Ib0bTkm}+vpXNmQE5M%H=0)h`v}6t71*8?^{}X8_zH?ckV9HJ1;8xkWHK;j3hr|Sbt8O$9iXG_QK5U z1)2>nfaiec?PSLI$9_t7u;u&LuC03nHIZ z#LG>b_pmk|SI6V|$$XPzc42k|yBM$4bBmLedbFfR+gO8BK2BgKw6@aabX8Q=4NZMnLx|T`2@mV^ z%8zkJptY-_O!ACOjZ0WtDakRxr9GxQ&8&gU<28i6m6Eiyrv6#|6=Ag@j#H_~-j`NkOgLP~44fpIm%7+WtvUjwv0Sx0vO zmV5IwCyz1rI{IG{*Tid2IK07z_N!z6I-5>%24ADr^ z6oL{v{y#(J(f{Q3&4*DZxEXi+R@-kSH}|4`m-e`^7aiZ|$w0IkNh|8!+`*2teIU1c zC(60o_B(sRzqd}?;4uc5v*@$!d97|M@x1N3?Wp0m<9q02xPz7N2Acm(&3?1h4q_5Q z63>cVo60~;QqWq&6W7q}LxvBWA%fR0`Jwe*#=7Po0hMpRu2=%0ETX8Pk_7*xc(>h( zlSb6(-P#SbW|xAwb{3?Qorjt7hH z%BT^Bh>T^U_$AhK#iKyhQc`m23NZLHi8zQWE(EWpQOl5*$m1}izI?9gv3Eof+>6|3?<$nCk{`3L8 zJ=yIkslRSZO3Kt?v!C=~J+ms47ce6)(R6vLj%Y8In*g2Rb*c-p@)X)wC;s^E-z9zv zYxIHr^GJqYOfow$&E-C@@%ziyzv!Sk!1^3cUtkDR^5f1I8$@@q&)YadoGo4GVjLu6 z&-(Nt=IcCDE?L>rzv3rfL+^!6{kF>WWUHH~oJdZ3$o88u>Ud=DiuZ(aBpjjFh>x_N zQC1^@PqoNfN!m@>C>8oa40DSbD(Cll?Gv{^9{6qYL(1B3M|+A#v9jn+DJzUxT~!W! z5GB_N=2b4}^pX>mlYSS~MlOEn_X0^eQ+B`Gf@4y78T8tIBXCP`BolAErH$Q!u26QX z8wWCROELht%+IO(zKr_4y%V`g(qE>A>{-^Sd5M~r(I^{ov;;)$_T&`}IEoX$aR5=I z^iOE#jGw8=Ya>^75@{X0N|dhBLM!%<0vB*nmMYN5%S0wQbzV zJ7ue(xM(VdwcX&+SPCnFZZyw|@9ilTD+ZK?V#3Yc-+lLc4_x~@-+Q3n?>+eL{hfbv ztFsr?i=G=NZeDviOyu>=^&*wi9c99t6l=L=FTnc%%p~&LBwe03=bC421KOo4XA(T= zpyK~VW3Vc>S(!O(0WD|Sj7CeA1Bk^oSrOw+vxd=%xxg2=!)tuOEb}Y;BHPmc6<%fw ze3O+yhyFGEtK)x>uYrQa&V%}OcAH((|66>CFR>NCGCb-nup&Rtmbk-e+`&#Zb}X~O z+;g$?Q%>a@BuEzODBMg#owf_D3i!`=fRy=7QcI?ayywB!dftabfIObT2G!|qdTLte z&SRE*Fgfw~Lm*=-ifwR=KS%B`c6h8EEp+>zm1H~an3TvZW^tz?)C=!qoKVi8cx$lt6$@_N9**3)#Lp*?08b&g|YAlKT`SKMO&5MY3i5{b9j`4uw6S_|(7- zKS2ose)e&m!|L(9WP{dX<*gs_VFiWfx$t~ChdDZ#3u%pdTDmF<6nzp}=hahwcpD!H zcz6+(f-(as`=u1cD1zML?(`mz`N)kc1w z`bhoY(Xz`y;a*nJ`~uG;o86k;iTxnKm!{6y%zK*9>eNUlNwmj5${%^=v;4^q0D6wU z(`Wp-1w=t)u_z(_(LFc+u04PKY)|O%2CaeaM5}XVTkcz><$#RL?3n7N&e8_fE8Dk)&ef1Ge+8uESE>0LHRQS{Dk8@OZazar z|AU$A2-%?c&>yr8Z%{-VR3?R(9wKg2`s`K_Fk6E_czPH=mv_`RwIfwlOQ#D-BpVMA zOM@=ry4)e2C=1OZH65V5OAWCdSIr*-E~NqB42o7KVO&ki~Sf;U}ul z@5V?Sf&p@ns2jxcUuh9?rsU3ToerpCKMp2bQs|doCyYYS?RW4I6vYhrXe>sKp=w#w zOpv0)%}91)`IaV^^n2}~KA(jJN-b1AOSMrlAR!Lq5rV0mrO?U*#4Vul;$8r_MHPfw zq&9<$`#K)GRSoY(T`wD#vUsiQ4dg+jY@J`k`u4f@Sl{||J#eB8eR867A63corWfz0 z0K{*iF%WYtj;P76vNc5ID!Yp4N>McHuds`Vo-`^V{sMztK@Tz3DzXh;=7WV#-|A_! zBbP#t4z1)4bQMXl2@x+VR8SN@cv}^aHf2@-d(<&@D#FZ#rS$jz@xOYk{*)?MinBHzif0aeEPrCt$86}r++k9P(NgP mB)(^4omv7uCuOE|kD^q39u3J#SuZ{GrwkW=bLBsc>wg2%l<|82 literal 0 HcmV?d00001 diff --git a/__pycache__/network.cpython-36.pyc b/__pycache__/network.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4d2cf16e3793bbec789f4571dd47d64fa79824e GIT binary patch literal 2312 zcmZuy&2HO95Z+z>ijpi@{!7xJ|1ByMu~GEWL(vvU;pX5NL7lb$(SoJ9E1R@L%Da@~ z8mp61?R?wvoph)*>9FJ!_T|j;EQ+vj6UrU@;ABh z*uZ}SOFaU@38!Nc);rBu$SRGMvCIygigz+pVH8NMr;x~5XU#(qaUNa zgU1gO*0|&Dz`}6q1(1@IRFehE*n&_}TA9tMrkTUpCDSBwU(M!Lxf zB;MM;!Y?#rCP?$?tc+B8Ax!%?I0I%9<9*sm?&ZJPLQ-IY{y24w!qICps0(NKd z*m#rp9GM|C7EdQv|H3oeGLD)}g_P)G{3uHEw2UId^4!>X^Ac~{bk$S{9*I0DxR9nj zE!0aPI1r~Zk(Vi);qIlmh^4W=7kfXLZe=z(6r88|k?}F!eK8&{!*NlE>nE~M3QVG< z#vdh7b);znw|s5<2eIU3oFAK&f1Q)p)95>{uQtxG-CZ@kqXp z*FsGzijr}xR22O|{`qX@xR{6?HHp)4oR&L>#Vkk2_YRBmds8WRnv`jg@8qI2IzcY?~4*STuP3CRwX=jVfJKEPS?A`NU%U0zTRJ47EEa!5z#H$VBL4$bheEdelk1ND$dQ}hfT7!5nUxO3Z z^#HK01Gj-(g=`h1O}&o#Du=#ae;WKu-?*rZdG}1=#2((Fk*1K zn^mrxXkFgYtp(1|_P%UESxD?+=>svHWmQH{l~BPh7iVcAhE&3IC#IF=C3azv%P#u7 zB*WWsfCe74`D9iKXbk9AaUtXy+F?H!yNuO|amqp_$Hs*?s*v&%eB;PCKN7};`2vFo zdgvsb%G>Dp4wARf#!BMSv}Sn|7m<}Q28yRsk#mVHBJU!>ikI)9j|)wSUFwvx=~xW6 zZZ3~a2PK>6EcRgX%t$=Y8qMzlH+Ls+_W%W4M~N~a9WBE%)82vD?kQ>reUX^ zU$seiJ#TLMBoc!21cC$UEoV;D_*^T-@u3(iImA>xdP(rNX1q}$pTJ6nr*Dz7s{Kvv^#;sRot&8qsx|k(a;-j-qSqP(dHUdxs6F*cV1^^$hPY_M4m<{4880T%X{IVu;u{VhMR+JmZ;U zEpx+2$fJU8-4J})C|H{_thFw+UNcVFkmD^p;pmx(39qbu$?)zo^6HFwHB;GfD4U1M zik`?=SsmZ&)+}Z9q%V{e%f7FwB8YlzujzOEZmdjC_&*zR4uk*N+Byt7{#MlSg0>gL zTL)pkD}wIhTLBgdFO4>$A{!8+|W=Yyp z>5?czD6=Je9h+1Dfrhg8XwlTo7X1}1dI|UAMbl(7>lht{y-t^tM~e=`iZSG~5L3{N zY)u_Zo|x-|4ojoMLeEZYpu^lxmrjw$){5;uh|;w>leH>`Hn-s^hXB&>>J?A=Uc;9% zly%ed&V|6--MzPrS)&Ss!EUmfe1U>+8F%9#P=qLwItUePLLmqs?6Y+D^ML?iQoUfC z#ySNg|NKJZ)(lW2Prb>wyuSE8>;H_y`|NBptW!R8^j4UOjed#rxXi>k0^X4ls6}cx ztt-wzwag%aVSu+xM-Pk06&c&r z)y*!jFA3dph2CNYVZWFHtlbBnefbKWM(Yp;+fM(y%BuV!`N zwUe^)-+LM2XHh>{y5u8`OyL2#EZLrjBR5&-3^B(9RUM&dOR*GariVhm$>gKnwH zPgOE&4=}zBT?XUZO#d~Ex7iNA!EO+Y--JbzP8al}mI$k}`n%`=#tb zXm&_iT5nJ3Ez)cLD>hKWGPt7Bo4Uv>4^_F;*ac*c1zKJb83f8CWuJPR- z7-C{k+9t6kwpeg>;+tr(%yYU8l%ywnUz0!f=xN3Qok|525GUc8r__?CIJrJISAg^L zV|HfErOSt)qI2=NFheJkrKCKbs zf}fB{ffYyiKLkxb*v1o0yAhM}{86^_x{rNLEoFsWKhjjI*>#(01u9E;5H)&rS096V zC^e;KUrIV2v!IB@a@6yEaR*kP!jp(5l6k^gJh2k1g%*Ldvq0+9 zMFdW3E%=6_SF;sAIH*?!jdpn4X#1_WUOAcctAoa2@c6JHA>w-VWa_RMoy9RzL)rCZ z^ex_$of1LSl&oh+RRaepEL-sBCId^cXmcOw#1}QLtX9x&EBiPQ@gaa?wS%r-bMbK}z2jzxq%w%* zdw8+@07CJ>+lrsy8r(s92j}AuQHqL-b!oKpoACh#Wgi1e>^hSl4c)+_CKU#tP$uk*K|wx NSKMX%OYWju{s$aA_<{fc literal 0 HcmV?d00001 diff --git a/prepare_data.py b/prepare_data.py new file mode 100644 index 0000000..5abb1e5 --- /dev/null +++ b/prepare_data.py @@ -0,0 +1,92 @@ +import os +import sys +import argparse +import numpy as np +import pandas as pd + +parser = argparse.ArgumentParser() +parser.add_argument('jaad_path', type=str, help='Path to zhr cloned JAAD repository') +parser.add_argument('train_ratio', type=float, help='Ratio of train video') +parser.add_argument('val_ratio', type=float, help='Ratio of val video') +parser.add_argument('test_ratio', type=float, help='Ratio of test video') + +args = parser.parse_args() + +data_path = args.jaad_path +sys.path.insert(1, data_path+'/') + +import jaad_data + +if not os.path.isdir(os.path.join(data_path, 'processed_annotations')): + os.mkdir(os.path.join(data_path, 'processed_annotations')) + +if not os.path.isdir(os.path.join(data_path, 'processed_annotations', 'train')): + os.mkdir(os.path.join(data_path, 'processed_annotations', 'train')) + +if not os.path.isdir(os.path.join(data_path, 'processed_annotations', 'val')): + os.mkdir(os.path.join(data_path, 'processed_annotations', 'val')) + +if not os.path.isdir(os.path.join(data_path, 'processed_annotations', 'test')): + os.mkdir(os.path.join(data_path, 'processed_annotations', 'test')) + +jaad = jaad_data.JAAD(data_path=data_path) +dataset = jaad.generate_database() + +n_train_video = int(args.train_ratio * 346) +n_val_video = int(args.val_ratio * 346) +n_test_video = int(args.test_ratio * 346) + +videos = list(dataset.keys()) +train_videos = videos[:n_train_video] +val_videos = videos[n_train_video:n_train_video+n_val_video] +test_videos = videos[n_train_video+n_val_video:] + + +for video in dataset: + print('Processing', video, '...') + vid = dataset[video] + data = np.empty((0,8)) + for ped in vid['ped_annotations']: + if vid['ped_annotations'][ped]['behavior']: + frames = np.array(vid['ped_annotations'][ped]['frames']).reshape(-1,1) + ids = np.repeat(vid['ped_annotations'][ped]['old_id'], frames.shape[0]).reshape(-1,1) + bbox = np.array(vid['ped_annotations'][ped]['bbox']) + x = bbox[:,0].reshape(-1,1) + y = bbox[:,1].reshape(-1,1) + w = np.abs(bbox[:,0] - bbox[:,2]).reshape(-1,1) + h = np.abs(bbox[:,1] - bbox[:,3]).reshape(-1,1) + scenefolderpath = np.repeat(os.path.join(data_path, 'scene', video.replace('video_', '')), frames.shape[0]).reshape(-1,1) + + cross = np.array(vid['ped_annotations'][ped]['behavior']['cross']).reshape(-1,1) + + ped_data = np.hstack((frames, ids, x, y, w, h, scenefolderpath, cross)) + data = np.vstack((data, ped_data)) + data_to_write = pd.DataFrame({'frame': data[:,0].reshape(-1), + 'ID': data[:,1].reshape(-1), + 'x': data[:,2].reshape(-1), + 'y': data[:,3].reshape(-1), + 'w': data[:,4].reshape(-1), + 'h': data[:,5].reshape(-1), + 'scenefolderpath': data[:,6].reshape(-1), + 'crossing_true': data[:,7].reshape(-1)}) + data_to_write['filename'] = data_to_write.frame + data_to_write.filename = data_to_write.filename.apply(lambda x: '%04d'%int(x)+'.png') + + if video in train_videos: + data_to_write.to_csv(os.path.join(data_path, 'processed_annotations', 'train', video+'.csv'), index=False) + elif video in val_videos: + data_to_write.to_csv(os.path.join(data_path, 'processed_annotations', 'val', video+'.csv'), index=False) + elif video in test_videos: + data_to_write.to_csv(os.path.join(data_path, 'processed_annotations', 'test', video+'.csv'), index=False) + + + + + + + + + + + + \ No newline at end of file diff --git a/test.py b/test.py index bc0a9de..3790234 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,6 @@ import torchvision import torchvision.transforms as transforms -import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import recall_score, accuracy_score, average_precision_score, precision_score @@ -18,12 +17,13 @@ class args(): def __init__(self): - self.jaad_dataset = '/data/haziq-data/jaad/annotations' #folder containing parsed jaad annotations (used when first time loading data) - self.dtype = 'val' + self.jaad_dataset = '/data/smailait-data/JAAD/processed_annotations' #folder containing parsed jaad annotations (used when first time loading data) + self.dtype = 'test' self.from_file = False #read dataset from csv file or reprocess data - self.file = '/data/smail-data/jaad_val_16_16.csv' - self.save_path = '/data/smail-data/jaad_val_16_16.csv' - self.model_path = '/data/smail-data/multitask_pv_lstm_trained.pkl' + self.save = True + self.file = '/data/smailait-data/jaad_test_16_16.csv' + self.save_path = '/data/smailait-data/jaad_test_16_16.csv' + self.model_path = '/data/smailait-data/models/multitask_pv_lstm_trained.pkl' self.loader_workers = 10 self.loader_shuffle = True self.pin_memory = False @@ -33,12 +33,9 @@ def __init__(self): self.n_epochs = 100 self.hidden_size = 512 self.hardtanh_limit = 100 - self.sample = False - self.n_train_sequences = 40000 - self.trainOrVal = 'test' - self.citywalks = False self.input = 16 self.output = 16 + self.stride = 16 self.skip = 1 self.task = 'bounding_box-intention' self.use_scenes = False diff --git a/train.py b/train.py index cfa2b26..fbf90e0 100644 --- a/train.py +++ b/train.py @@ -6,8 +6,7 @@ import torchvision import torchvision.transforms as transforms - -import matplotlib.pyplot as plt + import numpy as np from sklearn.metrics import recall_score, accuracy_score, average_precision_score, precision_score @@ -17,12 +16,13 @@ class args(): def __init__(self): - self.jaad_dataset = '../../../../data/haziq-data/jaad/annotations' #folder containing parsed jaad annotations (used when first time loading data) + self.jaad_dataset = '/data/smailait-data/JAAD/processed_annotations' #folder containing parsed jaad annotations (used when first time loading data) self.dtype = 'train' self.from_file = False #read dataset from csv file or reprocess data - self.file = '/data/smail-data/jaad_train_16_16.csv' - self.save_path = '/data/smail-data/jaad_train_16_16.csv' - self.model_path = '/data/smail-data/multitask_pv_lstm_trained.pkl' + self.save = True + self.file = '/data/smailait-data/jaad_train_16_16.csv' + self.save_path = '/data/smailait-data/jaad_train_16_16.csv' + self.model_path = '/data/smailait-data/models/multitask_pv_lstm_trained.pkl' self.loader_workers = 10 self.loader_shuffle = True self.pin_memory = False @@ -32,12 +32,9 @@ def __init__(self): self.n_epochs = 100 self.hidden_size = 512 self.hardtanh_limit = 100 - self.sample = False - self.n_train_sequences = 40000 - self.trainOrVal = 'train' - self.citywalks = False self.input = 16 self.output = 16 + self.stride = 16 self.skip = 1 self.task = 'bounding_box-intention' self.use_scenes = False @@ -46,7 +43,11 @@ def __init__(self): args = args() net = network.PV_LSTM(args).to(args.device) -train, val = DataLoader.data_loader(args) +train = DataLoader.data_loader(args) +args.dtype = 'val' +args.save_path = args.save_path.replace('train', 'val') +args.file = args.file.replace('train', 'val') +val = DataLoader.data_loader(args) optimizer = optim.Adam(net.parameters(), lr=args.lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=15, @@ -178,17 +179,7 @@ def __init__(self): '| fde: %.4f'% fde, '| aiou: %.4f'% aiou, '| fiou: %.4f'% fiou, '| state_acc: %.4f'% avg_acc, '| rec: %.4f'% avg_rec, '| pre: %.4f'% avg_pre, '| intention_acc: %.4f'% intent_acc, '| t:%.4f'%(time.time()-start)) - -print('='*100) -plt.figure(figsize=(10,8)) -plt.plot(list(range(len(train_s_scores))), train_s_scores, label = 'BB Training loss') -plt.plot(list(range(len(val_s_scores))), val_s_scores, label = 'BB Validation loss') -plt.plot(list(range(len(train_c_scores))), train_c_scores, label = 'Intention Training loss') -plt.plot(list(range(len(val_c_scores))), val_c_scores, label = 'Intention Validation loss') -plt.xlabel('epoch') -plt.ylabel('Mean square error loss') -plt.legend() -plt.show() + print('='*100) print('Saving ...') torch.save(net.state_dict(), args.model_path)