VAE优化,模型直接加载至GPU

This commit is contained in:
qhy
2026-02-07 17:36:00 +08:00
parent aba2a90045
commit 7dcf9e8b89
5 changed files with 679 additions and 622 deletions

View File

@@ -222,7 +222,7 @@ data:
test: test:
target: unifolm_wma.data.wma_data.WMAData target: unifolm_wma.data.wma_data.WMAData
params: params:
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts' data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
video_length: ${model.params.wma_config.params.temporal_length} video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2 frame_stride: 2
load_raw_resolution: True load_raw_resolution: True

File diff suppressed because it is too large Load Diff

View File

@@ -99,13 +99,16 @@ class AutoencoderKL(pl.LightningModule):
print(f"Restored from {path}") print(f"Restored from {path}")
def encode(self, x, **kwargs): def encode(self, x, **kwargs):
if getattr(self, '_channels_last', False):
x = x.to(memory_format=torch.channels_last)
h = self.encoder(x) h = self.encoder(x)
moments = self.quant_conv(h) moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)
return posterior return posterior
def decode(self, z, **kwargs): def decode(self, z, **kwargs):
if getattr(self, '_channels_last', False):
z = z.to(memory_format=torch.channels_last)
z = self.post_quant_conv(z) z = self.post_quant_conv(z)
dec = self.decoder(z) dec = self.decoder(z)
return dec return dec

View File

@@ -1073,15 +1073,19 @@ class LatentDiffusion(DDPM):
if not self.perframe_ae: if not self.perframe_ae:
encoder_posterior = self.first_stage_model.encode(x) encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach() results = self.get_first_stage_encoding(encoder_posterior).detach()
else: ## Consume less GPU memory but slower else: ## Batch encode with configurable batch size
results = [] bs = getattr(self, 'vae_encode_bs', 1)
for index in range(x.shape[0]): if bs >= x.shape[0]:
frame_batch = self.first_stage_model.encode(x[index:index + encoder_posterior = self.first_stage_model.encode(x)
1, :, :, :]) results = self.get_first_stage_encoding(encoder_posterior).detach()
frame_result = self.get_first_stage_encoding( else:
frame_batch).detach() results = []
results.append(frame_result) for i in range(0, x.shape[0], bs):
results = torch.cat(results, dim=0) frame_batch = self.first_stage_model.encode(x[i:i + bs])
frame_result = self.get_first_stage_encoding(
frame_batch).detach()
results.append(frame_result)
results = torch.cat(results, dim=0)
if reshape_back: if reshape_back:
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t) results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
@@ -1105,16 +1109,21 @@ class LatentDiffusion(DDPM):
else: else:
reshape_back = False reshape_back = False
z = 1. / self.scale_factor * z
if not self.perframe_ae: if not self.perframe_ae:
z = 1. / self.scale_factor * z
results = self.first_stage_model.decode(z, **kwargs) results = self.first_stage_model.decode(z, **kwargs)
else: else:
results = [] bs = getattr(self, 'vae_decode_bs', 1)
for index in range(z.shape[0]): if bs >= z.shape[0]:
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :] # all frames in one batch
frame_result = self.first_stage_model.decode(frame_z, **kwargs) results = self.first_stage_model.decode(z, **kwargs)
results.append(frame_result) else:
results = torch.cat(results, dim=0) results = []
for i in range(0, z.shape[0], bs):
results.append(
self.first_stage_model.decode(z[i:i + bs], **kwargs))
results = torch.cat(results, dim=0)
if reshape_back: if reshape_back:
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t) results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)

View File

@@ -10,8 +10,8 @@ from unifolm_wma.utils.utils import instantiate_from_config
def nonlinearity(x): def nonlinearity(x):
# swish # swish / SiLU — single fused CUDA kernel instead of x * sigmoid(x)
return x * torch.sigmoid(x) return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):