Fix: per-frame timestep allocation for video2world (image2world) mode in CosmosPredict2#26
Conversation
In video2world (image2world) mode, conditioning frames were receiving the same noisy timestep as all other frames. This caused the transformer to treat the clean conditioning frame as a fully-noised input, breaking temporal coherence and causing severe quality degradation. Fix: after replacing conditioning frames in model_input, expand t to shape (B, T) and zero out timesteps for frames where condition_mask=1, signaling to the model that those frames are already clean (t=0). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
juliusberner
left a comment
There was a problem hiding this comment.
Thanks a lot for the PR, I left a few comments.
| v2w_condition = {"conditioning_latents": conditioning_latents, "condition_mask": condition_mask} | ||
| model_input = self.preserve_conditioning(latents, v2w_condition) | ||
|
|
||
| # Per-frame timesteps: conditioning frames get special timestep if enabled |
There was a problem hiding this comment.
This part is now handled by the forward pass and can be removed. We can add conditional_frame_timestep as a kwarg to the forward method and then we just need to pass the conditioning_latents, e.g.,
cond_with_mask = {
"text_embeds": condition,
"conditioning_latents": conditioning_latents,
"condition_mask": condition_mask,
}
| # model knows they are noise-free and should not be denoised. Without this, the | ||
| # model treats the clean first frame as a fully-noised frame, causing incoherent | ||
| # video2world (image2world) generation. | ||
| t_expanded = t.unsqueeze(1).expand(B, T) |
There was a problem hiding this comment.
Can we make this more robust:
if t.ndim == 1:
t_expanded = t.unsqueeze(1).expand(B, T)
else:
t_expanded = t.expand(B, T)
| # video2world (image2world) generation. | ||
| t_expanded = t.unsqueeze(1).expand(B, T) | ||
| mask_B_T = condition_mask[:, 0, :, 0, 0] # (B, T) | ||
| t = t_expanded * (1 - mask_B_T) |
There was a problem hiding this comment.
I think we need to introduce a new time variable for the transformer input (e.g., transformer_t), since we use t again below for convert_model_output and it needs to be expanded as t = transformer_t.reshape(B, 1, T, 1, 1) for per-frame timesteps.
|
@coderabbitai review |
✅ Action performedReview finished.
|
|
@greptileai review |
WalkthroughThe change modifies how the CosmosPredict2 model processes timesteps during video2world training. When conditioning latents and a condition mask are provided, the timestep tensor is converted to per-frame representation and timesteps for conditioning frames are set to zero, treating them as noise-free inputs to the transformer. ChangesPer-frame Timestep Conditioning for Noise-Free Video Frames
Estimated Code Review Effort🎯 2 (Simple) | ⏱️ ~8 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
♻️ Duplicate comments (2)
fastgen/networks/cosmos_predict2/network.py (2)
1361-1361:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftIntroduce a separate variable for transformer timesteps to avoid shape issues in
convert_model_output.Reassigning
tto shape(B, T)affects downstream calls toconvert_model_outputat lines 1394-1396 and 1402-1404. For proper broadcasting with output shape(B, C, T, H, W), the timestep tensor likely needs to be reshaped to(B, 1, T, 1, 1). Using a separate variable (e.g.,transformer_t) preserves the originaltand allows correct reshaping forconvert_model_output.Please verify how
convert_model_outputhandles per-frame timesteps:#!/bin/bash # Description: Check convert_model_output implementation for timestep shape handling # Find the noise_scheduler implementation rg -n -A 20 'def convert_model_output' --type py#!/bin/bash # Description: Check if there are any tests for per-frame timestep conversion rg -n -C 3 'convert_model_output.*B.*T' --type py🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@fastgen/networks/cosmos_predict2/network.py` at line 1361, Currently t is being overwritten with a (B, T) shaped tensor which breaks broadcasting in convert_model_output; instead create a new variable (e.g., transformer_t) from t_expanded * (1 - mask_B_T) and leave the original t intact. Update uses of the timestep for the transformer to use transformer_t and, before calling convert_model_output (references: convert_model_output calls around lines handling outputs with shape (B, C, T, H, W)), reshape transformer_t to (B, 1, T, 1, 1) so it broadcasts correctly with the model output; ensure any other downstream code that expects the original t continues to use the original variable.
1359-1359:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winHandle both 1D and 2D timestep inputs for robustness.
The current implementation assumes
thas shape(B,), but the function signature documentst: torch.Tensorof shape(B,)or(B, T). Iftis already 2D,t.unsqueeze(1).expand(B, T)will produce an incompatible shape(B, 1, T).Proposed fix per juliusberner's suggestion
- t_expanded = t.unsqueeze(1).expand(B, T) + if t.ndim == 1: + t_expanded = t.unsqueeze(1).expand(B, T) + else: + t_expanded = t.expand(B, T)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@fastgen/networks/cosmos_predict2/network.py` at line 1359, The current line t_expanded = t.unsqueeze(1).expand(B, T) assumes t is 1D; update the code to handle both 1D and 2D timestep inputs by checking t.dim(): if t.dim() == 1, unsqueeze and expand to (B, T); if t.dim() == 2, validate or reshape to ensure it is (B, T) and use it directly (or expand/broadcast if one dimension is 1); otherwise raise a clear error. Apply this change where t_expanded is created so functions that accept t of shape (B,) or (B, T) (reference the variable t and the t_expanded assignment) behave correctly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@fastgen/networks/cosmos_predict2/network.py`:
- Line 1361: Currently t is being overwritten with a (B, T) shaped tensor which
breaks broadcasting in convert_model_output; instead create a new variable
(e.g., transformer_t) from t_expanded * (1 - mask_B_T) and leave the original t
intact. Update uses of the timestep for the transformer to use transformer_t
and, before calling convert_model_output (references: convert_model_output calls
around lines handling outputs with shape (B, C, T, H, W)), reshape transformer_t
to (B, 1, T, 1, 1) so it broadcasts correctly with the model output; ensure any
other downstream code that expects the original t continues to use the original
variable.
- Line 1359: The current line t_expanded = t.unsqueeze(1).expand(B, T) assumes t
is 1D; update the code to handle both 1D and 2D timestep inputs by checking
t.dim(): if t.dim() == 1, unsqueeze and expand to (B, T); if t.dim() == 2,
validate or reshape to ensure it is (B, T) and use it directly (or
expand/broadcast if one dimension is 1); otherwise raise a clear error. Apply
this change where t_expanded is created so functions that accept t of shape (B,)
or (B, T) (reference the variable t and the t_expanded assignment) behave
correctly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Enterprise
Run ID: 4520c194-03d3-493e-b094-8039ddd969f2
📒 Files selected for processing (1)
fastgen/networks/cosmos_predict2/network.py
Problem
In
CosmosPredict2.forward(), when running in video2world (image2world) mode, the conditioning (clean) frames were not assigned a special timestep of0.0. Instead, the model received the same noisy timestep for all frames — including the clean conditioning frame(s). This caused the model to treat the clean first frame as a fully-noised frame, leading to:Fix
After replacing the conditioning frames in
model_input, expandtto per-frame shape(B, T)and zero out the timestep for conditioning frames (indicated bycondition_mask). This tells the transformer that those frames are already clean and require no denoising:The transformer already accepts
timesteps_B_Tof shape(B, T), so no other changes are needed.Impact
Without this fix,
CosmosPredict2video2world distillation produces incoherent videos that ignore the conditioning frame. With this fix, the model correctly preserves the conditioning frame and generates temporally consistent video.Summary by CodeRabbit