轻量投影/MLP BF16

psnr指标反而比只量化扩散主干模型要低,原因不明
This commit is contained in:
2026-01-18 18:26:37 +08:00
parent 2b634cde90
commit 3c0f409fcf
6 changed files with 96 additions and 48 deletions

View File

@@ -289,13 +289,15 @@ def image_guided_synthesis(model: torch.nn.Module,
if not text_input:
prompts = [""] * batch_size
b, c, t, h, w = videos.shape
img = videos[:, :, 0]
img_emb = model.embedder(img)
img_emb = model.image_proj_model(img_emb)
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
cond_emb = model.get_learned_conditioning(prompts)
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
b, c, t, h, w = videos.shape
img = videos[:, :, 0]
img_emb = model.embedder(img)
cond_emb = model.get_learned_conditioning(prompts)
target_dtype = cond_emb.dtype
img_emb = model._projector_forward(model.image_proj_model, img_emb,
target_dtype)
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]}
if model.model.conditioning_key == 'hybrid':