复用 DDIMSampler + make_schedule微弱提升

This commit is contained in:
2026-02-09 18:26:39 +00:00
parent 6dca3696d8
commit 0b3b0e534a
3 changed files with 21 additions and 13 deletions

View File

@@ -1803,7 +1803,9 @@ class LatentDiffusion(DDPM):
"""
if ddim:
ddim_sampler = DDIMSampler(self)
if not hasattr(self, '_ddim_sampler') or self._ddim_sampler is None:
self._ddim_sampler = DDIMSampler(self)
ddim_sampler = self._ddim_sampler
shape = (self.channels, self.temporal_length, *self.image_size)
samples, actions, states, intermediates = ddim_sampler.sample(
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)

View File

@@ -18,6 +18,7 @@ class DDIMSampler(object):
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.counter = 0
self._schedule_key = None # (ddim_num_steps, ddim_discretize, ddim_eta)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
@@ -30,6 +31,11 @@ class DDIMSampler(object):
ddim_discretize="uniform",
ddim_eta=0.,
verbose=True):
key = (ddim_num_steps, ddim_discretize, ddim_eta)
if self._schedule_key == key:
return
self._schedule_key = key
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,

View File

@@ -1,14 +1,14 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-09 17:32:41.850068: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-09 17:32:41.853132: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 17:32:41.886058: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-09 17:32:41.886103: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-09 17:32:41.887979: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-09 17:32:41.896994: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 17:32:41.897283: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
2026-02-09 18:16:36.491189: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-09 18:16:36.494639: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 18:16:36.527202: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-09 18:16:36.527247: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-09 18:16:36.529027: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-09 18:16:36.537430: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 18:16:36.537748: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-09 17:32:42.611394: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2026-02-09 18:16:37.281129: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
@@ -116,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:08<07:58, 68.38s/it]
25%|██▌ | 2/8 [02:13<06:37, 66.30s/it]
@@ -140,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>