Skip to content

Fix: per-frame timestep allocation for video2world (image2world) mode in CosmosPredict2#26

Open
csy2077 wants to merge 1 commit into
NVlabs:mainfrom
csy2077:fix/cosmos-predict2-video2world-per-frame-timestep
Open

Fix: per-frame timestep allocation for video2world (image2world) mode in CosmosPredict2#26
csy2077 wants to merge 1 commit into
NVlabs:mainfrom
csy2077:fix/cosmos-predict2-video2world-per-frame-timestep

Conversation

@csy2077

@csy2077 csy2077 commented May 16, 2026

Copy link
Copy Markdown

Problem

In CosmosPredict2.forward(), when running in video2world (image2world) mode, the conditioning (clean) frames were not assigned a special timestep of 0.0. Instead, the model received the same noisy timestep for all frames — including the clean conditioning frame(s). This caused the model to treat the clean first frame as a fully-noised frame, leading to:

  • Dramatic quality degradation in generated videos
  • No temporal coherence to the conditioning frame
  • Effectively broken video2world / image2world distillation

Fix

After replacing the conditioning frames in model_input, expand t to per-frame shape (B, T) and zero out the timestep for conditioning frames (indicated by condition_mask). This tells the transformer that those frames are already clean and require no denoising:

t_expanded = t.unsqueeze(1).expand(B, T)
mask_B_T = condition_mask[:, 0, :, 0, 0]  # (B, T)
t = t_expanded * (1 - mask_B_T)

The transformer already accepts timesteps_B_T of shape (B, T), so no other changes are needed.

Impact

Without this fix, CosmosPredict2 video2world distillation produces incoherent videos that ignore the conditioning frame. With this fix, the model correctly preserves the conditioning frame and generates temporally consistent video.

Summary by CodeRabbit

  • Bug Fixes
    • Corrected per-frame timestep tensor handling in video training. Conditioning frames identified via masking now receive the correct timestep value, ensuring clean reference frames are properly processed during model training.

In video2world (image2world) mode, conditioning frames were receiving the
same noisy timestep as all other frames. This caused the transformer to
treat the clean conditioning frame as a fully-noised input, breaking
temporal coherence and causing severe quality degradation.

Fix: after replacing conditioning frames in model_input, expand t to
shape (B, T) and zero out timesteps for frames where condition_mask=1,
signaling to the model that those frames are already clean (t=0).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

@juliusberner juliusberner left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR, I left a few comments.

v2w_condition = {"conditioning_latents": conditioning_latents, "condition_mask": condition_mask}
model_input = self.preserve_conditioning(latents, v2w_condition)

# Per-frame timesteps: conditioning frames get special timestep if enabled

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is now handled by the forward pass and can be removed. We can add conditional_frame_timestep as a kwarg to the forward method and then we just need to pass the conditioning_latents, e.g.,

  cond_with_mask = {
      "text_embeds": condition,
      "conditioning_latents": conditioning_latents,
      "condition_mask": condition_mask,
  }

# model knows they are noise-free and should not be denoised. Without this, the
# model treats the clean first frame as a fully-noised frame, causing incoherent
# video2world (image2world) generation.
t_expanded = t.unsqueeze(1).expand(B, T)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this more robust:

if t.ndim == 1:
    t_expanded = t.unsqueeze(1).expand(B, T)
else:
    t_expanded = t.expand(B, T)

# video2world (image2world) generation.
t_expanded = t.unsqueeze(1).expand(B, T)
mask_B_T = condition_mask[:, 0, :, 0, 0] # (B, T)
t = t_expanded * (1 - mask_B_T)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to introduce a new time variable for the transformer input (e.g., transformer_t), since we use t again below for convert_model_output and it needs to be expanded as t = transformer_t.reshape(B, 1, T, 1, 1) for per-frame timesteps.

@juliusberner

Copy link
Copy Markdown
Collaborator

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Jun 10, 2026

Copy link
Copy Markdown
✅ Action performed

Review finished.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@juliusberner

Copy link
Copy Markdown
Collaborator

@greptileai review

@coderabbitai

coderabbitai Bot commented Jun 10, 2026

Copy link
Copy Markdown

Review Change Stack

Walkthrough

The change modifies how the CosmosPredict2 model processes timesteps during video2world training. When conditioning latents and a condition mask are provided, the timestep tensor is converted to per-frame representation and timesteps for conditioning frames are set to zero, treating them as noise-free inputs to the transformer.

Changes

Per-frame Timestep Conditioning for Noise-Free Video Frames

Layer / File(s) Summary
Per-frame timestep expansion and conditioning frame masking
fastgen/networks/cosmos_predict2/network.py
When conditioning_latents and condition_mask are present, the timestep tensor t is expanded to shape (B, T) and zeroed for frames identified by condition_mask, ensuring conditioning frames are treated as timestep-0 (noise-free) inputs to the diffusion transformer.

Estimated Code Review Effort

🎯 2 (Simple) | ⏱️ ~8 minutes

Poem

🐰 A timestep once scattered, now framed just right,
Per-frame expansion brings conditioning light,
When masks mark the clean frames, to zero we go,
So the transformer knows—no noise, all glow!

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: fixing per-frame timestep allocation for video2world mode in CosmosPredict2, which is the core fix addressed in this PR.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
fastgen/networks/cosmos_predict2/network.py (2)

1361-1361: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Introduce a separate variable for transformer timesteps to avoid shape issues in convert_model_output.

Reassigning t to shape (B, T) affects downstream calls to convert_model_output at lines 1394-1396 and 1402-1404. For proper broadcasting with output shape (B, C, T, H, W), the timestep tensor likely needs to be reshaped to (B, 1, T, 1, 1). Using a separate variable (e.g., transformer_t) preserves the original t and allows correct reshaping for convert_model_output.

Please verify how convert_model_output handles per-frame timesteps:

#!/bin/bash
# Description: Check convert_model_output implementation for timestep shape handling

# Find the noise_scheduler implementation
rg -n -A 20 'def convert_model_output' --type py
#!/bin/bash
# Description: Check if there are any tests for per-frame timestep conversion

rg -n -C 3 'convert_model_output.*B.*T' --type py
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@fastgen/networks/cosmos_predict2/network.py` at line 1361, Currently t is
being overwritten with a (B, T) shaped tensor which breaks broadcasting in
convert_model_output; instead create a new variable (e.g., transformer_t) from
t_expanded * (1 - mask_B_T) and leave the original t intact. Update uses of the
timestep for the transformer to use transformer_t and, before calling
convert_model_output (references: convert_model_output calls around lines
handling outputs with shape (B, C, T, H, W)), reshape transformer_t to (B, 1, T,
1, 1) so it broadcasts correctly with the model output; ensure any other
downstream code that expects the original t continues to use the original
variable.

1359-1359: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Handle both 1D and 2D timestep inputs for robustness.

The current implementation assumes t has shape (B,), but the function signature documents t: torch.Tensor of shape (B,) or (B, T). If t is already 2D, t.unsqueeze(1).expand(B, T) will produce an incompatible shape (B, 1, T).

Proposed fix per juliusberner's suggestion
-            t_expanded = t.unsqueeze(1).expand(B, T)
+            if t.ndim == 1:
+                t_expanded = t.unsqueeze(1).expand(B, T)
+            else:
+                t_expanded = t.expand(B, T)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@fastgen/networks/cosmos_predict2/network.py` at line 1359, The current line
t_expanded = t.unsqueeze(1).expand(B, T) assumes t is 1D; update the code to
handle both 1D and 2D timestep inputs by checking t.dim(): if t.dim() == 1,
unsqueeze and expand to (B, T); if t.dim() == 2, validate or reshape to ensure
it is (B, T) and use it directly (or expand/broadcast if one dimension is 1);
otherwise raise a clear error. Apply this change where t_expanded is created so
functions that accept t of shape (B,) or (B, T) (reference the variable t and
the t_expanded assignment) behave correctly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@fastgen/networks/cosmos_predict2/network.py`:
- Line 1361: Currently t is being overwritten with a (B, T) shaped tensor which
breaks broadcasting in convert_model_output; instead create a new variable
(e.g., transformer_t) from t_expanded * (1 - mask_B_T) and leave the original t
intact. Update uses of the timestep for the transformer to use transformer_t
and, before calling convert_model_output (references: convert_model_output calls
around lines handling outputs with shape (B, C, T, H, W)), reshape transformer_t
to (B, 1, T, 1, 1) so it broadcasts correctly with the model output; ensure any
other downstream code that expects the original t continues to use the original
variable.
- Line 1359: The current line t_expanded = t.unsqueeze(1).expand(B, T) assumes t
is 1D; update the code to handle both 1D and 2D timestep inputs by checking
t.dim(): if t.dim() == 1, unsqueeze and expand to (B, T); if t.dim() == 2,
validate or reshape to ensure it is (B, T) and use it directly (or
expand/broadcast if one dimension is 1); otherwise raise a clear error. Apply
this change where t_expanded is created so functions that accept t of shape (B,)
or (B, T) (reference the variable t and the t_expanded assignment) behave
correctly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Enterprise

Run ID: 4520c194-03d3-493e-b094-8039ddd969f2

📥 Commits

Reviewing files that changed from the base of the PR and between 123e6a2 and 81c4cb3.

📒 Files selected for processing (1)
  • fastgen/networks/cosmos_predict2/network.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants