├─────┼─────────────────────────────────┼───────────────────────┼───────────────────┤

│ 1   │ CUDA Stream 预创建              │ wma_model.py          │ 50次 → 0次        │
  ├─────┼─────────────────────────────────┼───────────────────────┼───────────────────┤
  │ 2   │ noise buffer 预分配             │ ddim.py               │ 50次 alloc → 0次  │
  ├─────┼─────────────────────────────────┼───────────────────────┼───────────────────┤
  │ 3   │ global_feature expand提到循环外 │ conditional_unet1d.py │ ~700次 → ~100次   │
  ├─────┼─────────────────────────────────┼───────────────────────┼───────────────────┤
  │ 4   │ alpha/sigma dtype 预转换        │ ddim.py               │ 200次 .to() → 0次 │
效果不算特别明显
This commit is contained in:
2026-02-10 13:40:52 +00:00
parent a09d35ae5b
commit 2cef3e9e45
4 changed files with 52 additions and 40 deletions

View File

@@ -567,6 +567,11 @@ class ConditionalUnet1D(nn.Module):
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
# Pre-expand global_feature once (reused in every down/mid/up block)
if self.use_linear_act_proj:
global_feature_expanded = global_feature.unsqueeze(1).expand(-1, T, -1)
else:
global_feature_expanded = global_feature.unsqueeze(1).expand(-1, 2, -1)
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
@@ -603,15 +608,11 @@ class ConditionalUnet1D(nn.Module):
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
h.append(x)
@@ -638,15 +639,11 @@ class ConditionalUnet1D(nn.Module):
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
@@ -683,16 +680,12 @@ class ConditionalUnet1D(nn.Module):
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, cur_global_feature)

View File

@@ -251,6 +251,13 @@ class DDIMSampler(object):
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
ts = torch.empty((b, ), device=device, dtype=torch.long)
noise_buf = torch.empty_like(img)
# Pre-convert schedule arrays to inference dtype (avoid per-step .to())
_dtype = img.dtype
_alphas = (self.model.alphas_cumprod if ddim_use_original_steps else self.ddim_alphas).to(_dtype)
_alphas_prev = (self.model.alphas_cumprod_prev if ddim_use_original_steps else self.ddim_alphas_prev).to(_dtype)
_sqrt_one_minus = (self.model.sqrt_one_minus_alphas_cumprod if ddim_use_original_steps else self.ddim_sqrt_one_minus_alphas).to(_dtype)
_sigmas = (self.ddim_sigmas_for_original_num_steps if ddim_use_original_steps else self.ddim_sigmas).to(_dtype)
enable_cross_attn_kv_cache(self.model)
enable_ctx_cache(self.model)
try:
@@ -286,6 +293,8 @@ class DDIMSampler(object):
x0=x0,
fs=fs,
guidance_rescale=guidance_rescale,
noise_buf=noise_buf,
schedule_arrays=(_alphas, _alphas_prev, _sqrt_one_minus, _sigmas),
**kwargs)
img, pred_x0, model_output_action, model_output_state = outs
@@ -339,6 +348,8 @@ class DDIMSampler(object):
mask=None,
x0=None,
guidance_rescale=0.0,
noise_buf=None,
schedule_arrays=None,
**kwargs):
b, *_, device = *x.shape, x.device
@@ -384,16 +395,18 @@ class DDIMSampler(object):
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
**corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if schedule_arrays is not None:
alphas, alphas_prev, sqrt_one_minus_alphas, sigmas = schedule_arrays
else:
alphas = (self.model.alphas_cumprod if use_original_steps else self.ddim_alphas).to(x.dtype)
alphas_prev = (self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev).to(x.dtype)
sqrt_one_minus_alphas = (self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas).to(x.dtype)
sigmas = (self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas).to(x.dtype)
# Use 0-d tensors directly (already on device); broadcasting handles shape
a_t = alphas[index].to(x.dtype)
a_prev = alphas_prev[index].to(x.dtype)
sigma_t = sigmas[index].to(x.dtype)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
a_t = alphas[index]
a_prev = alphas_prev[index]
sigma_t = sigmas[index]
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
@@ -411,8 +424,12 @@ class DDIMSampler(object):
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_buf is not None:
noise_buf.normal_()
noise = sigma_t * noise_buf * temperature
else:
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)

View File

@@ -690,6 +690,8 @@ class WMAModel(nn.Module):
self._ctx_cache = {}
# fs_embed cache
self._fs_embed_cache = None
# Pre-created CUDA stream for parallel action/state UNet
self._side_stream = torch.cuda.Stream() if not self.base_model_gen_only else None
def forward(self,
x: Tensor,
@@ -849,8 +851,8 @@ class WMAModel(nn.Module):
if not self.base_model_gen_only:
ba, _, _ = x_action.shape
ts_state = timesteps[:ba] if b > 1 else timesteps
# Run action_unet and state_unet in parallel via CUDA streams
s_stream = torch.cuda.Stream()
# Run action_unet and state_unet in parallel via pre-created CUDA stream
s_stream = self._side_stream
s_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s_stream):
s_y = self.state_unet(x_state, ts_state, hs_a,

View File

@@ -1,14 +1,14 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-10 10:36:44.797852: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-10 10:36:44.801300: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-10 10:36:44.837891: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-10 10:36:44.837946: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-10 10:36:44.839880: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-10 10:36:44.849073: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-10 10:36:44.849365: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
2026-02-10 13:30:56.669605: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-10 13:30:56.672987: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-10 13:30:56.704235: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-10 13:30:56.704271: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-10 13:30:56.706111: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-10 13:30:56.714239: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-10 13:30:56.714546: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-10 10:36:45.644793: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2026-02-10 13:30:57.511779: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
@@ -116,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:02<07:20, 63.00s/it]
25%|██▌ | 2/8 [02:02<06:05, 60.84s/it]
@@ -140,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>