Files
qibotn/compare_jit_tn_quimb.py
jaunatisblue 5a692033a6
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
添加MPI并行TN benchmark及辅助脚本,移除旧benchmark
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-05 19:04:09 +08:00

51 lines
1.7 KiB
Python

import numpy as np
import os
import sys
def check_results(ref_path, tn_path):
# 1. 检查文件是否存在
if not os.path.exists(ref_path) or not os.path.exists(tn_path):
print(f"Error: 找不到文件!\n参考文件: {ref_path}\n待测文件: {tn_path}")
return
print(f"正在加载数据并对比: \n [Ref] {ref_path}\n [TN ] {tn_path}\n")
try:
# 2. 加载状态向量
# mmap_mode='r' 可以防止大文件直接撑爆内存
sv_ref = np.load(ref_path, mmap_mode='r')
sv_tn = np.load(tn_path, mmap_mode='r')
# 3. 计算保真度 (Fidelity)
# fid = |<ref|tn>|^2
inner_product = np.dot(sv_ref.conj(), sv_tn)
fidelity = np.abs(inner_product)**2
# 4. 计算 L2 误差 (欧氏距离)
l2_error = np.linalg.norm(sv_ref - sv_tn)
# 5. 打印结果
print("-" * 30)
print(f"保真度 (Fidelity): {fidelity:.12f}")
#print(f"L2 范数误差: {l2_error:.2e}")
print("-" * 30)
# phase-invariant L2: align global phase first
phase = inner_product / np.abs(inner_product)
l2_phase_corrected = np.linalg.norm(sv_ref - sv_tn / phase)
print(f"L2 误差(相位校正后): {l2_phase_corrected:.2e}")
if fidelity > 0.999999:
print("✅ 验证通过:结果高度一致。")
else:
print("❌ 警告:保真度较低,请检查收缩路径或截断误差。")
except Exception as e:
print(f"计算过程中发生错误: {e}")
if __name__ == "__main__":
# 你可以在这里直接修改文件名
REF_FILE = 'data/sv_qibojit_qft30.npy'
TN_FILE = 'data/sv_tn_qft30_mpi.npy'
check_results(REF_FILE, TN_FILE)