Skip to content

Cosmos3 context parallel#14054

Open
atharvajoshi10 wants to merge 4 commits into
huggingface:mainfrom
atharvajoshi10:cosmos3-context-parallel
Open

Cosmos3 context parallel#14054
atharvajoshi10 wants to merge 4 commits into
huggingface:mainfrom
atharvajoshi10:cosmos3-context-parallel

Conversation

@atharvajoshi10

@atharvajoshi10 atharvajoshi10 commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Cosmos3 multi-GPU inference: context + tensor parallelism

What this PR does

Adds multi-GPU inference for the Cosmos3 (Cosmos3OmniPipeline) family along two orthogonal, composable sharding axes, so the model can be run faster and/or made to fit when a single checkpoint exceeds one GPU's memory. All parallelism logic lives in examples/cosmos3/cosmos_parallel.py; the model itself stays parallelism-free apart from two tiny no-op seams.

The two axes

  • Context parallelism (CP / Ulysses)enable_cosmos3_context_parallel. Shards the sequence across GPUs; attention runs with two all-to-all collectives per layer (gather-seq/scatter-heads → local attention → gather-heads/scatter-seq). Weights are replicated, so it cuts latency but not weight memory.
  • Tensor parallelism (TP)enable_cosmos3_tensor_parallel. Shards the attention and MLP weight matrices (Megatron-style: column-parallel q/k/v + gate/up, row-parallel out + down), so a checkpoint that doesn't fit one GPU (e.g. Cosmos3-Super, ~120 GB) loads across several.

They compose on a 2-D (tp, cp) device mesh (e.g. TP=2 × CP=2 over 4 GPUs).

Design

The model carries no parallelism logic — it exposes two optional, default-None seams on Cosmos3OmniTransformer: _cp_shard_fn / _cp_gather_fn, which shard each pathway's sequence (and rotary embeddings) before the decoder stack and re-gather after the final norm. Attention parallelism lives entirely in standalone attention processors (Cosmos3CPAttnProcessor, Cosmos3FlashAttnProcessor) installed via set_processor — each is self-contained (its own __call__), so the core model file needs no override hooks.

Why a custom CP path (not the declarative _cp_plan): Cosmos3 attention has (1) grouped-query attention — KV heads must be repeated to match query heads; (2) separate understanding (causal) and generation (full) token streams, where generation attends to cat(und, gen); (3) ragged per-stream lengths that are padded independently with the padded generation keys masked. These can't be expressed declaratively.

GQA + flash: SDPA's flash/cuDNN kernels reject enable_gqa, and the native kernel falls back to math (materializing the full [S, S] scores → OOM on long sequences). Both attention paths instead expand KV heads up to the query-head count and call SDPA with enable_gqa=False, so it dispatches to flash (O(S) memory).

Usage

The example runner examples/cosmos3/inference_cosmos3.py works across all modalities (t2i / t2v / i2v / v2v / sound / action) — just launch with torchrun and pass degrees:

# CP=4 (lower latency, Nano):
torchrun --nproc_per_node 4 inference_cosmos3.py --cp-degree 4 --prompt "..."

# TP=2 × CP=2 across 4 GPUs (Super):
torchrun --nproc_per_node 4 inference_cosmos3.py --model super --tp-degree 2 --cp-degree 2 --prompt "..."

--nproc_per_node must equal --tp-degree × --cp-degree.

@atharvajoshi10 atharvajoshi10 marked this pull request as draft June 23, 2026 17:31
@github-actions github-actions Bot added documentation Improvements or additions to documentation models pipelines examples size/L PR with diff > 200 LOC labels Jun 23, 2026
@atharvajoshi10 atharvajoshi10 force-pushed the cosmos3-context-parallel branch 4 times, most recently from 67fb9ec to 6edc5fd Compare June 24, 2026 00:05
Comment thread src/diffusers/models/transformers/transformer_cosmos3.py Outdated
atharvajoshi10 added a commit to atharvajoshi10/diffusers that referenced this pull request Jun 25, 2026
Address review feedback on PR huggingface#14054: the parallel attention processors no
longer subclass Cosmos3AttnProcessor, so the model file needs no override seam.

- transformer_cosmos3.py: revert Cosmos3AttnProcessor to inline the attention
  in __call__ (remove the _run_attention seam); restores it to its base version.
- cosmos_parallel.py: Cosmos3CPAttnProcessor and Cosmos3FlashAttnProcessor each
  get their own full __call__, sharing a _project_qkv_with_rope prologue helper.

Verified behavior-preserving on 4x RTX PRO 6000: cp_unit_test (fp32) passes at
1e-4; cp_numeric_check is byte-identical to the pre-refactor code; the end-to-end
CLI passes in CP-only, TP-only, and TP+CP modes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Atharva Joshi and others added 3 commits June 25, 2026 11:04
Under sharded placement (device_map="balanced"), vae.encode() runs on the
VAE's own device while the mean/inv_std buffers were pinned to x.device,
causing a cross-device RuntimeError. Compute raw_mu first, then pin the
normalization buffers to its device so all tensors share one device.
Cosmos 3 cannot use diffusers' declarative `_cp_plan` CP path: it is grouped-query
attention (the shared Ulysses kernel assumes K/V share the query head count), its
understanding (causal) and generation (full) streams are separate packed sequences
(gen attends to cat(und, gen)), and per-pathway lengths are ragged. The model carries
no parallelism logic -- it exposes only small, CP-agnostic seams; all sharding lives
outside it, in a reusable example module.

Model (transformer_cosmos3.py): adds two default-None `forward` seams -- `_cp_shard_fn`
(shards und/gen + rotary before the decoder layers) and `_cp_gather_fn` (gathers/unpads
after the final norm) -- and extracts `Cosmos3AttnProcessor._run_attention` as an
override point. The non-parallel path is unchanged.

Helpers (examples/cosmos3/cosmos_parallel.py): one importable module, two orthogonal
and composable axes:
  * Context parallelism (Ulysses) -- `enable_cosmos3_context_parallel`. Shards the
    sequence; brackets the two attention pathways with all-to-all (DTensor redistribute),
    repeats GQA KV heads, pads ragged lengths and masks padded generation keys.
  * Tensor parallelism (Megatron) -- `enable_cosmos3_tensor_parallel`. Column/row-shards
    the attention + MLP weights so a checkpoint that does not fit one GPU (Super, ~120 GB)
    loads across several; weights load to CPU then shard layer by layer.
Both expand KV heads to the query-head count and call SDPA with enable_gqa=False so it
dispatches to the flash kernel; enable_gqa=True forces the math path, which materializes
the full [S, S] score matrix and OOMs on long videos. A dense `Cosmos3FlashAttnProcessor`
(`enable_cosmos3_flash_attention`) provides the same for TP without CP.

CLI (examples/cosmos3/inference_cosmos3.py): imports these helpers, so any modality
(text-to-image/video, image-to-video, sound, action) runs single- or multi-GPU via
`--tp-degree` / `--cp-degree` (their product must equal --nproc_per_node). Single-GPU
behavior is unchanged.

Docs + example README updated. Verified: CP attention core is bit-exact vs non-CP in
fp32 (max|d|=0), and a full 36-layer forward matches CP-on vs CP-off to ~1e-6 in fp32
(bf16 differs only by floating-point rounding).
Address review feedback on PR huggingface#14054: the parallel attention processors no
longer subclass Cosmos3AttnProcessor, so the model file needs no override seam.

- transformer_cosmos3.py: revert Cosmos3AttnProcessor to inline the attention
  in __call__ (remove the _run_attention seam); restores it to its base version.
- cosmos_parallel.py: Cosmos3CPAttnProcessor and Cosmos3FlashAttnProcessor each
  get their own full __call__, sharing a _project_qkv_with_rope prologue helper.

Verified behavior-preserving on 4x RTX PRO 6000: cp_unit_test (fp32) passes at
1e-4; cp_numeric_check is byte-identical to the pre-refactor code; the end-to-end
CLI passes in CP-only, TP-only, and TP+CP modes.
@atharvajoshi10 atharvajoshi10 force-pushed the cosmos3-context-parallel branch from 84fc872 to 7d3f093 Compare June 25, 2026 18:08
@atharvajoshi10 atharvajoshi10 marked this pull request as ready for review June 25, 2026 19:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation examples models size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants