@@ -805,7 +805,12 @@ void Net::prepare_recurrent(vtensor tin, vtensor tout, int &inl, int &outl, vten
805
805
for (j=1 ;j<xtd[i]->ndim ;j++)
806
806
shape.push_back (xtd[i]->shape [j]);
807
807
808
- tinr.push_back (Z);
808
+ vector<int >zero_shape;
809
+ for (j=0 ;j<tout[i]->ndim ;j++)
810
+ if (j!=1 ) zero_shape.push_back (tout[i]->shape [j]);
811
+
812
+ if (!isencoder) tinr.push_back (new Tensor (tin[0 ]->shape ,tin[0 ]->ptr ,tin[0 ]->device ));
813
+ tinr.push_back (Tensor::zeros (zero_shape,tout[i]->device ));
809
814
for (j=0 ;j<outl-1 ;j++)
810
815
tinr.push_back (new Tensor (shape,xtd[i]->ptr +(j*offset),xtd[i]->device ));
811
816
}
@@ -837,12 +842,7 @@ void Net::fit_recurrent(vtensor tin, vtensor tout, int batch, int epochs) {
837
842
int inl;
838
843
int outl;
839
844
840
- vector<int >shape;
841
- for (j=0 ;j<tout[0 ]->ndim ;j++)
842
- if (j!=1 ) shape.push_back (tout[0 ]->shape [j]);
843
- Tensor *Z=Tensor::zeros (shape,tout[0 ]->device );
844
-
845
- prepare_recurrent (tin,tout,inl,outl,xt,xtd,yt,tinr,toutr,Z);
845
+ prepare_recurrent (tin,tout,inl,outl,xt,xtd,yt,tinr,toutr);
846
846
847
847
if (rnet==nullptr ) build_rnet (inl,outl);
848
848
@@ -851,10 +851,13 @@ void Net::fit_recurrent(vtensor tin, vtensor tout, int batch, int epochs) {
851
851
else if (isencoder)
852
852
rnet->fit (tinr,tout,batch,epochs);
853
853
else if (isdecoder)
854
- rnet->fit (tin ,toutr,batch,epochs);
854
+ rnet->fit (tinr ,toutr,batch,epochs);
855
855
856
856
if (snets[0 ]->dev !=DEV_CPU) rnet->sync_weights ();
857
857
858
+ for (i=0 ;i<tinr.size ();i++) delete (tinr[i]);
859
+ for (i=0 ;i<toutr.size ();i++) delete (toutr[i]);
860
+
858
861
if (isencoder) {
859
862
for (i=0 ;i<xt.size ();i++)
860
863
delete xt[i];
@@ -870,8 +873,6 @@ void Net::fit_recurrent(vtensor tin, vtensor tout, int batch, int epochs) {
870
873
yt.clear ();
871
874
}
872
875
873
- delete Z;
874
-
875
876
}
876
877
877
878
// TODO: train_batch_recurrent
0 commit comments