-
Notifications
You must be signed in to change notification settings - Fork 53
[WAN] Adds VACE conditioning to WAN 2.1 #304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| img_height, img_width = image.shape[-2:] | ||
| scale = min(image_size[0] / img_height, image_size[1] / img_width) | ||
| new_height, new_width = int(img_height * scale), int(img_width * scale) | ||
| # TODO: should we use jax/TF-based resizing here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know what you think about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is necessary right now. Wouldn't it require casting to numpy for running the torch function below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It worked fine so far, video_processor.preprocess returns a Torch tensor already, but I will keep an eye just in case
e9086b5 to
db2a559
Compare
Co-authored-by: ninatu <[email protected]>
| blocks.append(block) | ||
| self.blocks = blocks | ||
|
|
||
| if scan_layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't looked too deeply at the vace architecture, but why is it that scan cannot be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure how to do it, because the nnx.vmap decorator does not differentiate between each separate layer. In fact, it simply creates a tensor with an extra axis, so passing parameters like apply_input_projection=vace_block_id == 0 is to my knowledge not feasible. I think the nnx.scan function later can probably be used in this context if we keep some new variable that acts as counter to identify the current iteration, but I was not able to work around the limitation in the initialization (and this parameter cannot be passed later because it conditions how the layer is initialized). I would like to support this though, in case you have any ideas I would appreciate it!
I can also try to have the Wan layers vmap-initialized and skip it for the Vace ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok sounds good we can add it later.
| img_height, img_width = image.shape[-2:] | ||
| scale = min(image_size[0] / img_height, image_size[1] / img_width) | ||
| new_height, new_width = int(img_height * scale), int(img_width * scale) | ||
| # TODO: should we use jax/TF-based resizing here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is necessary right now. Wouldn't it require casting to numpy for running the torch function below?
This brings in the VACE model taken from diffusers, trying to comply as much as possible with the conventions upstream.