├─────┼─────────────────────────────────┼───────────────────────┼───────────────────┤
│ 1 │ CUDA Stream 预创建 │ wma_model.py │ 50次 → 0次 │ ├─────┼─────────────────────────────────┼───────────────────────┼───────────────────┤ │ 2 │ noise buffer 预分配 │ ddim.py │ 50次 alloc → 0次 │ ├─────┼─────────────────────────────────┼───────────────────────┼───────────────────┤ │ 3 │ global_feature expand提到循环外 │ conditional_unet1d.py │ ~700次 → ~100次 │ ├─────┼─────────────────────────────────┼───────────────────────┼───────────────────┤ │ 4 │ alpha/sigma dtype 预转换 │ ddim.py │ 200次 .to() → 0次 │ 效果不算特别明显
This commit is contained in:
@@ -567,6 +567,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timesteps = timesteps.expand(sample.shape[0])
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
global_feature = self.diffusion_step_encoder(timesteps)
|
global_feature = self.diffusion_step_encoder(timesteps)
|
||||||
|
# Pre-expand global_feature once (reused in every down/mid/up block)
|
||||||
|
if self.use_linear_act_proj:
|
||||||
|
global_feature_expanded = global_feature.unsqueeze(1).expand(-1, T, -1)
|
||||||
|
else:
|
||||||
|
global_feature_expanded = global_feature.unsqueeze(1).expand(-1, 2, -1)
|
||||||
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
||||||
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
||||||
|
|
||||||
@@ -603,15 +608,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
|
|
||||||
if self.use_linear_act_proj:
|
if self.use_linear_act_proj:
|
||||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=T, dim=1)
|
|
||||||
else:
|
else:
|
||||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=2, dim=1)
|
|
||||||
cur_global_feature = torch.cat(
|
cur_global_feature = torch.cat(
|
||||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||||
x = resnet(x, cur_global_feature)
|
x = resnet(x, cur_global_feature)
|
||||||
x = resnet2(x, cur_global_feature)
|
x = resnet2(x, cur_global_feature)
|
||||||
h.append(x)
|
h.append(x)
|
||||||
@@ -638,15 +639,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||||
if self.use_linear_act_proj:
|
if self.use_linear_act_proj:
|
||||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
|
||||||
repeats=T, dim=1)
|
|
||||||
else:
|
else:
|
||||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
|
||||||
repeats=2, dim=1)
|
|
||||||
cur_global_feature = torch.cat(
|
cur_global_feature = torch.cat(
|
||||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||||
x = resnet(x, cur_global_feature)
|
x = resnet(x, cur_global_feature)
|
||||||
x = resnet2(x, cur_global_feature)
|
x = resnet2(x, cur_global_feature)
|
||||||
|
|
||||||
@@ -683,16 +680,12 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
|
|
||||||
if self.use_linear_act_proj:
|
if self.use_linear_act_proj:
|
||||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=T, dim=1)
|
|
||||||
else:
|
else:
|
||||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=2, dim=1)
|
|
||||||
|
|
||||||
cur_global_feature = torch.cat(
|
cur_global_feature = torch.cat(
|
||||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||||
|
|
||||||
x = torch.cat((x, h.pop()), dim=1)
|
x = torch.cat((x, h.pop()), dim=1)
|
||||||
x = resnet(x, cur_global_feature)
|
x = resnet(x, cur_global_feature)
|
||||||
|
|||||||
@@ -251,6 +251,13 @@ class DDIMSampler(object):
|
|||||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||||
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||||
|
noise_buf = torch.empty_like(img)
|
||||||
|
# Pre-convert schedule arrays to inference dtype (avoid per-step .to())
|
||||||
|
_dtype = img.dtype
|
||||||
|
_alphas = (self.model.alphas_cumprod if ddim_use_original_steps else self.ddim_alphas).to(_dtype)
|
||||||
|
_alphas_prev = (self.model.alphas_cumprod_prev if ddim_use_original_steps else self.ddim_alphas_prev).to(_dtype)
|
||||||
|
_sqrt_one_minus = (self.model.sqrt_one_minus_alphas_cumprod if ddim_use_original_steps else self.ddim_sqrt_one_minus_alphas).to(_dtype)
|
||||||
|
_sigmas = (self.ddim_sigmas_for_original_num_steps if ddim_use_original_steps else self.ddim_sigmas).to(_dtype)
|
||||||
enable_cross_attn_kv_cache(self.model)
|
enable_cross_attn_kv_cache(self.model)
|
||||||
enable_ctx_cache(self.model)
|
enable_ctx_cache(self.model)
|
||||||
try:
|
try:
|
||||||
@@ -286,6 +293,8 @@ class DDIMSampler(object):
|
|||||||
x0=x0,
|
x0=x0,
|
||||||
fs=fs,
|
fs=fs,
|
||||||
guidance_rescale=guidance_rescale,
|
guidance_rescale=guidance_rescale,
|
||||||
|
noise_buf=noise_buf,
|
||||||
|
schedule_arrays=(_alphas, _alphas_prev, _sqrt_one_minus, _sigmas),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
img, pred_x0, model_output_action, model_output_state = outs
|
img, pred_x0, model_output_action, model_output_state = outs
|
||||||
@@ -339,6 +348,8 @@ class DDIMSampler(object):
|
|||||||
mask=None,
|
mask=None,
|
||||||
x0=None,
|
x0=None,
|
||||||
guidance_rescale=0.0,
|
guidance_rescale=0.0,
|
||||||
|
noise_buf=None,
|
||||||
|
schedule_arrays=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
@@ -384,16 +395,18 @@ class DDIMSampler(object):
|
|||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||||
**corrector_kwargs)
|
**corrector_kwargs)
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
if schedule_arrays is not None:
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
alphas, alphas_prev, sqrt_one_minus_alphas, sigmas = schedule_arrays
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
else:
|
||||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
alphas = (self.model.alphas_cumprod if use_original_steps else self.ddim_alphas).to(x.dtype)
|
||||||
|
alphas_prev = (self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev).to(x.dtype)
|
||||||
|
sqrt_one_minus_alphas = (self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas).to(x.dtype)
|
||||||
|
sigmas = (self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas).to(x.dtype)
|
||||||
|
|
||||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
a_t = alphas[index]
|
||||||
a_t = alphas[index].to(x.dtype)
|
a_prev = alphas_prev[index]
|
||||||
a_prev = alphas_prev[index].to(x.dtype)
|
sigma_t = sigmas[index]
|
||||||
sigma_t = sigmas[index].to(x.dtype)
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
|
||||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
|
|
||||||
|
|
||||||
if self.model.parameterization != "v":
|
if self.model.parameterization != "v":
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
@@ -411,6 +424,10 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
|
||||||
|
if noise_buf is not None:
|
||||||
|
noise_buf.normal_()
|
||||||
|
noise = sigma_t * noise_buf * temperature
|
||||||
|
else:
|
||||||
noise = sigma_t * noise_like(x.shape, device,
|
noise = sigma_t * noise_like(x.shape, device,
|
||||||
repeat_noise) * temperature
|
repeat_noise) * temperature
|
||||||
if noise_dropout > 0.:
|
if noise_dropout > 0.:
|
||||||
|
|||||||
@@ -690,6 +690,8 @@ class WMAModel(nn.Module):
|
|||||||
self._ctx_cache = {}
|
self._ctx_cache = {}
|
||||||
# fs_embed cache
|
# fs_embed cache
|
||||||
self._fs_embed_cache = None
|
self._fs_embed_cache = None
|
||||||
|
# Pre-created CUDA stream for parallel action/state UNet
|
||||||
|
self._side_stream = torch.cuda.Stream() if not self.base_model_gen_only else None
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@@ -849,8 +851,8 @@ class WMAModel(nn.Module):
|
|||||||
if not self.base_model_gen_only:
|
if not self.base_model_gen_only:
|
||||||
ba, _, _ = x_action.shape
|
ba, _, _ = x_action.shape
|
||||||
ts_state = timesteps[:ba] if b > 1 else timesteps
|
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||||
# Run action_unet and state_unet in parallel via CUDA streams
|
# Run action_unet and state_unet in parallel via pre-created CUDA stream
|
||||||
s_stream = torch.cuda.Stream()
|
s_stream = self._side_stream
|
||||||
s_stream.wait_stream(torch.cuda.current_stream())
|
s_stream.wait_stream(torch.cuda.current_stream())
|
||||||
with torch.cuda.stream(s_stream):
|
with torch.cuda.stream(s_stream):
|
||||||
s_y = self.state_unet(x_state, ts_state, hs_a,
|
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
2026-02-10 10:36:44.797852: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-10 13:30:56.669605: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-10 10:36:44.801300: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-10 13:30:56.672987: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-10 10:36:44.837891: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-10 13:30:56.704235: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-10 10:36:44.837946: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-10 13:30:56.704271: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-10 10:36:44.839880: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-10 13:30:56.706111: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-10 10:36:44.849073: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-10 13:30:56.714239: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-10 10:36:44.849365: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-10 13:30:56.714546: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-10 10:36:45.644793: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-10 13:30:57.511779: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
[rank: 0] Global seed set to 123
|
[rank: 0] Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
@@ -116,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
12%|█▎ | 1/8 [01:02<07:20, 63.00s/it]
|
12%|█▎ | 1/8 [01:02<07:20, 63.00s/it]
|
||||||
25%|██▌ | 2/8 [02:02<06:05, 60.84s/it]
|
25%|██▌ | 2/8 [02:02<06:05, 60.84s/it]
|
||||||
@@ -140,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
>>> Step 5: interacting with world model ...
|
>>> Step 5: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
Reference in New Issue
Block a user