Skip to content

feat: add LTX-2 video generation support#1459

Open
mudler wants to merge 28 commits intoleejet:masterfrom
mudler:feat/ltx-video
Open

feat: add LTX-2 video generation support#1459
mudler wants to merge 28 commits intoleejet:masterfrom
mudler:feat/ltx-video

Conversation

@mudler
Copy link
Copy Markdown
Contributor

@mudler mudler commented Apr 24, 2026

Hi! I've been playing with claude code as I wanted to try to have LTX in stable-diffusion.cpp. I saw already an early attempt in #491 , and just saw #1458 right now, even if I was already working on it 😅

Hope it can be of any help - I got it to produce "acceptable" things

                                                                                                                                                         
  ~/ltxv-sd-cpp/build-cuda/bin/sd-cli -M vid_gen \                              
    -m ltxv-models/ltx-2.3-22b-distilled.safetensors \                                                                                                                                                    
    --text-encoder gemma-3-12b-it \                                                                                                                                                                       
    -p 'a cat walking across a grassy field' \                     
    -W 768 -H 512 --video-frames 121 \                                                                                                                                                                    
    --steps 8 --cfg-scale 1 \                                                   
    -o /tmp/ltx23_clean.webp --seed 42                                                                                                                                                                    
                                                            
ltx23_fix

mudler added 28 commits April 23, 2026 21:46
WIP 1:1 port of diffusers' LTX transformer + causal video autoencoder,
wired into sd.cpp as a new DiT family (VERSION_LTXV).

Transformer (ltxv.hpp):
* 28-layer LTXVideoTransformer3DModel with 32 heads x 64 head_dim
* 3D rotary positional embedding (F, H, W; dim//6 freqs per axis)
* rms_norm_across_heads QK norm; cross-attention to T5-XXL (4096->2048)
* AdaLayerNormSingle with 6-way modulation and final scale_shift_table
* FeedForward with gelu-approximate activation

Causal video autoencoder (ltxv.hpp):
* Encoder (causal) + Decoder (non-causal) with CausalConv3d stacks
* Residual blocks with channel-wise RMSNorm, optional timestep conditioning
* Pixel-shuffle 3D upsampling via reshape (TODO: match diffusers' exact
  permute order; likely needs correction after hardware validation)

Wiring:
* model.h: VERSION_LTXV + sd_version_is_ltxv helper + DiT aggregator
* model.cpp: detect via scale_shift_table / adaln_single / caption_projection
* diffusion_model.hpp: new LTXVModel wrapper
* stable-diffusion.cpp: T5CLIPEmbedder (is_umt5=false) + LTXVModel ctor,
  VAE factory arm for LTXVVAERunner, FLOW_PRED + default_flow_shift=3,
  latent channels=128, temporal compression=8 for generate_init_latent,
  8k+1 frame rounding in GenerationRequest
* vae.hpp: get_scale_factor returns 32 for LTX

End-to-end hardware verification still pending; known simplifications
flagged with TODO comments.

Upstream-tracking refs: based on leejet#491
(stduhpf/wip-ltx-support) with VAE filled in and modulation order
corrected per current diffusers transformer_ltx.py.
Replace the LTX-1 transformer + VAE with the LTX-2 architecture.

Transformer (ltxv.hpp namespace LTXV):
* LTX2VideoTransformer3DModel: 48 layers, 32 heads * 128 head_dim =
  inner_dim 4096, cross_attention_dim 4096, caption_channels 3840
* LTX2AdaLayerNormSingle with configurable num_mod_params (6 or 9)
* Gated attention (to_gate_logits + 2*sigmoid per-head gate)
* 3D RoPE 'interleaved' with patch-boundary midpoints, vae_scale_factors
  (8, 32, 32), causal_offset=1, fps scaling (split rope is a TODO)
* Video-only forward path: audio branches, a2v/v2a cross-attention, and
  audio FFN/proj_out are intentionally skipped (isolate_modalities=True)
  while their weight slots are still registered so LTX-2 checkpoints open

VAE (LTX2VideoEncoder3d / LTX2VideoDecoder3d / LTX2CausalVideoAutoencoder):
* PerChannelRMSNorm (no weight, y = x / sqrt(mean(x^2, C) + eps))
* CausalConv3d with runtime causal flag
* ResBlock conv_shortcut is a plain Conv3d (no temporal causal padding)
* Default block_out_channels (256, 512, 1024, 2048); upsample types
  spatial / temporal / spatiotemporal; all-spatiotemporal scaling

Wiring:
* model.h: VERSION_LTXV -> VERSION_LTXV2; sd_version_is_ltxv2 helper
* model.cpp: detect via audio_scale_shift_table / av_cross_attn_*
  / audio_proj_in / audio_time_embed
* diffusion_model.hpp: LTXVModel -> LTXV2Model wrapper
* stable-diffusion.cpp: T5-XXL conditioner + LTXV2Model ctor; VAE factory
  arm for LTXVVAERunner; FLOW_PRED + default_flow_shift=3; latent
  channels=128, temporal compression=8, 8k+1 frame rounding
* vae.hpp: get_scale_factor returns 32 for LTX-2

Known gaps (flagged in docs/ltxv.md): split rope, LTX-2.3 prompt modulation
gate, exact pixel-shuffle 3D permute order, latents mean/std scaling.
After inspecting the ltx-2.3-22b-dev.safetensors header (5947 tensors)
the checkpoint uses different top-level names than diffusers' LTX-2.0
code. Renames:
  time_embed                          -> adaln_single
  audio_time_embed                    -> audio_adaln_single
  proj_in                             -> patchify_proj
  audio_proj_in                       -> audio_patchify_proj
  prompt_adaln                        -> prompt_adaln_single
  audio_prompt_adaln                  -> audio_prompt_adaln_single
  av_cross_attn_video_scale_shift     -> av_ca_video_scale_shift_adaln_single
  av_cross_attn_audio_scale_shift     -> av_ca_audio_scale_shift_adaln_single
  av_cross_attn_video_a2v_gate        -> av_ca_a2v_gate_adaln_single
  av_cross_attn_audio_v2a_gate        -> av_ca_v2a_gate_adaln_single
  attention.norm_q / norm_k           -> q_norm / k_norm

Remaining gaps (flagged in docs/ltxv.md) before LTX-2.3 weights load:
  * caption_projection is LTX-2.0 style (2 linears). LTX-2.3 uses an
    8-block video_embeddings_connector with 128 learnable_registers
    and self-attention transformer_1d_blocks. Ditto audio.
  * VAE has 9 down/up blocks (not 4), block_out_channels starts at
    128 (not 256), deepest latent width is 1024 (not 2048).
  * Split RoPE still unimplemented.
  * prompt_modulation forward path is stubbed.

CPU + CUDA builds remain clean; runtime load on an LTX-2.3 checkpoint
will fail until the connector + VAE rewrites land.
… VAE

Rewrites ltxv.hpp to match the LTX-2.3 22B checkpoint layout exactly
(inferred from ltx-2.3-22b-dev.safetensors header — 5947 tensors).

Transformer additions (each a weight slot the checkpoint has):
* EmbeddingsConnector  — 128 learnable_registers + 8 transformer_1d_blocks
  (attn1 gated + ff 4x). Replaces the old caption_projection. Video uses
  inner_dim=4096, audio uses 2048.
* Per-block scale_shift tables: the 9-param scale_shift_table, the 2-param
  prompt_scale_shift_table (LTX-2.3 prompt modulation), and the 5-param
  scale_shift_table_a2v_ca_video / _audio tables (a2v/v2a cross-attn
  modulation). All six tables are registered per block on both video and
  audio branches.
* Gated attention always on (to_gate_logits linear + 2*sigmoid per-head).
* q_norm/k_norm tensor path (was norm_q/norm_k in diffusers LTX-2.0 code).
* Audio-to-video / video-to-audio cross-attention modules registered so
  weights load; forward path skips them (isolate_modalities=True).

VAE rewrite to match checkpoint:
* 9 encoder down_blocks: res×4 @128, spatial↓, res×6 @256, temporal↓,
  res×4 @512, st↓, res×2 @1024, st↓, res×2 @1024
* Mirrored decoder up_blocks with spatiotemporal/temporal/spatial
  upsamplers at the checkpoint-observed conv output sizes
  ([4096,1024], [4096,512], [512,512], [512,256]).
* VAEResBlock is the LTX-2.3 simplified shape (no norms, no conv_shortcut,
  no timestep modulation on the main path).
* per_channel_statistics (mean-of-means, std-of-means) registered so they
  load; not yet consumed by vae_to_diffusion / diffusion_to_vae.

CausalConv3d now uses tensor names conv.weight / conv.bias (not plain
weight/bias) to match diffusers' nn.Conv3d-wrapped-in-self.conv layout.
No LayerNorm at transformer output (collapsed to scale_shift + proj_out).
Patchify uses tensor name patchify_proj (not proj_in).

CPU build remains clean; next step is a DGX load-test against the
ltx-2.3-22b-distilled.safetensors checkpoint.
Checkpoint metadata confirms LTX-2.3 22B uses rope_type=split, not interleaved.
Split RoPE pair layout is (x[k], x[k+r]) where r = head_dim/2, applied as:
  first_new  = first  * cos - second * sin
  second_new = second * cos + first  * sin
vs. interleaved's pair (x[2k], x[2k+1]).

* LTXAttention gains a rope_type field (default "split").
* apply_split_rotary_emb implements the (first, second) swap using
  ggml_sub / ggml_mul on [r, 2, num_heads, L*N] layout.
* compute_rope_ltx2 gains a split_rope flag: when true it produces cos/sin
  tables sized dim/2 per position with layout [pad, (F0,H0,W0), (F1,H1,W1), ...]
  matching diffusers' transpose+flatten(2).
* Runner passes split_rope=true and uses rope_tbl.dim (not hardcoded inner_dim)
  when allocating the backend tensors.
LTX-2.3 checkpoints don't ship their multilingual text encoder — the
weight file only contains the 'text_embedding_projection' aggregate
embedder, not the upstream encoder that produces the 3840-dim
per-token features it consumes. Until that encoder is ported, wire
up a no-op LTXV2Conditioner that returns zero embeddings of shape
[1, 128, 4096] so the rest of the pipeline can load the diffusion
model and exercise its forward path.

Also ignore audio_vae.*, vocoder.*, text_embedding_projection.* at
load time — those live in the single-file 22B release but aren't
consumed by the video-only inference path yet.

Tensor-layout parity with the ltx-2.3-22b-distilled checkpoint is
verified by /tmp/compare_tensor_names.py (zero missing, zero shape
mismatches over 4444 transformer + 170 VAE tensors).
…p st_t-1 frames per temporal upsample

Two fixes after first DGX run:
1. CausalConv3d weights must be F16/F32 (not BF16) because
   ggml_cuda_op_im2col_3d GGML_ASSERTs on BF16 destination. Load path
   converts checkpoint BF16 -> F16 on the way in.
2. VAEUpsampler was doubling every temporal chunk uniformly, giving 16
   output frames from 2 latent frames. Matching diffusers by dropping
   the first (st_t - 1) frames after each temporal-upsampling step so
   f_out = (f_in - 1) * st_t + 1 across the decoder.

DGX run status:
  * 46 GB checkpoint loads clean (4444 transformer + 170 VAE tensors
    assign + 1333 extra tensors logged as unknown/ignored)
  * Detected as VERSION_LTXV2 (ltxv2.3 desc)
  * Transformer forward: 2 sampling steps complete in 2.26s on GB10
  * VAE decode graph builds (3.1 GB compute buffer) and runs in 1.57s
  * Only remaining crash was output-index mismatch from wrong frame
    count, addressed by this commit
sd.cpp's tensor_to_sd_image treats 4-D video tensors as [W,H,C,T] but
our VAE produces [W,H,T,C]. The framework already supports a 5-D video
format [W,H,T,C,N] which matches our ordering — so we simply unsqueeze
the decode result to 5-D and take that branch.

Milestone: end-to-end video generation works on DGX GB10:
  * 46 GB BF16 checkpoint loads in ~9s
  * 22B transformer forward runs (~1.1s/step, 128 MB compute buffer)
  * VAE decodes 2-latent to 9-frame 704x480 output (~1s, 1.8GB buffer)
  * Total wall time: 3.64s for the 9-frame test

Output quality is obviously not meaningful yet — the LTXV2Conditioner
stub returns zero text embeddings, so the transformer has no semantic
signal. Next steps: port the LTX-2.3 text encoder, apply the
per_channel_statistics latents normalisation, and fix the pixel-shuffle
3D permute order in the VAE (currently a simplified reshape).
…mute

Two small reverts after on-hardware testing:
* scale_input back to true (the sd.cpp default). Our earlier override
  to false prevented the [-1,1] → [0,1] mapping and contributed to all
  frames saturating to black.
* Revert the input/output token permute order change (1,2,3,0 /
  3,0,1,2) — the 'correct' w-fastest ordering triggered a 4D vs 5D
  broadcast mismatch in the sampler. Restored the original permute
  pair that produces the expected round-trip shape, at the cost of
  the known RoPE-vs-token ordering mismatch still flagged in
  docs/ltxv.md.

End-to-end reality on DGX GB10:
  * Load, forward and decode all run to completion for both the 46GB
    BF16 checkpoint and the 28GB q8_0 GGUF.
  * Output is a valid WebP file (704x480, 9 frames).
  * Output is IDENTICAL across different seeds, which means either (a)
    the transformer's cross-attention with zero text conditioning
    collapses to a constant, or (b) there is a numerical bug in the
    forward path (q/k norm, modulation, rope alignment) that produces
    constants regardless of input noise. Diagnosing this properly
    requires the real multilingual text encoder and/or per-op intermediates
    dumped against the PyTorch reference — tracked separately.
Instrumented diagnostic runs on DGX revealed:
* Transformer layer 1 output magnitude is ~1e11 from clean noise input
  — all loaded weights appear normal (verified via direct safetensors
  dump: scale_shift_table std=0.37, q_norm std=0.17, linear weights
  std ~0.017). Root cause not yet localised.
* VAE decode of the transformer output produced all-NaN frames because
  I had omitted the stateless PerChannelRMSNorm before each conv in
  VAEResBlock. Adding it back gives bounded but still uniform output
  (mean 826, std 109) — the VAE itself is seed-insensitive when the
  upstream latent is constant.

What this adds:
* log_tensor_stats helper for min/max/mean/std of sd::Tensor.
* LTXV_DEBUG_MODE env var to selectively skip modulation, attention,
  FF, rope, gate, qk_norm at runtime (exposed only when LTXV_DEBUG_MODE
  is set; unaffects normal runs except in the debug path).
* LTXV_DEBUG_MAX_LAYERS env var to bisect how many transformer blocks
  run.
* LTXV_PROBE_STAGE env var returning an early-stage tensor (not yet
  fully stable; some stages crash on unused rope/context inputs).
* LTXV_BYPASS to skip the transformer entirely (validates sampler+VAE
  independently).
* PerChannelRMSNorm restored in VAEResBlock (missing norms caused NaN).

Next: reference PyTorch harness to compare per-op layer outputs
against the C++ pipeline.
…rmalization, depth-to-space

Three fixes that, combined, bring LTX-2.3 decoder output from single-color
banding to colored spatial structure.

1. EmbeddingsConnector: add pre-rms_norm inside each of the 8 transformer_1d_blocks
   and a final rms_norm after the stack, matching Lightricks' reference
   Embeddings1DConnector._BasicTransformerBlock1D. Without these, residual
   magnitudes compounded across 8 blocks and drove the connector output to
   std≈1.1e12, exploding cross-attention inside block 0.

2. Per-channel VAE latent normalisation: materialise
   per_channel_statistics.{mean,std}-of-means to CPU on first call and apply
   (x * std + mean) in diffusion_to_vae_latents (and the inverse in
   vae_to_diffusion_latents). Values taken from the checkpoint — no more
   identity fall-through.

3. Decoder conv_norm_out (PerChannelRMSNorm) + SiLU before conv_out. Missing
   these activations left the decoder output at ~O(1000) per pixel instead of
   [-1, 1].

4. Implement depth_to_space_3d matching einops
   `b (c p1 p2 p3) f h w -> b c (f p1) (h p2) (w p3)` with p3-inner/p1-outer
   convention. Use in VAEUpsampler (replaces the naive ggml_reshape) and
   final decoder unpatchify. Eliminates visible banding artefacts in decoded
   frames.

5. Add intra-block probe infrastructure (blk0_*, attn prefix) that surfaced
   the connector bug; keep it in place for future sampler tuning.
…es real frames

Adds the missing LayerNorm (elementwise_affine=False, eps=1e-6) between
the 48 transformer blocks and the final adaln modulation / proj_out, matching
Lightricks' LTXModel._process_output:

    x = norm_out(x)
    x = x * (1 + scale) + shift
    x = proj_out(x)

Without norm_out, the post-block activation std accumulated to ~285 across
48 layers and the predicted velocity came out at std≈57 — 40× larger than
expected. Chained through the sampler this produced completely saturated
garbage on 4+ steps.

With norm_out, transformer_out is std≈1.0 at every step, 8-step distilled
sampling converges to real photo-realistic frames (VAE output in [-1.5, 1.2]).
The unconditional result is generic (the text encoder is still stubbed to
zeros — it's the next remaining item) but the full transformer + VAE stack
is now demonstrably working end-to-end on the 22B checkpoint.

Combined with the previous commit (connector pre-norms, per-channel VAE
normalisation, VAE conv_norm_out, einops depth-to-space), this completes
the numerical correctness milestone.
… frames

Document the five numerical-correctness bugs that blocked output quality
(connector pre-norms, final norm_out, VAE conv_norm_out+SiLU, per-channel
latent normalisation, depth-to-space ordering) and reframe the remaining
items around text encoder / schedule tuning / audio branch.
…enizer

Lays the groundwork for LTX-2.3's text conditioning path. Native port of
Gemma-3-12B (the exact text encoder the LTX-2.3 pipeline uses, confirmed
via text_embedding_projection.video_aggregate_embed.weight shape
[4096, 188160] = 3840 * 49 = hidden * (48 layers + 1 embed)).

src/gemma3.hpp
  - Gemma3Params (3840 hidden, 48 layers, 16Q/8KV heads, head_dim=256,
    sliding_window=1024, global/local RoPE θ=1e6/1e4, linear scaling=8,
    query_pre_attn_scalar=256, GELU-tanh MLP at 15360 intermediate).
  - Gemma3RMSNorm: applies `(1 + w)` per Gemma convention.
  - Gemma3MLP: SwiGLU-with-GELU(tanh).
  - Gemma3Block: input_norm → GQA (qk_norm + RoPE + 1/sqrt(256) scale)
                 → post_attn_norm → residual → pre_ffn_norm → MLP
                 → post_ffn_norm → residual. Matches
                 llama.cpp/src/models/gemma3.cpp's layer_build_gemma3.
  - Gemma3TextModel: embedding (scaled by sqrt(hidden)) + 48 blocks +
    final RMSNorm. Exposes `forward_with_hidden_states` returning all 49
    intermediate states for LTX's aggregate_embed projection.
  - RoPE precompute + sliding-window mask builder.
  - Gemma3Runner wraps it all for backend graph execution.

src/tokenizers/gemma3_tokenizer.{h,cpp}
  - Minimal SentencePiece protobuf parser (just the `pieces` field —
    262144 entries for Gemma-3; parses string/score/type tuples, skips
    trainer_spec and normalizer_spec).
  - Classic SPM BPE encoder: meta-space pre-tokenisation, byte-level
    fallback via <0xHH> pieces, score-priority merge loop.
  - Gemma-3 specific tweak: `add_dummy_prefix=False`, so the first word
    is encoded without a leading meta-space (the HF GemmaTokenizerFast
    default).

tests/gemma3_tokenizer_test.cpp
  - Standalone CLI that loads tokenizer.model and prints the encoding
    of a given prompt. Validated against HuggingFace's
    `GemmaTokenizerFast.encode("...")`:
      "a cat walking across a grassy field" → matches exactly.
      "Hello, World!"                       → matches exactly.
      "A person riding..."                  → matches exactly.
      "1234 abc 日本語 𝟘🎉"                 → matches exactly (Japanese,
                                             digits, math-style alphanum,
                                             emoji).

No runtime behaviour change — gemma3.hpp is included from conditioner.hpp
but not yet wired into any LTXV conditioner. Next phase: run the
transformer forward with real Gemma-3-12B weights and diff against HF.

Assisted-by: Claude Opus 4.7 [Code] [Agent]
…bf16 precision

Implements the Gemma-3-12B text transformer (48 decoder blocks) and
validates every intermediate hidden state against the HuggingFace
reference on a real 23 GB checkpoint. All 49 states match to within
bf16 rounding:

    layer      ours       HF
      0      0.974      0.974    (post-embed)
      1     83.967     83.917
      4    334.240    334.817
     12   1791.846   1788.172
     24   5105.149   5091.607
     36   5957.374   6003.428
     47   6551.764   6593.038
     48      2.531      2.394    (post-final-norm)

Summary of the port:

  src/gemma3.hpp
    * Gemma3Block: pre-attn RMSNorm -> GQA (16 Q / 8 KV heads, head_dim=256)
      with per-head q_norm/k_norm, NEOX-style RoPE, 1/sqrt(256) scale
      -> post-attn RMSNorm -> residual
      -> pre-ffn RMSNorm -> SwiGLU-with-GELU-tanh MLP -> post-ffn RMSNorm
      -> residual. Matches llama.cpp's `llm_build_gemma3`.
    * Gemma3RMSNorm applies `x * (1 + w)` (not `x * w`) — HF convention.
    * compute_gemma3_rope generates NEOX-layout tables with halves
      duplicated: [cos_0..cos_{r-1}, cos_0..cos_{r-1}] for direct
      element-wise multiply in apply_rotary_emb. Dual θ (1e6 global,
      1e4 local) and linear scaling factor 8 for the global family.
    * build_causal_mask produces lower-triangular masks clipped to
      the sliding window. Gemma-3 uses causal attention everywhere
      (`use_bidirectional_attention = False`); full-attention layers
      just get a plain causal mask instead of a windowed one.
    * sliding_window_pattern = 6 — every 6th layer is full attention
      (layers 5, 11, 17, 23, 29, 35, 41, 47 for 48-layer config).
    * Gemma3TextModel wires embedding (scaled by sqrt(hidden)) + 48
      blocks + final RMSNorm, exposes forward_with_hidden_states which
      returns the 49-entry hidden-state list. A `max_layers` argument
      truncates the forward for diagnostic/probe use.
    * Gemma3Runner manages params_buffer, lazy RoPE/mask rebuilds,
      and a `compute_layer_hidden(layer_idx)` probe that builds a
      graph ending exactly at the requested hidden state so the
      runner's compute<> picks it up as `final_result`.

  examples/gemma_test/
    * Standalone `gemma3-test` CLI: takes a Gemma-3-12B directory (HF
      safetensors shards), a tokenizer.model path and a prompt;
      tokenises, runs the forward on CUDA, prints per-layer hidden
      stats. Used as the validation harness above.

Known follow-ups (Phase 4+):
  - concatenate 49 hidden states along the channel axis
  - apply `text_embedding_projection.video_aggregate_embed` from the
    LTX-2.3 22B safetensors to project 49*3840 -> 4096
  - push through video_embeddings_connector and into the LTX DiT
  - replace LTXV2Conditioner's zero stub with the real pipeline

Assisted-by: Claude Opus 4.7 [Code] [Agent]
Extends Gemma3Runner with compute_concatenated_hiddens(input_ids, out_dim),
which produces the exact 188160-dim feature that LTX-2.3's
text_embedding_projection.video_aggregate_embed expects on its input —
matching the reference FeatureExtractorV2 pipeline:

  1. Run Gemma-3 → list of 49 hidden states [hidden=3840, L, 1]
  2. Per-token RMSNorm along the hidden axis for EACH layer
  3. Concatenate the 49 normed layers along the channel axis
  4. Rescale by sqrt(out_dim / hidden_size)   (= sqrt(4096/3840) for LTX)

Numerical match against HuggingFace:
  ours: min=-60.0572 max=+63.9996 mean=+0.0158 std=1.0327
  HF:   min=-60.0856 max=+63.9996 mean=+0.0153 std=1.0327
on prompt "a cat walking" (4 tokens).

The rescaled output is the input to the Linear(188160, 4096) that lives
inside the LTX-2.3 22B safetensors as `text_embedding_projection.
video_aggregate_embed`; Phase 5 wires that projection into LTXV2Conditioner
and replaces the zero stub.

Assisted-by: Claude Opus 4.7 [Code] [Agent]
End-to-end: CLI flag --text-encoder points at a Gemma-3-12B-it directory
(tokenizer.model + safetensors shards); sd-cli then tokenises the prompt,
runs Gemma-3 on CUDA, applies LTX-2.3's per-token RMSNorm + sqrt(out/D)
rescale + `text_embedding_projection.video_aggregate_embed` Linear, and
feeds the resulting [4096, L] cross-attention features to the LTX video
DiT's `video_embeddings_connector`.

Changes:
  include/stable-diffusion.h
    + sd_ctx_params_t.text_encoder_path (appended to preserve existing
      aggregate-initialiser callers).
  examples/common/common.{h,cpp}
    + --text-encoder CLI flag, plumbed into the ctx params struct.
  src/stable-diffusion.cpp
    + When version == ltxv2 and text_encoder_path is non-empty:
      * Enumerate shards in the directory (dirent scan for *.safetensors).
      * Build a secondary ModelLoader with prefix="language_model." so
        HF names collapse to the runner's "model.*" expectations.
      * Construct a Gemma3Runner (clip_backend), alloc_params_buffer,
        load tensors.
      * Attach the Gemma runner + tokenizer + LTXTextEmbedProjection to
        LTXV2Conditioner. When path is unset, the conditioner returns
        zeros (unconditional) so the pipeline still works.
    + Remove `text_embedding_projection.` from LTX ignore list so the
      projection weight loads together with the rest of the LTX weights.
  src/conditioner.hpp
    + LTXTextEmbedProjection: GGMLRunner that owns the 188160 -> 4096
      Linear from the LTX 22B safetensors.
    + LTXV2Conditioner::attach_gemma + real get_learned_condition that
      runs Gemma + projection, with diagnostic logging of input/output
      stats.

Status:
  * Gemma forward: validated numerically against HF reference (exact
    match on prompts incl. Japanese/emoji, per-layer std within bf16
    noise). See `gemma3-test` binary.
  * Conditioner output: 188160-dim concat matches HF exactly
    (min/max/std within rounding). The 4096-dim projected embedding
    flows into the transformer's `video_embeddings_connector`.
  * End-to-end: prompts produce DIFFERENT c_crossattn, verified by
    per-prompt stats diffs. The video DiT runs to completion without
    asserts or NaNs.

Follow-up (separate commit): the cross-attention gate inside
`LTX2VideoTransformerBlock.attn2` reads gate_logits ≈ -11 on block 0,
closing the cross-attn to ~3e-5. This mutes the prompt signal at the
perceptual level even though the condition flows through the graph.
Needs investigation — either a weight-loading order issue or a
legitimate behaviour that only unblocks in later blocks / denoising
steps; does not affect the correctness of the port itself.

Assisted-by: Claude Opus 4.7 [Code] [Agent]
Two related bugs in the Gemma -> LTX conditioning path that together
prevented prompts from meaningfully steering generation.

1. 188160-dim flat ordering was transposed.

   HF reference does  encoded = stack(hidden_states, dim=-1); normed =
   (encoded * rsqrt(var)).reshape(B, T, D*L). In PyTorch memory this
   makes the layer axis FAST and the hidden axis SLOW within the flat
   188160-dim.

   We were doing ggml_concat(..., axis=0) over a list of [D, T, 1, 1]
   tensors, which puts hidden FAST and layer SLOW — the opposite order.
   text_embedding_projection.video_aggregate_embed was trained with the
   HF ordering, so every row of the weight was multiplied against the
   wrong input element and the projection output was tiny and nearly
   identical across prompts (std 2.4 vs HF's 6.8 on the same prompt).

   Fix: stack along axis 2 -> [D, T, L, 1], permute(2, 0, 1, 3) ->
   [L, D, T, 1], reshape to [D*L, T, 1]. Now the flat's fast index is
   L (matches HF). After the fix the projected cross-attn input reaches
   std=5.24 with min/max +/-87..+277 — close to HF reference
   (std=6.83, min/max -148..+281) to within bf16 noise.

2. EmbeddingsConnector output shape was wrong.

   Reference Embeddings1DConnector produces a FIXED 128-token output
   whose first L positions are the real text and positions [L..127]
   come from learnable_registers[L..127]. We were concatenating 128
   registers + L text = 128+L tokens, in the WRONG order. Rewrote the
   connector's register path to match the reference's
   `_replace_padded_with_learnable_registers` semantics.

3. Drop the diagnostic LTXV_SKIP_XATTN_GATE env knob — bypassing the
   learned cross-attn gate breaks self-attention too; gate values ~-11
   at block 0 with noise queries are correct, text influence builds up
   through later blocks once the layout fix lets the projection do its
   job.

With these fixes, conditioned generation now visibly reacts to prompt
changes (different seeds / prompts produce meaningfully different
scenes) — previously every prompt produced the same "person in
kitchen" fallback because the projection was effectively noise. Seed
42 + "cat walking in a grassy field" now yields an entirely different
scene (chefs in a kitchen) compared to unconditioned, and with CFG=3
plus a negative prompt the scene moves outdoors.

Assisted-by: Claude Opus 4.7 [Code] [Agent]
… work

The Gemma-3 attention scale `1/sqrt(query_pre_attn_scalar)` equals
`1/sqrt(head_dim)` for Gemma-3-12B (both = 256). I was applying that
scale to Q explicitly AND `ggml_ext_attention_ext` applies the same
`1/sqrt(d_head)` internally — so the effective softmax temperature was
16× too small, which flattened cross-token attention into a near-
uniform mix. With noise-driven queries that showed up as a fixed
"attention sink" outlier at one hidden dim (2339 for this vocab) at
every layer, and the projection output was 25% below HF (std 5.24 vs
6.83) because the singular outlier swallowed all the RMSNorm budget.

Removing the explicit Q scale (and asserting the scalar==head_dim
invariant for future-proofing larger variants) makes the forward match
HF numerically:

    layer  1 post-block tok=1 d=2339: ours 799.92  HF 800.00
    layer 47                        : ours 124879 HF 124928
    projected [4096, L]: min -148.24 max +280.51 std 6.828
                         HF           -148.28     +280.99    6.830

Prompts now visibly change the generated content: seed 42 with
"a cat walking across a grassy field" finally produces a cat walking
across a grassy field instead of the generic "person in a kitchen"
fallback.

Assisted-by: Claude Opus 4.7 [Code] [Agent]
Applies the official `DISTILLED_SIGMA_VALUES` sequence
[1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
when the user asks for exactly 8 sampling steps on an LTX-2.3 model
and hasn't provided a --sigmas override — matching the reference
ltx_pipelines.DistilledPipeline default.

The distilled schedule is non-uniform: five tight steps clustered near
σ=1 plus three sharp drops at the end. The generic shifted-flow
schedule (shift=3) we defaulted to before spent denoising budget too
uniformly, producing softer / smoother output. With the distilled
schedule the cat-in-grass test ("a cat walking across a grassy field",
seed 42, cfg=1, 8 steps) finally looks crisp instead of smudged.

Assisted-by: Claude Opus 4.7 [Code] [Agent]
The reference LTX-2 VideoDecoder ends with `ops.py::unpatchify`:
  rearrange(x, "b (c p r q) f h w -> b c (f p) (h q) (w r)")
where the channel axis is packed as (c outer, p_t, p_w, p_h) with the
h_patch (q) innermost.

The intermediate DepthToSpaceUpsample uses a DIFFERENT convention:
  "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)"
where p3 (w_stride) is innermost.

I was reusing my (c p1 p2 p3) helper for the final 4x4 unpatchify. That
silently transposes every (p_h x p_w) output block, producing a visible
fine-scale hatching artefact that survived every diffusion step
regardless of the sigma schedule or prompt conditioning.

Add a dedicated depth_to_space_3d_patch helper that swaps the inner
(p_w, p_h) sub-axes of the channel layout to match the reference
convention, then delegates to the existing helper. The decoder's final
call is now correct; a TODO marks the matching encoder patchify bug
for future v2v/i2v work (the encoder isn't exercised in T2V).
Add the ninth bug (two conflicting channel-packing conventions between
ops.py::unpatchify and sampling.py::DepthToSpaceUpsample) and record the
safetensors __metadata__["config"]["vae"] cross-check that confirmed the
absence of a residual skip, reflect padding, or timestep conditioning in
the 22B checkpoint.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant