优化写入后新的所有结果

This commit is contained in:
qhy
2026-02-19 20:18:31 +08:00
parent 5e0e21d91b
commit 43ab0f71b0
28 changed files with 1776 additions and 1199 deletions

View File

@@ -0,0 +1,122 @@
== Task Comprehension: Diffusion Model and UnifoLM-WMA
This section provides a comprehensive overview of the UnifoLM-WMA-0 deep learning architecture, serving as a practical foundation for the optimization strategies discussed in subsequent sections.
=== Overall Inference Pipeline
UnifoLM-WMA-0 is Unitree Robotics' open-source World-Model-Action framework. Its core task is to predict future video frame sequences along with the corresponding robot action and state trajectories, given a current observation image and a text instruction. The model operates in an interactive simulation mode: each iteration consumes the previous prediction as input and generates the next segment of video and actions, thereby forming a closed-loop rollout. A single iteration of this pipeline can be decomposed into four sequential stages condition encoding, VAE encoding, DDIM diffusion sampling, and VAE decoding each of which is described below.
==== Condition Encoding
The condition encoding stage transforms raw multi-modal inputs into a unified context vector that guides the diffusion denoising process, through three parallel encoding paths. On the image side, the input observation image (320#sym.times 512) is processed by a frozen OpenCLIP ViT-H-14 vision encoder, then compressed through a Resampler a Perceiver-based cross-attention module (4 layers, 12 heads, dim\_head=64, embed\_dim 1280 #sym.arrow 1024) into 16 image condition tokens per frame, yielding $16 times T = 256$ image tokens for T=16 frames.
On the text side, the instruction is encoded by a frozen OpenCLIP text encoder (`FrozenOpenCLIPEmbedder`, penultimate layer output) into 77 tokens of dimension 1024, computed once and reused across all DDIM steps. On the state side, the robot proprioceptive state (dim 16) is mapped through a SATokenProjector (Perceiver Attention, 1 layer, 16 heads, dim\_head=64, 16 learnable queries) into 16 tokens of dimension 1024.
These three token sets are concatenated to form the unified context vector: `[agent_state(2) | agent_action(16) | text(77) | image(256)]`, totaling 351 tokens per cross-attention operation.
==== VAE Encoding
The observation images are encoded into a compact latent space through an AutoencoderKL (`autoencoder.py`) — a variational autoencoder regularized by KL divergence. The encoder follows a convolutional architecture with 4-level channel multipliers [1, 2, 4, 4] (base channels ch=128, yielding channel widths [128, 256, 512, 512]), 2 residual blocks per level, and a latent channel count of z\_channels=4. The input RGB frames at resolution 320#sym.times 512 are encoded into latent representations at 1/8 spatial resolution, producing tensors of shape `(B, 4, T, 40, 64)`.
A critical configuration parameter is `perframe_ae=True`, which means the VAE processes each of the T=16 frames independently rather than as a 3D volume. While this per-frame strategy avoids the memory overhead of volumetric convolutions, it introduces a sequential loop of T forward passes through the encoder a point worth noting for latency optimization. The latent representations are scaled by a fixed factor of `scale_factor=0.18215` before being fed into the diffusion process.
==== DDIM Diffusion Sampling
This is the core time-consuming part of inference. A DDIM (Denoising Diffusion Implicit Models) sampler (`ddim.py`) is employed with a default of 50 denoising steps. The diffusion process is parameterized with v-prediction (`parameterization="v"`), 1000 training timesteps, and a linear beta schedule from `linear_start=0.00085` to `linear_end=0.012`, with zero-SNR terminal rescaling enabled (`rescale_betas_zero_snr=True`) and dynamic rescaling applied at `base_scale=0.7` to stabilize generation quality.
Unlike standard video diffusion models that only predict denoised video latents, UnifoLM-WMA simultaneously produces three outputs per step: a video latent prediction `y` of shape `(B, 4, T, 40, 64)`, an action trajectory prediction `a_y` of shape `(B, T, 16)`, and a state trajectory prediction `s_y` of shape `(B, T, 16)`. The three predictions share the same diffusion timestep but employ heterogeneous noise schedules the video stream uses the DDPM schedule with v-prediction, while the action and state streams use a `DDIMScheduler` from the `diffusers` library with epsilon-prediction and a `squaredcos_cap_v2` beta schedule. This design allows each modality to adopt its optimal denoising strategy.
The sampler also supports classifier-free guidance with `unconditional_guidance_scale` and guidance rescaling, applied only to the video stream to balance generation quality and diversity.
==== VAE Decoding
After the DDIM sampling loop completes, the denoised video latent tensor $x_0$ of shape `(B, 4, T, 40, 64)` is decoded back to RGB pixel space through the AutoencoderKL decoder. Due to the `perframe_ae=True` configuration, decoding is likewise performed frame-by-frame: each of the T=16 latent frames is individually inverse-scaled by $1 slash "scale_factor"$, passed through the decoder's convolutional transpose layers, and reconstructed to a 320#sym.times 512 RGB frame.
In the interactive simulation mode, the decoded video serves a dual purpose providing the observation image for the next iteration's condition encoding (only the first `exe_steps` frames are needed) and producing the final output video for visualization and evaluation. The action and state trajectories predicted by the DDIM loop are directly used for robot control without further decoding.
=== WMAModel Backbone: Dual-UNet Collaborative Architecture
The WMAModel (`wma_model.py:326`) is the core neural network invoked at every DDIM step, employing a unique dual-UNet collaborative architecture that jointly predicts video, actions, and states within a single forward pass. This tightly-coupled design enables the action and state predictions to directly leverage the rich spatiotemporal features extracted by the video generation backbone, rather than treating them as independent prediction heads.
==== Video UNet
The primary backbone is a 2D convolution-based UNet with temporal extensions. Its key configuration is summarized in the following table:
#figure(
table(
columns: (3fr, 5fr),
[*Parameter*], [*Value*],
[Input / Output channels], [8 (4 latent + 4 conditioning) / 4],
[Base model channels], [320],
[Channel multipliers], [\[1, 2, 4, 4\] #sym.arrow widths \[320, 640, 1280, 1280\]],
[Residual blocks per level], [2],
[Attention resolutions], [\[4, 2, 1\] (3 of 4 resolution levels)],
[Attention head channels], [64],
[Transformer depth], [1 per attention resolution],
[Context dimension], [1024],
[Temporal length], [16 frames],
),
caption: [Video UNet configuration parameters.],
)
The UNet follows the classic encoder-middle-decoder structure with skip connections. At each attention-enabled resolution level, every ResBlock is followed by two transformer modules: a SpatialTransformer that performs spatial self-attention among all $H times W$ tokens within each frame followed by cross-attention with the 351-token context vector, and a TemporalTransformer that performs self-attention among T=16 time-step tokens at each spatial position (configured with `temporal_selfatt_only=True`, i.e., no cross-attention).
During the forward pass, intermediate feature maps are collected after each Downsample layer and the middle block, reshaped from $(B times T, C, H, W)$ to $(B, T, C, H, W)$, accumulating 10 multi-scale feature maps in `hs_a` the bridge to the Action/State UNets.
==== Action UNet and State UNet
The Action UNet (`conditional_unet1d.py`) is a 1D convolutional UNet specifically designed for predicting robot action trajectories. Its configuration is as follows:
#figure(
table(
columns: (3fr, 5fr),
[*Parameter*], [*Value*],
[Input dimension], [16 (agent\_action\_dim)],
[Down channel widths], [\[256, 512, 1024, 2048\]],
[Kernel size], [5],
[GroupNorm groups], [8],
[Diffusion step embedding dim], [128],
[Horizon], [16],
[Action projection dim], [32],
),
caption: [Action UNet (ConditionalUnet1D) configuration parameters.],
)
The Action UNet receives the 10 `hs_a` feature maps from the Video UNet as visual conditioning. The conditioning pipeline involves three stages: (1) SpatialSoftmax compresses each 2D feature map into keypoint coordinates $(B times T, C, 2)$; (2) the compressed features are concatenated with the diffusion timestep embedding and observation encoding (ResNet-18 `MultiImageObsEncoder`), then injected via FiLM modulation to produce per-channel scale/bias for the 1D convolution blocks; (3) `ActionLatentImageCrossAttention` enables action tokens to cross-attend to the Video UNet's spatiotemporal features, allowing visually-grounded action planning.
The input action tensor $(B, T, 16)$ is projected to act\_proj\_dim=32, processed through the 1D UNet, then projected back to $(B, T, 16)$.
The State UNet is an identical `ConditionalUnet1D` instance with the same hyperparameters, operating on the state tensor `x_state` $(B, T, 16)$ instead of the action tensor.
A critical optimization observation: the Action and State UNets are computationally independent sharing read-only inputs with no data dependencies. The original code executes them sequentially, leaving significant room for CUDA stream parallelization.
=== Multi-Level Design of Attention Mechanisms
The attention mechanisms in UnifoLM-WMA constitute the core computational bottleneck of inference. Their design encompasses four distinct levels, each serving a different purpose in the model's spatiotemporal reasoning, and understanding their structure is essential for identifying optimization opportunities.
The first level is *spatial self-attention* within the SpatialTransformer. For a latent frame at resolution $H times W$, the token count is $H times W$ (e.g., $40 times 64 = 2560$ at the highest resolution). Implemented via xformers `memory_efficient_attention`, reducing peak memory from $O(N^2)$ to $O(N)$. Q/K/V use bias-free linear layers with head count = channel\_dim / num\_head\_channels (e.g., 1280/64 = 20 heads).
The second level is *multi-source cross-attention*, the most distinctive design in UnifoLM-WMA. The unified context vector is split into four semantic sources, each with dedicated K/V projection layers:
#figure(
table(
columns: (2fr, 1fr, 3fr, 2fr),
[*Source*], [*Tokens*], [*K/V Projections*], [*Scale*],
[Text], [77], [`to_k` / `to_v` (shared base)], [1.0],
[Image], [16#sym.times T], [`to_k_ip` / `to_v_ip`], [`image_cross_attention_scale`],
[Agent state], [2], [`to_k_as` / `to_v_as`], [`agent_state_cross_attention_scale`],
[Agent action], [16], [`to_k_aa` / `to_v_aa`], [`agent_action_cross_attention_scale`],
),
caption: [Multi-source cross-attention configuration.],
)
The Query vector Q is always derived from the video latent features via `to_q`. For each of the four sources, independent attention scores are computed — $"softmax"(Q dot K_i^T \/ sqrt(d)) dot V_i$ — producing four separate attention outputs. These outputs are then combined via weighted summation:
$ "out" = "out"_"text" + alpha_"img" dot "out"_"ip" + alpha_"state" dot "out"_"as" + alpha_"action" dot "out"_"aa" $
In the current configuration, `cross_attention_scale_learnable=False` (fixed scales). This decoupled design adds 8 extra linear layers versus standard single-source cross-attention, creating opportunities for KV fusion optimization.
The third level is *temporal self-attention* within the TemporalTransformer. The input $(B, C, T, H, W)$ is reshaped to $(B times H times W, C, T)$, so each spatial position becomes an independent batch element and T=16 time steps form the token sequence. Supports relative position encoding via a `RelativePosition` module and optional causal masks; current configuration uses bidirectional temporal attention.
The fourth level is *action-latent-image cross-attention* in the `ActionLatentImageCrossAttention` module. Action tokens $(B, "action_dim", "act_proj_dim")$ as Query cross-attend to Video UNet features reshaped to $(B, T times H times W, C)$ as Key/Value. A `BasicTransformerBlock` (depth=1) performs action self-attention then cross-attention to video features, with zero-initialized `proj_out` and residual connection. This mechanism is the key bridge enabling the action head to access the visual world model's internal representations.