@@ -121,21 +121,19 @@ def prepare_latents(
121121 nlf = latents .shape [1 ],
122122 exp = num_latent_frames )
123123 latent_condition , _ = self .prepare_latents_i2v_base (image , num_frames , dtype , last_image )
124- mask_lat_size = jnp .ones ((batch_size , 1 , num_frames , latent_height , latent_width ), dtype = dtype )
124+ # 1. Create a base mask at the latent frame level
125+ mask_lat_size = jnp .ones ((batch_size , 1 , num_latent_frames , latent_height , latent_width ), dtype = dtype )
126+ # 2. Apply masking based on last_image
125127 if last_image is None :
126128 mask_lat_size = mask_lat_size .at [:, :, 1 :, :, :].set (0 )
127129 else :
128- mask_lat_size = mask_lat_size .at [:, :, 1 :- 1 , :, :].set (0 )
129- first_frame_mask = mask_lat_size [:, :, 0 :1 ]
130- first_frame_mask = jnp .repeat (first_frame_mask , self .vae_scale_factor_temporal , axis = 2 )
131- jax .debug .print ("first_frame_mask.shape:{shape}, is None:{isnone}" ,
132- shape = first_frame_mask .shape if first_frame_mask is not None else (- 1 ,),
133- isnone = first_frame_mask is None )
134- jax .debug .print ("first_frame_mask_stats: min={mn}, max={mx}, mean={mean}" ,
135- mn = jnp .min (first_frame_mask ) if first_frame_mask is not None else 0.0 ,
136- mx = jnp .max (first_frame_mask ) if first_frame_mask is not None else 0.0 ,
137- mean = jnp .mean (first_frame_mask ) if first_frame_mask is not None else 0.0 )
138- mask_lat_size = jnp .concatenate ([first_frame_mask , mask_lat_size [:, :, 1 :]], axis = 2 )
130+ mask_lat_size = mask_lat_size .at [:, :, 1 :- 1 , :, :].set (0 )
131+
132+ # 3. Expand the mask to match the temporal scale factor during reshape
133+ mask_lat_size = jnp .repeat (mask_lat_size , self .vae_scale_factor_temporal , axis = 2 )
134+ jax .debug .print ("mask_lat_size shape after repeat: {shape}" , shape = mask_lat_size .shape )
135+
136+ # 4. Reshape to combine latent frames and temporal scale factor
139137 mask_lat_size = mask_lat_size .reshape (
140138 batch_size ,
141139 1 ,
@@ -144,16 +142,21 @@ def prepare_latents(
144142 latent_height ,
145143 latent_width
146144 )
145+ # 5. Transpose and squeeze to get the final mask shape (B, F_l, H_l, W_l, T_sf)
147146 mask_lat_size = jnp .transpose (mask_lat_size , (0 , 2 , 4 , 5 , 3 , 1 )).squeeze (- 1 )
147+ jax .debug .print ("mask_lat_size final shape: {shape}" , shape = mask_lat_size .shape )
148+
149+ # 6. Concatenate with latent_condition
148150 condition = jnp .concatenate ([mask_lat_size , latent_condition ], axis = - 1 )
149151 jax .debug .print ("condition shape: {shape}, channel dim: {c}" ,
150152 shape = condition .shape ,
151153 c = condition .shape [- 1 ])
152154 jax .debug .print ("condition stats: mask_mean={mm}, latent_mean={lm}" ,
153- mm = jnp .mean (condition [..., 0 ]),
154- lm = jnp .mean (condition [..., 1 :]))
155+ mm = jnp .mean (condition [..., : self . vae_scale_factor_temporal ]),
156+ lm = jnp .mean (condition [..., self . vae_scale_factor_temporal :]))
155157
156- return latents , condition , None
158+ first_frame_mask = mask_lat_size [:, 0 :1 , :, :, :] # (B, 1, H_l, W_l, 4)
159+ return latents , condition , first_frame_mask
157160
158161
159162 def __call__ (
0 commit comments