KV 融合实现完成。改动总结: 速度微弱提升psnr略微上升

attention.py — 3处改动:
  1. __init__ 添加 _kv_fused = False 标志
  2.新增 fuse_kv() 方法:将 to_k + to_v → to_kv,同时处理 _ip/_as/_aa 辅助 KV 对
  2. bmm_forward 两个分支加_kv_fused 判断,用to_kv().chunk(2, dim=-1) 替代分别调用
This commit is contained in:
qhy
2026-02-11 12:36:38 +08:00
parent b558856e1e
commit 9a08e27a19
4 changed files with 180 additions and 38 deletions

View File

@@ -579,6 +579,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
device = get_device_from_parameters(model)
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
from unifolm_wma.modules.attention import CrossAttention
kv_count = sum(1 for m in model.modules()
if isinstance(m, CrossAttention) and m.fuse_kv())
print(f" ✓ KV fused: {kv_count} attention layers")
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16