Skip to content

Commit bb9a242

Browse files
committed
some new fix
1 parent c08a050 commit bb9a242

1 file changed

Lines changed: 18 additions & 15 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -121,21 +121,19 @@ def prepare_latents(
121121
nlf=latents.shape[1],
122122
exp=num_latent_frames)
123123
latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image)
124-
mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype)
124+
# 1. Create a base mask at the latent frame level
125+
mask_lat_size = jnp.ones((batch_size, 1, num_latent_frames, latent_height, latent_width), dtype=dtype)
126+
# 2. Apply masking based on last_image
125127
if last_image is None:
126128
mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0)
127129
else:
128-
mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0)
129-
first_frame_mask = mask_lat_size[:, :, 0:1]
130-
first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2)
131-
jax.debug.print("first_frame_mask.shape:{shape}, is None:{isnone}",
132-
shape = first_frame_mask.shape if first_frame_mask is not None else (-1,),
133-
isnone = first_frame_mask is None)
134-
jax.debug.print("first_frame_mask_stats: min={mn}, max={mx}, mean={mean}",
135-
mn=jnp.min(first_frame_mask) if first_frame_mask is not None else 0.0,
136-
mx=jnp.max(first_frame_mask) if first_frame_mask is not None else 0.0,
137-
mean=jnp.mean(first_frame_mask) if first_frame_mask is not None else 0.0)
138-
mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2)
130+
mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0)
131+
132+
# 3. Expand the mask to match the temporal scale factor during reshape
133+
mask_lat_size = jnp.repeat(mask_lat_size, self.vae_scale_factor_temporal, axis=2)
134+
jax.debug.print("mask_lat_size shape after repeat: {shape}", shape=mask_lat_size.shape)
135+
136+
# 4. Reshape to combine latent frames and temporal scale factor
139137
mask_lat_size = mask_lat_size.reshape(
140138
batch_size,
141139
1,
@@ -144,16 +142,21 @@ def prepare_latents(
144142
latent_height,
145143
latent_width
146144
)
145+
# 5. Transpose and squeeze to get the final mask shape (B, F_l, H_l, W_l, T_sf)
147146
mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1)
147+
jax.debug.print("mask_lat_size final shape: {shape}", shape=mask_lat_size.shape)
148+
149+
# 6. Concatenate with latent_condition
148150
condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1)
149151
jax.debug.print("condition shape: {shape}, channel dim: {c}",
150152
shape=condition.shape,
151153
c=condition.shape[-1])
152154
jax.debug.print("condition stats: mask_mean={mm}, latent_mean={lm}",
153-
mm=jnp.mean(condition[..., 0]),
154-
lm=jnp.mean(condition[..., 1:]))
155+
mm=jnp.mean(condition[..., :self.vae_scale_factor_temporal]),
156+
lm=jnp.mean(condition[..., self.vae_scale_factor_temporal:]))
155157

156-
return latents, condition, None
158+
first_frame_mask = mask_lat_size[:, 0:1, :, :, :] # (B, 1, H_l, W_l, 4)
159+
return latents, condition, first_frame_mask
157160

158161

159162
def __call__(

0 commit comments

Comments
 (0)