KV 融合实现完成。改动总结: 速度微弱提升psnr略微上升

attention.py — 3处改动:
  1. __init__ 添加 _kv_fused = False 标志
  2.新增 fuse_kv() 方法:将 to_k + to_v → to_kv,同时处理 _ip/_as/_aa 辅助 KV 对
  2. bmm_forward 两个分支加_kv_fused 判断,用to_kv().chunk(2, dim=-1) 替代分别调用
This commit is contained in:
2026-02-10 18:07:23 +00:00
parent 2cef3e9e45
commit 57ba85d147
4 changed files with 61 additions and 23 deletions

View File

@@ -625,6 +625,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Compile hot ResBlocks for operator fusion
apply_torch_compile(model)
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
from unifolm_wma.modules.attention import CrossAttention
kv_count = sum(1 for m in model.modules()
if isinstance(m, CrossAttention) and m.fuse_kv())
print(f" ✓ KV fused: {kv_count} attention layers")
# Export precision-converted checkpoint if requested
if args.export_precision_ckpt:
export_path = args.export_precision_ckpt