VAE优化,模型直接加载至GPU
This commit is contained in:
@@ -1073,15 +1073,19 @@ class LatentDiffusion(DDPM):
|
||||
if not self.perframe_ae:
|
||||
encoder_posterior = self.first_stage_model.encode(x)
|
||||
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||
else: ## Consume less GPU memory but slower
|
||||
results = []
|
||||
for index in range(x.shape[0]):
|
||||
frame_batch = self.first_stage_model.encode(x[index:index +
|
||||
1, :, :, :])
|
||||
frame_result = self.get_first_stage_encoding(
|
||||
frame_batch).detach()
|
||||
results.append(frame_result)
|
||||
results = torch.cat(results, dim=0)
|
||||
else: ## Batch encode with configurable batch size
|
||||
bs = getattr(self, 'vae_encode_bs', 1)
|
||||
if bs >= x.shape[0]:
|
||||
encoder_posterior = self.first_stage_model.encode(x)
|
||||
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||
else:
|
||||
results = []
|
||||
for i in range(0, x.shape[0], bs):
|
||||
frame_batch = self.first_stage_model.encode(x[i:i + bs])
|
||||
frame_result = self.get_first_stage_encoding(
|
||||
frame_batch).detach()
|
||||
results.append(frame_result)
|
||||
results = torch.cat(results, dim=0)
|
||||
|
||||
if reshape_back:
|
||||
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
@@ -1105,16 +1109,21 @@ class LatentDiffusion(DDPM):
|
||||
else:
|
||||
reshape_back = False
|
||||
|
||||
z = 1. / self.scale_factor * z
|
||||
|
||||
if not self.perframe_ae:
|
||||
z = 1. / self.scale_factor * z
|
||||
results = self.first_stage_model.decode(z, **kwargs)
|
||||
else:
|
||||
results = []
|
||||
for index in range(z.shape[0]):
|
||||
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
|
||||
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
|
||||
results.append(frame_result)
|
||||
results = torch.cat(results, dim=0)
|
||||
bs = getattr(self, 'vae_decode_bs', 1)
|
||||
if bs >= z.shape[0]:
|
||||
# all frames in one batch
|
||||
results = self.first_stage_model.decode(z, **kwargs)
|
||||
else:
|
||||
results = []
|
||||
for i in range(0, z.shape[0], bs):
|
||||
results.append(
|
||||
self.first_stage_model.decode(z[i:i + bs], **kwargs))
|
||||
results = torch.cat(results, dim=0)
|
||||
|
||||
if reshape_back:
|
||||
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
|
||||
Reference in New Issue
Block a user