添加MPI并行TN benchmark及辅助脚本,移除旧benchmark
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
51
compare_jit_tn_quimb.py
Normal file
51
compare_jit_tn_quimb.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
|
||||
def check_results(ref_path, tn_path):
|
||||
# 1. 检查文件是否存在
|
||||
if not os.path.exists(ref_path) or not os.path.exists(tn_path):
|
||||
print(f"Error: 找不到文件!\n参考文件: {ref_path}\n待测文件: {tn_path}")
|
||||
return
|
||||
|
||||
print(f"正在加载数据并对比: \n [Ref] {ref_path}\n [TN ] {tn_path}\n")
|
||||
|
||||
try:
|
||||
# 2. 加载状态向量
|
||||
# mmap_mode='r' 可以防止大文件直接撑爆内存
|
||||
sv_ref = np.load(ref_path, mmap_mode='r')
|
||||
sv_tn = np.load(tn_path, mmap_mode='r')
|
||||
|
||||
# 3. 计算保真度 (Fidelity)
|
||||
# fid = |<ref|tn>|^2
|
||||
inner_product = np.dot(sv_ref.conj(), sv_tn)
|
||||
fidelity = np.abs(inner_product)**2
|
||||
|
||||
# 4. 计算 L2 误差 (欧氏距离)
|
||||
l2_error = np.linalg.norm(sv_ref - sv_tn)
|
||||
|
||||
# 5. 打印结果
|
||||
print("-" * 30)
|
||||
print(f"保真度 (Fidelity): {fidelity:.12f}")
|
||||
#print(f"L2 范数误差: {l2_error:.2e}")
|
||||
print("-" * 30)
|
||||
|
||||
# phase-invariant L2: align global phase first
|
||||
phase = inner_product / np.abs(inner_product)
|
||||
l2_phase_corrected = np.linalg.norm(sv_ref - sv_tn / phase)
|
||||
print(f"L2 误差(相位校正后): {l2_phase_corrected:.2e}")
|
||||
|
||||
if fidelity > 0.999999:
|
||||
print("✅ 验证通过:结果高度一致。")
|
||||
else:
|
||||
print("❌ 警告:保真度较低,请检查收缩路径或截断误差。")
|
||||
|
||||
except Exception as e:
|
||||
print(f"计算过程中发生错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 你可以在这里直接修改文件名
|
||||
REF_FILE = 'data/sv_qibojit_qft30.npy'
|
||||
TN_FILE = 'data/sv_tn_qft30_mpi.npy'
|
||||
|
||||
check_results(REF_FILE, TN_FILE)
|
||||
Reference in New Issue
Block a user