Merge branch 'midend' into midend-LoopAnalysis

This commit is contained in:
rain2133
2025-08-11 21:20:34 +08:00
41 changed files with 5604 additions and 2382 deletions

View File

@@ -60,11 +60,7 @@ display_file_content() {
# 清理临时文件的函数
clean_tmp() {
echo "正在清理临时目录: ${TMP_DIR}"
rm -rf "${TMP_DIR}"/*.s \
"${TMP_DIR}"/*_sysyc_riscv64 \
"${TMP_DIR}"/*_sysyc_riscv64.actual_out \
"${TMP_DIR}"/*_sysyc_riscv64.expected_stdout \
"${TMP_DIR}"/*_sysyc_riscv64.o
rm -rf "${TMP_DIR}"/*
echo "清理完成。"
}

View File

@@ -2,64 +2,67 @@
# runit-single.sh - 用于编译和测试单个或少量 SysY 程序的脚本
# 模仿 runit.sh 的功能,但以具体文件路径作为输入。
# 此脚本应该位于 mysysy/script/
export ASAN_OPTIONS=detect_leaks=0
# --- 配置区 ---
# 请根据你的环境修改这些路径
# 假设此脚本位于你的项目根目录或一个脚本目录中
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
# 默认寻找项目根目录下的 build 和 lib
BUILD_BIN_DIR="${SCRIPT_DIR}/../build/bin"
LIB_DIR="${SCRIPT_DIR}/../lib"
# 临时文件会存储在脚本所在目录的 tmp 子目录中
TMP_DIR="${SCRIPT_DIR}/tmp"
# 定义编译器和模拟器
SYSYC="${BUILD_BIN_DIR}/sysyc"
LLC_CMD="llc-19" # 新增
GCC_RISCV64="riscv64-linux-gnu-gcc"
QEMU_RISCV64="qemu-riscv64"
# --- 初始化变量 ---
EXECUTE_MODE=false
IR_EXECUTE_MODE=false # 新增
CLEAN_MODE=false
SYSYC_TIMEOUT=10 # sysyc 编译超时 (秒)
GCC_TIMEOUT=10 # gcc 编译超时 (秒)
EXEC_TIMEOUT=5 # qemu 自动化执行超时 (秒)
MAX_OUTPUT_LINES=50 # 对比失败时显示的最大行数
SY_FILES=() # 存储用户提供的 .sy 文件列表
OPTIMIZE_FLAG=""
SYSYC_TIMEOUT=30
LLC_TIMEOUT=10 # 新增
GCC_TIMEOUT=10
EXEC_TIMEOUT=30
MAX_OUTPUT_LINES=20
SY_FILES=()
PASSED_CASES=0
FAILED_CASES_LIST=""
INTERRUPTED=false # 新增
# =================================================================
# --- 函数定义 ---
# =================================================================
show_help() {
echo "用法: $0 [文件1.sy] [文件2.sy] ... [选项]"
echo "编译并测试指定的 .sy 文件。"
echo ""
echo "如果找到对应的 .in/.out 文件,则进行自动化测试。否则,进入交互模式。"
echo "编译并测试指定的 .sy 文件。必须提供 -e 或 -eir 之一。"
echo ""
echo "选项:"
echo " -e, --executable 编译为可执行文件并运行测试 (必须)。"
echo " -e 通过汇编运行测试 (sysyc -> gcc -> qemu)。"
echo " -eir 通过IR运行测试 (sysyc -> llc -> gcc -> qemu)。"
echo " -c, --clean 清理 tmp 临时目录下的所有文件。"
echo " -sct N 设置 sysyc 编译超时为 N 秒 (默认: 10)。"
echo " -O1 启用 sysyc 的 -O1 优化。"
echo " -sct N 设置 sysyc 编译超时为 N 秒 (默认: 30)。"
echo " -lct N 设置 llc-19 编译超时为 N 秒 (默认: 10)。"
echo " -gct N 设置 gcc 交叉编译超时为 N 秒 (默认: 10)。"
echo " -et N 设置 qemu 自动化执行超时为 N 秒 (默认: 5)。"
echo " -ml N, --max-lines N 当输出对比失败时,最多显示 N 行内容 (默认: 50)。"
echo " -et N 设置 qemu 自动化执行超时为 N 秒 (默认: 30)。"
echo " -ml N, --max-lines N 当输出对比失败时,最多显示 N 行内容 (默认: 20)。"
echo " -h, --help 显示此帮助信息并退出。"
echo ""
echo "可在任何时候按 Ctrl+C 来中断测试并显示当前已完成的测例总结。"
}
# --- 新增功能: 显示文件内容并根据行数截断 ---
display_file_content() {
local file_path="$1"
local title="$2"
local max_lines="$3"
if [ ! -f "$file_path" ]; then
return
fi
if [ ! -f "$file_path" ]; then return; fi
echo -e "$title"
local line_count
line_count=$(wc -l < "$file_path")
if [ "$line_count" -gt "$max_lines" ]; then
head -n "$max_lines" "$file_path"
echo -e "\e[33m[... 输出已截断,共 ${line_count} 行 ...]\e[0m"
@@ -68,51 +71,79 @@ display_file_content() {
fi
}
# --- 本次修改点: 整个参数解析逻辑被重写 ---
# 使用标准的 while 循环来健壮地处理任意顺序的参数
# --- 新增:总结报告函数 ---
print_summary() {
local total_cases=${#SY_FILES[@]}
echo ""
echo "======================================================================"
if [ "$INTERRUPTED" = true ]; then
echo -e "\e[33m测试被中断。正在汇总已完成的结果...\e[0m"
else
echo "所有测试完成"
fi
local failed_count
if [ -n "$FAILED_CASES_LIST" ]; then
failed_count=$(echo -e -n "${FAILED_CASES_LIST}" | wc -l)
else
failed_count=0
fi
local executed_count=$((PASSED_CASES + failed_count))
echo "测试结果: [通过: ${PASSED_CASES}, 失败: ${failed_count}, 已执行: ${executed_count}/${total_cases}]"
if [ -n "$FAILED_CASES_LIST" ]; then
echo ""
echo -e "\e[31m未通过的测例:\e[0m"
printf "%b" "${FAILED_CASES_LIST}"
fi
echo "======================================================================"
if [ "$failed_count" -gt 0 ]; then
exit 1
else
exit 0
fi
}
# --- 新增SIGINT 信号处理函数 ---
handle_sigint() {
INTERRUPTED=true
print_summary
}
# =================================================================
# --- 主逻辑开始 ---
# =================================================================
# --- 新增:设置 trap 来捕获 SIGINT ---
trap handle_sigint SIGINT
# --- 参数解析 ---
while [[ "$#" -gt 0 ]]; do
case "$1" in
-e|--executable)
EXECUTE_MODE=true
shift # 消耗选项
;;
-c|--clean)
CLEAN_MODE=true
shift # 消耗选项
;;
-sct)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then SYSYC_TIMEOUT="$2"; shift 2; else echo "错误: -sct 需要一个正整数参数。" >&2; exit 1; fi
;;
-gct)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then GCC_TIMEOUT="$2"; shift 2; else echo "错误: -gct 需要一个正整数参数。" >&2; exit 1; fi
;;
-et)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then EXEC_TIMEOUT="$2"; shift 2; else echo "错误: -et 需要一个正整数参数。" >&2; exit 1; fi
;;
-ml|--max-lines)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_LINES="$2"; shift 2; else echo "错误: --max-lines 需要一个正整数参数。" >&2; exit 1; fi
;;
-h|--help)
show_help
exit 0
;;
-*) # 未知选项
echo "未知选项: $1"
show_help
exit 1
;;
*) # 其他参数被视为文件路径
-e|--executable) EXECUTE_MODE=true; shift ;;
-eir) IR_EXECUTE_MODE=true; shift ;; # 新增
-c|--clean) CLEAN_MODE=true; shift ;;
-O1) OPTIMIZE_FLAG="-O1"; shift ;;
-lct) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then LLC_TIMEOUT="$2"; shift 2; else echo "错误: -lct 需要一个正整数参数。" >&2; exit 1; fi ;; # 新增
-sct) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then SYSYC_TIMEOUT="$2"; shift 2; else echo "错误: -sct 需要一个正整数参数。" >&2; exit 1; fi ;;
-gct) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then GCC_TIMEOUT="$2"; shift 2; else echo "错误: -gct 需要一个正整数参数。" >&2; exit 1; fi ;;
-et) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then EXEC_TIMEOUT="$2"; shift 2; else echo "错误: -et 需要一个正整数参数。" >&2; exit 1; fi ;;
-ml|--max-lines) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_LINES="$2"; shift 2; else echo "错误: --max-lines 需要一个正整数参数。" >&2; exit 1; fi ;;
-h|--help) show_help; exit 0 ;;
-*) echo "未知选项: $1"; show_help; exit 1 ;;
*)
if [[ -f "$1" && "$1" == *.sy ]]; then
SY_FILES+=("$1")
else
echo "警告: 无效文件或不是 .sy 文件,已忽略: $1"
fi
shift # 消耗文件参数
shift
;;
esac
done
if ${CLEAN_MODE}; then
echo "检测到 -c/--clean 选项,正在清空 ${TMP_DIR}..."
if [ -d "${TMP_DIR}" ]; then
@@ -121,19 +152,22 @@ if ${CLEAN_MODE}; then
else
echo "临时目录 ${TMP_DIR} 不存在,无需清理。"
fi
if [ ${#SY_FILES[@]} -eq 0 ] && ! ${EXECUTE_MODE}; then
if [ ${#SY_FILES[@]} -eq 0 ] && ! ${EXECUTE_MODE} && ! ${IR_EXECUTE_MODE}; then
exit 0
fi
fi
# --- 主逻辑开始 ---
if ! ${EXECUTE_MODE}; then
echo "错误: 请提供 -e 或 --executable 选项来运行测试。"
if ! ${EXECUTE_MODE} && ! ${IR_EXECUTE_MODE}; then
echo "错误: 请提供 -e 或 -eir 选项来运行测试。"
show_help
exit 1
fi
if ${EXECUTE_MODE} && ${IR_EXECUTE_MODE}; then
echo -e "\e[31m错误: -e 和 -eir 选项不能同时使用。\e[0m" >&2
exit 1
fi
if [ ${#SY_FILES[@]} -eq 0 ]; then
echo "错误: 未提供任何 .sy 文件作为输入。"
show_help
@@ -144,18 +178,18 @@ mkdir -p "${TMP_DIR}"
TOTAL_CASES=${#SY_FILES[@]}
echo "SysY 单例测试运行器启动..."
echo "超时设置: sysyc=${SYSYC_TIMEOUT}s, gcc=${GCC_TIMEOUT}s, qemu=${EXEC_TIMEOUT}s"
echo "失败输出最大行数: ${MAX_OUTPUT_LINES}"
if [ -n "$OPTIMIZE_FLAG" ]; then echo "优化等级: ${OPTIMIZE_FLAG}"; fi
echo "超时设置: sysyc=${SYSYC_TIMEOUT}s, llc=${LLC_TIMEOUT}s, gcc=${GCC_TIMEOUT}s, qemu=${EXEC_TIMEOUT}s"
echo ""
for sy_file in "${SY_FILES[@]}"; do
is_passed=1
compilation_ok=1
base_name=$(basename "${sy_file}" .sy)
source_dir=$(dirname "${sy_file}")
ir_file="${TMP_DIR}/${base_name}_sysyc_riscv64.ll"
ir_file="${TMP_DIR}/${base_name}.ll"
assembly_file="${TMP_DIR}/${base_name}.s"
assembly_debug_file="${TMP_DIR}/${base_name}_d.s"
executable_file="${TMP_DIR}/${base_name}"
input_file="${source_dir}/${base_name}.in"
output_reference_file="${source_dir}/${base_name}.out"
@@ -164,37 +198,39 @@ for sy_file in "${SY_FILES[@]}"; do
echo "======================================================================"
echo "正在处理: ${sy_file}"
# 步骤 1: sysyc 编译
echo " 使用 sysyc 编译 (超时 ${SYSYC_TIMEOUT}s)..."
timeout -s KILL ${SYSYC_TIMEOUT} "${SYSYC}" -s ir "${sy_file}" > "${ir_file}"
SYSYC_STATUS=$?
if [ $SYSYC_STATUS -eq 124 ]; then
echo -e "\e[31m错误: SysY 编译 ${sy_file} IR超时\e[0m"
is_passed=0
elif [ $SYSYC_STATUS -ne 0 ]; then
echo -e "\e[31m错误: SysY 编译 ${sy_file} IR失败退出码: ${SYSYC_STATUS}\e[0m"
is_passed=0
fi
timeout -s KILL ${SYSYC_TIMEOUT} "${SYSYC}" -S "${sy_file}" -o "${assembly_file}"
if [ $? -ne 0 ]; then
echo -e "\e[31m错误: SysY 编译失败或超时。\e[0m"
is_passed=0
fi
# timeout -s KILL ${SYSYC_TIMEOUT} "${SYSYC}" -s asmd "${sy_file}" > "${assembly_debug_file}" 2>&1
# --- 编译阶段 ---
if ${IR_EXECUTE_MODE}; then
# 路径1: sysyc -> llc -> gcc
echo " [1/3] 使用 sysyc 编译为 IR (超时 ${SYSYC_TIMEOUT}s)..."
timeout -s KILL ${SYSYC_TIMEOUT} "${SYSYC}" -s ir "${sy_file}" ${OPTIMIZE_FLAG} -o "${ir_file}"
if [ $? -ne 0 ]; then echo -e "\e[31m错误: SysY (IR) 编译失败或超时。\e[0m"; compilation_ok=0; fi
# 步骤 2: GCC 编译
if [ "$is_passed" -eq 1 ]; then
echo " 使用 gcc 编译 (超时 ${GCC_TIMEOUT}s)..."
if [ "$compilation_ok" -eq 1 ]; then
echo " [2/3] 使用 llc 编译为汇编 (超时 ${LLC_TIMEOUT}s)..."
timeout -s KILL ${LLC_TIMEOUT} "${LLC_CMD}" -march=riscv64 -mcpu=generic-rv64 -mattr=+m,+a,+f,+d,+c -filetype=asm "${ir_file}" -o "${assembly_file}"
if [ $? -ne 0 ]; then echo -e "\e[31m错误: llc 编译失败或超时。\e[0m"; compilation_ok=0; fi
fi
if [ "$compilation_ok" -eq 1 ]; then
echo " [3/3] 使用 gcc 编译 (超时 ${GCC_TIMEOUT}s)..."
timeout -s KILL ${GCC_TIMEOUT} "${GCC_RISCV64}" "${assembly_file}" -o "${executable_file}" -L"${LIB_DIR}" -lsysy_riscv -static
if [ $? -ne 0 ]; then
echo -e "\e[31m错误: GCC 编译失败或超时。\e[0m"
is_passed=0
if [ $? -ne 0 ]; then echo -e "\e[31m错误: GCC 编译失败或超时。\e[0m"; compilation_ok=0; fi
fi
else # EXECUTE_MODE
# 路径2: sysyc -> gcc
echo " [1/2] 使用 sysyc 编译为汇编 (超时 ${SYSYC_TIMEOUT}s)..."
timeout -s KILL ${SYSYC_TIMEOUT} "${SYSYC}" -S "${sy_file}" ${OPTIMIZE_FLAG} -o "${assembly_file}"
if [ $? -ne 0 ]; then echo -e "\e[31m错误: SysY (汇编) 编译失败或超时。\e[0m"; compilation_ok=0; fi
if [ "$compilation_ok" -eq 1 ]; then
echo " [2/2] 使用 gcc 编译 (超时 ${GCC_TIMEOUT}s)..."
timeout -s KILL ${GCC_TIMEOUT} "${GCC_RISCV64}" "${assembly_file}" -o "${executable_file}" -L"${LIB_DIR}" -lsysy_riscv -static
if [ $? -ne 0 ]; then echo -e "\e[31m错误: GCC 编译失败或超时。\e[0m"; compilation_ok=0; fi
fi
fi
# 步骤 3: 执行与测试
if [ "$is_passed" -eq 1 ]; then
# 检查是自动化测试还是交互模式
# --- 执行与测试阶段 (公共逻辑) ---
if [ "$compilation_ok" -eq 1 ]; then
if [ -f "${input_file}" ] || [ -f "${output_reference_file}" ]; then
# --- 自动化测试模式 ---
echo " 检测到 .in/.out 文件,进入自动化测试模式..."
@@ -217,24 +253,26 @@ for sy_file in "${SY_FILES[@]}"; do
EXPECTED_RETURN_CODE="$LAST_LINE_TRIMMED"
EXPECTED_STDOUT_FILE="${TMP_DIR}/${base_name}.expected_stdout"
head -n -1 "${output_reference_file}" > "${EXPECTED_STDOUT_FILE}"
if [ "$ACTUAL_RETURN_CODE" -ne "$EXPECTED_RETURN_CODE" ]; then echo -e "\e[31m 返回码测试失败: 期望 ${EXPECTED_RETURN_CODE}, 实际 ${ACTUAL_RETURN_CODE}\e[0m"; is_passed=0; fi
ret_ok=1
if [ "$ACTUAL_RETURN_CODE" -ne "$EXPECTED_RETURN_CODE" ]; then echo -e "\e[31m 返回码测试失败: 期望 ${EXPECTED_RETURN_CODE}, 实际 ${ACTUAL_RETURN_CODE}\e[0m"; ret_ok=0; fi
out_ok=1
if ! diff -q <(tr -d '[:space:]' < "${output_actual_file}") <(tr -d '[:space:]' < "${EXPECTED_STDOUT_FILE}") >/dev/null 2>&1; then
echo -e "\e[31m 标准输出测试失败。\e[0m"
is_passed=0
echo -e "\e[31m 标准输出测试失败。\e[0m"; out_ok=0
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m--- 期望输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
display_file_content "${output_actual_file}" " \e[36m--- 实际输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
echo -e " \e[36m----------------\e[0m"
fi
if [ "$ret_ok" -eq 1 ] && [ "$out_ok" -eq 1 ]; then echo -e "\e[32m 返回码与标准输出测试成功。\e[0m"; else is_passed=0; fi
else
if diff -q <(tr -d '[:space:]' < "${output_actual_file}") <(tr -d '[:space:]' < "${output_reference_file}") >/dev/null 2>&1; then
echo -e "\e[32m 标准输出测试成功。\e[0m"
else
echo -e "\e[31m 标准输出测试失败。\e[0m"
is_passed=0
echo -e "\e[31m 标准输出测试失败。\e[0m"; is_passed=0
display_file_content "${output_reference_file}" " \e[36m--- 期望输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
display_file_content "${output_actual_file}" " \e[36m--- 实际输出 ---\e[0m" "${MAX_OUTPUT_LINES}"
echo -e " \e[36m----------------\e[0m"
fi
fi
else
@@ -243,20 +281,16 @@ for sy_file in "${SY_FILES[@]}"; do
fi
else
# --- 交互模式 ---
echo -e "\e[33m"
echo " **********************************************************"
echo " ** 未找到 .in 或 .out 文件,进入交互模式。 **"
echo " ** 程序即将运行,你可以直接在终端中输入。 **"
echo " ** 按下 Ctrl+D (EOF) 或以其他方式结束程序以继续。 **"
echo " **********************************************************"
echo -e "\e[0m"
echo -e "\e[33m\n 未找到 .in 或 .out 文件,进入交互模式...\e[0m"
"${QEMU_RISCV64}" "${executable_file}"
INTERACTIVE_RET_CODE=$?
echo -e "\e[33m\n 交互模式执行完毕,程序返回码: ${INTERACTIVE_RET_CODE}\e[0m"
echo " 注意: 交互模式的结果未经验证。"
echo -e "\e[33m\n 交互模式执行完毕,程序返回码: ${INTERACTIVE_RET_CODE} (此结果未经验证)\e[0m"
fi
else
is_passed=0
fi
# --- 状态总结 ---
if [ "$is_passed" -eq 1 ]; then
echo -e "\e[32m状态: 通过\e[0m"
((PASSED_CASES++))
@@ -267,20 +301,4 @@ for sy_file in "${SY_FILES[@]}"; do
done
# --- 打印最终总结 ---
echo "======================================================================"
echo "所有测试完成"
echo "测试通过率: [${PASSED_CASES}/${TOTAL_CASES}]"
if [ -n "$FAILED_CASES_LIST" ]; then
echo ""
echo -e "\e[31m未通过的测例:\e[0m"
echo -e "${FAILED_CASES_LIST}"
fi
echo "======================================================================"
if [ "$PASSED_CASES" -eq "$TOTAL_CASES" ]; then
exit 0
else
exit 1
fi
print_summary

View File

@@ -1,31 +1,41 @@
#!/bin/bash
# runit.sh - 用于编译和测试 SysY 程序的脚本
# 此脚本应该位于 mysysy/test_script/
# 此脚本应该位于 mysysy/script/
export ASAN_OPTIONS=detect_leaks=0
# 定义相对于脚本位置的目录
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
TESTDATA_DIR="${SCRIPT_DIR}/../testdata"
BUILD_BIN_DIR="${SCRIPT_DIR}/../build/bin"
LIB_DIR="${SCRIPT_DIR}/../lib"
# TMP_DIR="${SCRIPT_DIR}/tmp"
TMP_DIR="${SCRIPT_DIR}/tmp"
# 定义编译器和模拟器
SYSYC="${BUILD_BIN_DIR}/sysyc"
LLC_CMD="llc-19"
GCC_RISCV64="riscv64-linux-gnu-gcc"
QEMU_RISCV64="qemu-riscv64"
# --- 新增功能: 初始化变量 ---
# --- 状态变量 ---
EXECUTE_MODE=false
SYSYC_TIMEOUT=10 # sysyc 编译超时 (秒)
GCC_TIMEOUT=10 # gcc 编译超时 (秒)
EXEC_TIMEOUT=5 # qemu 执行超时 (秒)
MAX_OUTPUT_LINES=50 # 对比失败时显示的最大行数
TEST_SETS=() # 用于存储要运行的测试集
IR_EXECUTE_MODE=false
OPTIMIZE_FLAG=""
SYSYC_TIMEOUT=30
LLC_TIMEOUT=10
GCC_TIMEOUT=10
EXEC_TIMEOUT=30
MAX_OUTPUT_LINES=20
TEST_SETS=()
TOTAL_CASES=0
PASSED_CASES=0
FAILED_CASES_LIST="" # 用于存储未通过的测例列表
FAILED_CASES_LIST=""
INTERRUPTED=false # 新增:用于标记是否被中断
# =================================================================
# --- 函数定义 ---
# =================================================================
# 显示帮助信息的函数
show_help() {
@@ -33,30 +43,32 @@ show_help() {
echo "此脚本用于按文件名前缀数字升序编译和测试 .sy 文件。"
echo ""
echo "选项:"
echo " -e, --executable 编译为可执行文件并运行测试。"
echo " -e, --executable 编译为汇编并运行测试 (sysyc -> gcc -> qemu)。"
echo " -eir 通过IR编译为可执行文件并运行测试 (sysyc -> llc -> gcc -> qemu)。"
echo " -c, --clean 清理 'tmp' 目录下的所有生成文件。"
echo " -O1 启用 sysyc 的 -O1 优化。"
echo " -set [f|h|p|all]... 指定要运行的测试集 (functional, h_functional, performance)。可多选,默认为 all。"
echo " -sct N 设置 sysyc 编译超时为 N 秒 (默认: 10)。"
echo " -sct N 设置 sysyc 编译超时为 N 秒 (默认: 30)。"
echo " -lct N 设置 llc-19 编译超时为 N 秒 (默认: 10)。"
echo " -gct N 设置 gcc 交叉编译超时为 N 秒 (默认: 10)。"
echo " -et N 设置 qemu 执行超时为 N 秒 (默认: 5)。"
echo " -ml N, --max-lines N 当输出对比失败时,最多显示 N 行内容 (默认: 50)。"
echo " -et N 设置 qemu 执行超时为 N 秒 (默认: 30)。"
echo " -ml N, --max-lines N 当输出对比失败时,最多显示 N 行内容 (默认: 20)。"
echo " -h, --help 显示此帮助信息并退出。"
echo ""
echo "注意: 默认行为 (无 -e 或 -eir) 是将 .sy 文件同时编译为 .s (汇编) 和 .ll (IR),不执行。"
echo " 可在任何时候按 Ctrl+C 来中断测试并显示当前已完成的测例总结。"
}
# 显示文件内容并根据行数截断的函数
display_file_content() {
local file_path="$1"
local title="$2"
local max_lines="$3"
if [ ! -f "$file_path" ]; then
return
fi
if [ ! -f "$file_path" ]; then return; fi
echo -e "$title"
local line_count
line_count=$(wc -l < "$file_path")
if [ "$line_count" -gt "$max_lines" ]; then
head -n "$max_lines" "$file_path"
echo -e "\e[33m[... 输出已截断,共 ${line_count} 行 ...]\e[0m"
@@ -71,61 +83,90 @@ clean_tmp() {
rm -rf "${TMP_DIR}"/*
}
# 如果临时目录不存在,则创建它
# --- 新增:总结报告函数 ---
print_summary() {
echo "" # 确保从新的一行开始
echo "========================================"
if [ "$INTERRUPTED" = true ]; then
echo -e "\e[33m测试被中断。正在汇总已完成的结果...\e[0m"
else
echo "测试完成"
fi
local failed_count
if [ -n "$FAILED_CASES_LIST" ]; then
# `wc -l` 计算由换行符分隔的列表项数
failed_count=$(echo -e -n "${FAILED_CASES_LIST}" | wc -l)
else
failed_count=0
fi
local executed_count=$((PASSED_CASES + failed_count))
echo "测试结果: [通过: ${PASSED_CASES}, 失败: ${failed_count}, 已执行: ${executed_count}/${TOTAL_CASES}]"
if [ -n "$FAILED_CASES_LIST" ]; then
echo ""
echo -e "\e[31m未通过的测例:\e[0m"
# 使用 printf 保证原样输出
printf "%b" "${FAILED_CASES_LIST}"
fi
echo "========================================"
if [ "$failed_count" -gt 0 ]; then
exit 1
else
exit 0
fi
}
# --- 新增SIGINT 信号处理函数 ---
handle_sigint() {
INTERRUPTED=true
print_summary
}
# =================================================================
# --- 主逻辑开始 ---
# =================================================================
# --- 新增:设置 trap 来捕获 SIGINT ---
trap handle_sigint SIGINT
mkdir -p "${TMP_DIR}"
# 解析命令行参数
while [[ "$#" -gt 0 ]]; do
case "$1" in
-e|--executable)
EXECUTE_MODE=true
shift
;;
-c|--clean)
clean_tmp
exit 0
;;
-e|--executable) EXECUTE_MODE=true; shift ;;
-eir) IR_EXECUTE_MODE=true; shift ;;
-c|--clean) clean_tmp; exit 0 ;;
-O1) OPTIMIZE_FLAG="-O1"; shift ;;
-set)
shift # 移过 '-set'
# 消耗所有后续参数直到遇到下一个选项
while [[ "$#" -gt 0 && ! "$1" =~ ^- ]]; do
TEST_SETS+=("$1")
shift
done
;;
-sct)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then SYSYC_TIMEOUT="$2"; shift 2; else echo "错误: -sct 需要一个正整数参数。" >&2; exit 1; fi
;;
-gct)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then GCC_TIMEOUT="$2"; shift 2; else echo "错误: -gct 需要一个正整数参数。" >&2; exit 1; fi
;;
-et)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then EXEC_TIMEOUT="$2"; shift 2; else echo "错误: -et 需要一个正整数参数。" >&2; exit 1; fi
;;
-ml|--max-lines)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_LINES="$2"; shift 2; else echo "错误: --max-lines 需要一个正整数参数。" >&2; exit 1; fi
;;
-h|--help)
show_help
exit 0
;;
*)
echo "未知选项: $1"
show_help
exit 1
while [[ "$#" -gt 0 && ! "$1" =~ ^- ]]; do TEST_SETS+=("$1"); shift; done
;;
-sct) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then SYSYC_TIMEOUT="$2"; shift 2; else echo "错误: -sct 需要一个正整数参数。" >&2; exit 1; fi ;;
-lct) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then LLC_TIMEOUT="$2"; shift 2; else echo "错误: -lct 需要一个正整数参数。" >&2; exit 1; fi ;;
-gct) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then GCC_TIMEOUT="$2"; shift 2; else echo "错误: -gct 需要一个正整数参数。" >&2; exit 1; fi ;;
-et) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then EXEC_TIMEOUT="$2"; shift 2; else echo "错误: -et 需要一个正整数参数。" >&2; exit 1; fi ;;
-ml|--max-lines) if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then MAX_OUTPUT_LINES="$2"; shift 2; else echo "错误: --max-lines 需要一个正整数参数。" >&2; exit 1; fi ;;
-h|--help) show_help; exit 0 ;;
*) echo "未知选项: $1"; show_help; exit 1 ;;
esac
done
# --- 本次修改点: 根据 -set 参数构建查找路径 ---
if ${EXECUTE_MODE} && ${IR_EXECUTE_MODE}; then
echo -e "\e[31m错误: -e 和 -eir 选项不能同时使用。\e[0m" >&2
exit 1
fi
declare -A SET_MAP
SET_MAP[f]="functional"
SET_MAP[h]="h_functional"
SET_MAP[p]="performance"
SEARCH_PATHS=()
# 如果未指定测试集,或指定了 'all',则搜索所有目录
if [ ${#TEST_SETS[@]} -eq 0 ] || [[ " ${TEST_SETS[@]} " =~ " all " ]]; then
SEARCH_PATHS+=("${TESTDATA_DIR}")
else
@@ -138,23 +179,34 @@ else
done
fi
# 如果没有有效的搜索路径,则退出
if [ ${#SEARCH_PATHS[@]} -eq 0 ]; then
echo -e "\e[31m错误: 没有找到有效的测试集目录,测试中止。\e[0m"
exit 1
fi
echo "SysY 测试运行器启动..."
if [ -n "$OPTIMIZE_FLAG" ]; then echo "优化等级: ${OPTIMIZE_FLAG}"; fi
echo "输入目录: ${SEARCH_PATHS[@]}"
echo "临时目录: ${TMP_DIR}"
echo "执行模式: ${EXECUTE_MODE}"
if ${EXECUTE_MODE}; then
echo "超时设置: sysyc=${SYSYC_TIMEOUT}s, gcc=${GCC_TIMEOUT}s, qemu=${EXEC_TIMEOUT}s"
RUN_MODE_INFO=""
if ${IR_EXECUTE_MODE}; then
RUN_MODE_INFO="IR执行模式 (-eir)"
TIMEOUT_INFO="超时设置: sysyc=${SYSYC_TIMEOUT}s, llc=${LLC_TIMEOUT}s, gcc=${GCC_TIMEOUT}s, qemu=${EXEC_TIMEOUT}s"
elif ${EXECUTE_MODE}; then
RUN_MODE_INFO="直接执行模式 (-e)"
TIMEOUT_INFO="超时设置: sysyc=${SYSYC_TIMEOUT}s, gcc=${GCC_TIMEOUT}s, qemu=${EXEC_TIMEOUT}s"
else
RUN_MODE_INFO="编译模式 (默认)"
TIMEOUT_INFO="超时设置: sysyc=${SYSYC_TIMEOUT}s"
fi
echo "运行模式: ${RUN_MODE_INFO}"
echo "${TIMEOUT_INFO}"
if ${EXECUTE_MODE} || ${IR_EXECUTE_MODE}; then
echo "失败输出最大行数: ${MAX_OUTPUT_LINES}"
fi
echo ""
# 使用构建好的路径查找 .sy 文件并排序
sy_files=$(find "${SEARCH_PATHS[@]}" -name "*.sy" | sort -V)
if [ -z "$sy_files" ]; then
echo "在指定目录中未找到任何 .sy 文件。"
@@ -162,139 +214,229 @@ if [ -z "$sy_files" ]; then
fi
TOTAL_CASES=$(echo "$sy_files" | wc -w)
# --- 修复: 使用 here-string (<<<) 代替管道 (|) 来避免子 shell 问题 ---
while IFS= read -r sy_file; do
is_passed=1 # 1 表示通过, 0 表示失败
is_passed=0 # 0 表示失败, 1 表示通过
relative_path_no_ext=$(realpath --relative-to="${TESTDATA_DIR}" "${sy_file%.*}")
output_base_name=$(echo "${relative_path_no_ext}" | tr '/' '_')
assembly_file="${TMP_DIR}/${output_base_name}_sysyc_riscv64.s"
executable_file="${TMP_DIR}/${output_base_name}_sysyc_riscv64"
assembly_file_S="${TMP_DIR}/${output_base_name}_sysyc_S.s"
executable_file_S="${TMP_DIR}/${output_base_name}_sysyc_S"
output_actual_file_S="${TMP_DIR}/${output_base_name}_sysyc_S.actual_out"
ir_file="${TMP_DIR}/${output_base_name}_sysyc_ir.ll"
assembly_file_from_ir="${TMP_DIR}/${output_base_name}_from_ir.s"
executable_file_from_ir="${TMP_DIR}/${output_base_name}_from_ir"
output_actual_file_from_ir="${TMP_DIR}/${output_base_name}_from_ir.actual_out"
input_file="${sy_file%.*}.in"
output_reference_file="${sy_file%.*}.out"
output_actual_file="${TMP_DIR}/${output_base_name}_sysyc_riscv64.actual_out"
echo "正在处理: $(basename "$sy_file") (路径: ${relative_path_no_ext}.sy)"
# 步骤 1: 使用 sysyc 编译 .sy 到 .s
echo " 使用 sysyc 编译 (超时 ${SYSYC_TIMEOUT}s)..."
timeout -s KILL ${SYSYC_TIMEOUT} "${SYSYC}" -S "${sy_file}" -o "${assembly_file}"
# --- 模式 1: IR 执行模式 (-eir) ---
if ${IR_EXECUTE_MODE}; then
step_failed=0
test_logic_passed=0
echo " [1/4] 使用 sysyc 编译为 IR (超时 ${SYSYC_TIMEOUT}s)..."
timeout -s KILL ${SYSYC_TIMEOUT} "${SYSYC}" -s ir "${sy_file}" -o "${ir_file}" ${OPTIMIZE_FLAG}
SYSYC_STATUS=$?
if [ $SYSYC_STATUS -eq 124 ]; then
echo -e "\e[31m错误: SysY 编译 ${sy_file} 超时\e[0m"
is_passed=0
elif [ $SYSYC_STATUS -ne 0 ]; then
echo -e "\e[31m错误: SysY 编译 ${sy_file} 失败,退出码: ${SYSYC_STATUS}\e[0m"
is_passed=0
if [ $SYSYC_STATUS -ne 0 ]; then
[ $SYSYC_STATUS -eq 124 ] && echo -e "\e[31m错误: SysY (IR) 编译超时\e[0m" || echo -e "\e[31m错误: SysY (IR) 编译失败,退出码: ${SYSYC_STATUS}\e[0m"
step_failed=1
fi
# 只有当 EXECUTE_MODE 为 true 且上一步成功时才继续
if ${EXECUTE_MODE} && [ "$is_passed" -eq 1 ]; then
# 步骤 2: 使用 riscv64-linux-gnu-gcc 编译 .s 到可执行文件
echo " 使用 gcc 编译 (超时 ${GCC_TIMEOUT}s)..."
timeout -s KILL ${GCC_TIMEOUT} "${GCC_RISCV64}" "${assembly_file}" -o "${executable_file}" -L"${LIB_DIR}" -lsysy_riscv -static
if [ "$step_failed" -eq 0 ]; then
echo " [2/4] 使用 llc-19 编译为汇编 (超时 ${LLC_TIMEOUT}s)..."
timeout -s KILL ${LLC_TIMEOUT} "${LLC_CMD}" -march=riscv64 -mcpu=generic-rv64 -mattr=+m,+a,+f,+d,+c -filetype=asm "${ir_file}" -o "${assembly_file_from_ir}"
LLC_STATUS=$?
if [ $LLC_STATUS -ne 0 ]; then
[ $LLC_STATUS -eq 124 ] && echo -e "\e[31m错误: llc-19 编译超时\e[0m" || echo -e "\e[31m错误: llc-19 编译失败,退出码: ${LLC_STATUS}\e[0m"
step_failed=1
fi
fi
if [ "$step_failed" -eq 0 ]; then
echo " [3/4] 使用 gcc 编译 (超时 ${GCC_TIMEOUT}s)..."
timeout -s KILL ${GCC_TIMEOUT} "${GCC_RISCV64}" "${assembly_file_from_ir}" -o "${executable_file_from_ir}" -L"${LIB_DIR}" -lsysy_riscv -static
GCC_STATUS=$?
if [ $GCC_STATUS -eq 124 ]; then
echo -e "\e[31m错误: GCC 编译 ${assembly_file} 超时\e[0m"
is_passed=0
elif [ $GCC_STATUS -ne 0 ]; then
echo -e "\e[31m错误: GCC 编译 ${assembly_file} 失败,退出码: ${GCC_STATUS}\e[0m"
is_passed=0
if [ $GCC_STATUS -ne 0 ]; then
[ $GCC_STATUS -eq 124 ] && echo -e "\e[31m错误: GCC 编译超时\e[0m" || echo -e "\e[31m错误: GCC 编译失败,退出码: ${GCC_STATUS}\e[0m"
step_failed=1
fi
elif ! ${EXECUTE_MODE}; then
echo " 跳过执行模式。仅生成汇编文件。"
if [ "$is_passed" -eq 1 ]; then
((PASSED_CASES++))
else
FAILED_CASES_LIST+="${relative_path_no_ext}.sy\n"
fi
echo ""
continue
fi
# 步骤 3, 4, 5: 只有当编译都成功时才执行
if [ "$is_passed" -eq 1 ]; then
echo " 正在执行 (超时 ${EXEC_TIMEOUT}s)..."
exec_cmd="${QEMU_RISCV64} \"${executable_file}\""
if [ -f "${input_file}" ]; then
exec_cmd+=" < \"${input_file}\""
fi
exec_cmd+=" > \"${output_actual_file}\""
if [ "$step_failed" -eq 0 ]; then
echo " [4/4] 正在执行 (超时 ${EXEC_TIMEOUT}s)..."
exec_cmd="${QEMU_RISCV64} \"${executable_file_from_ir}\""
[ -f "${input_file}" ] && exec_cmd+=" < \"${input_file}\""
exec_cmd+=" > \"${output_actual_file_from_ir}\""
eval "timeout -s KILL ${EXEC_TIMEOUT} ${exec_cmd}"
ACTUAL_RETURN_CODE=$?
if [ "$ACTUAL_RETURN_CODE" -eq 124 ]; then
echo -e "\e[31m 执行超时: ${sy_file} 运行超过 ${EXEC_TIMEOUT} 秒\e[0m"
is_passed=0
echo -e "\e[31m 执行超时: 运行超过 ${EXEC_TIMEOUT} 秒\e[0m"
else
if [ -f "${output_reference_file}" ]; then
LAST_LINE_TRIMMED=$(tail -n 1 "${output_reference_file}" | tr -d '[:space:]')
test_logic_passed=1
if [[ "$LAST_LINE_TRIMMED" =~ ^[-+]?[0-9]+$ ]]; then
EXPECTED_RETURN_CODE="$LAST_LINE_TRIMMED"
EXPECTED_STDOUT_FILE="${TMP_DIR}/${output_base_name}_sysyc_riscv64.expected_stdout"
EXPECTED_STDOUT_FILE="${TMP_DIR}/${output_base_name}_from_ir.expected_stdout"
head -n -1 "${output_reference_file}" > "${EXPECTED_STDOUT_FILE}"
if [ "$ACTUAL_RETURN_CODE" -eq "$EXPECTED_RETURN_CODE" ]; then
echo -e "\e[32m 返回码测试成功: (${ACTUAL_RETURN_CODE}) 与期望值 (${EXPECTED_RETURN_CODE}) 匹配\e[0m"
else
echo -e "\e[31m 返回码测试失败: 期望: ${EXPECTED_RETURN_CODE}, 实际: ${ACTUAL_RETURN_CODE}\e[0m"
is_passed=0
test_logic_passed=0
fi
if ! diff -q <(tr -d '[:space:]' < "${output_actual_file}") <(tr -d '[:space:]' < "${EXPECTED_STDOUT_FILE}") >/dev/null 2>&1; then
if diff -q <(tr -d '[:space:]' < "${output_actual_file_from_ir}") <(tr -d '[:space:]' < "${EXPECTED_STDOUT_FILE}") >/dev/null 2>&1; then
[ "$test_logic_passed" -eq 1 ] && echo -e "\e[32m 标准输出测试成功\e[0m"
else
echo -e "\e[31m 标准输出测试失败\e[0m"
is_passed=0
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
display_file_content "${output_actual_file}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
echo -e " \e[36m------------------------------\e[0m"
display_file_content "${output_actual_file_from_ir}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
test_logic_passed=0
fi
else
if [ $ACTUAL_RETURN_CODE -ne 0 ]; then
echo -e "\e[33m警告: 程序以非零状态 ${ACTUAL_RETURN_CODE} 退出 (纯输出比较模式)。\e[0m"
fi
if diff -q <(tr -d '[:space:]' < "${output_actual_file}") <(tr -d '[:space:]' < "${output_reference_file}") >/dev/null 2>&1; then
if [ $ACTUAL_RETURN_CODE -ne 0 ]; then echo -e "\e[33m警告: 程序以非零状态 ${ACTUAL_RETURN_CODE} 退出 (纯输出比较模式)。\e[0m"; fi
if diff -q <(tr -d '[:space:]' < "${output_actual_file_from_ir}") <(tr -d '[:space:]' < "${output_reference_file}") >/dev/null 2>&1; then
echo -e "\e[32m 成功: 输出与参考输出匹配\e[0m"
else
echo -e "\e[31m 失败: 输出不匹配\e[0m"
is_passed=0
display_file_content "${output_reference_file}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
display_file_content "${output_actual_file}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
echo -e " \e[36m------------------------------\e[0m"
display_file_content "${output_actual_file_from_ir}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
test_logic_passed=0
fi
fi
else
echo " 无参考输出文件。程序返回码: ${ACTUAL_RETURN_CODE}"
test_logic_passed=1
fi
fi
fi
[ "$step_failed" -eq 0 ] && [ "$test_logic_passed" -eq 1 ] && is_passed=1
# --- 模式 2: 直接执行模式 (-e) ---
elif ${EXECUTE_MODE}; then
step_failed=0
test_logic_passed=0
echo " [1/3] 使用 sysyc 编译为汇编 (超时 ${SYSYC_TIMEOUT}s)..."
timeout -s KILL ${SYSYC_TIMEOUT} "${SYSYC}" -S "${sy_file}" -o "${assembly_file_S}" ${OPTIMIZE_FLAG}
SYSYC_STATUS=$?
if [ $SYSYC_STATUS -ne 0 ]; then
[ $SYSYC_STATUS -eq 124 ] && echo -e "\e[31m错误: SysY (汇编) 编译超时\e[0m" || echo -e "\e[31m错误: SysY (汇编) 编译失败,退出码: ${SYSYC_STATUS}\e[0m"
step_failed=1
fi
if [ "$step_failed" -eq 0 ]; then
echo " [2/3] 使用 gcc 编译 (超时 ${GCC_TIMEOUT}s)..."
timeout -s KILL ${GCC_TIMEOUT} "${GCC_RISCV64}" "${assembly_file_S}" -o "${executable_file_S}" -L"${LIB_DIR}" -lsysy_riscv -static
GCC_STATUS=$?
if [ $GCC_STATUS -ne 0 ]; then
[ $GCC_STATUS -eq 124 ] && echo -e "\e[31m错误: GCC 编译超时\e[0m" || echo -e "\e[31m错误: GCC 编译失败,退出码: ${GCC_STATUS}\e[0m"
step_failed=1
fi
fi
if [ "$step_failed" -eq 0 ]; then
echo " [3/3] 正在执行 (超时 ${EXEC_TIMEOUT}s)..."
exec_cmd="${QEMU_RISCV64} \"${executable_file_S}\""
[ -f "${input_file}" ] && exec_cmd+=" < \"${input_file}\""
exec_cmd+=" > \"${output_actual_file_S}\""
eval "timeout -s KILL ${EXEC_TIMEOUT} ${exec_cmd}"
ACTUAL_RETURN_CODE=$?
if [ "$ACTUAL_RETURN_CODE" -eq 124 ]; then
echo -e "\e[31m 执行超时: 运行超过 ${EXEC_TIMEOUT} 秒\e[0m"
else
if [ -f "${output_reference_file}" ]; then
LAST_LINE_TRIMMED=$(tail -n 1 "${output_reference_file}" | tr -d '[:space:]')
test_logic_passed=1
if [[ "$LAST_LINE_TRIMMED" =~ ^[-+]?[0-9]+$ ]]; then
EXPECTED_RETURN_CODE="$LAST_LINE_TRIMMED"
EXPECTED_STDOUT_FILE="${TMP_DIR}/${output_base_name}_sysyc_S.expected_stdout"
head -n -1 "${output_reference_file}" > "${EXPECTED_STDOUT_FILE}"
if [ "$ACTUAL_RETURN_CODE" -eq "$EXPECTED_RETURN_CODE" ]; then
echo -e "\e[32m 返回码测试成功: (${ACTUAL_RETURN_CODE}) 与期望值 (${EXPECTED_RETURN_CODE}) 匹配\e[0m"
else
echo -e "\e[31m 返回码测试失败: 期望: ${EXPECTED_RETURN_CODE}, 实际: ${ACTUAL_RETURN_CODE}\e[0m"
test_logic_passed=0
fi
if diff -q <(tr -d '[:space:]' < "${output_actual_file_S}") <(tr -d '[:space:]' < "${EXPECTED_STDOUT_FILE}") >/dev/null 2>&1; then
[ "$test_logic_passed" -eq 1 ] && echo -e "\e[32m 标准输出测试成功\e[0m"
else
echo -e "\e[31m 标准输出测试失败\e[0m"
display_file_content "${EXPECTED_STDOUT_FILE}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
display_file_content "${output_actual_file_S}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
test_logic_passed=0
fi
else
if [ $ACTUAL_RETURN_CODE -ne 0 ]; then echo -e "\e[33m警告: 程序以非零状态 ${ACTUAL_RETURN_CODE} 退出 (纯输出比较模式)。\e[0m"; fi
if diff -q <(tr -d '[:space:]' < "${output_actual_file_S}") <(tr -d '[:space:]' < "${output_reference_file}") >/dev/null 2>&1; then
echo -e "\e[32m 成功: 输出与参考输出匹配\e[0m"
else
echo -e "\e[31m 失败: 输出不匹配\e[0m"
display_file_content "${output_reference_file}" " \e[36m---------- 期望输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
display_file_content "${output_actual_file_S}" " \e[36m---------- 实际输出 ----------\e[0m" "${MAX_OUTPUT_LINES}"
test_logic_passed=0
fi
fi
else
echo " 无参考输出文件。程序返回码: ${ACTUAL_RETURN_CODE}"
test_logic_passed=1
fi
fi
fi
[ "$step_failed" -eq 0 ] && [ "$test_logic_passed" -eq 1 ] && is_passed=1
# --- 模式 3: 默认编译模式 ---
else
s_compile_ok=0
ir_compile_ok=0
echo " [1/2] 使用 sysyc 编译为汇编 (超时 ${SYSYC_TIMEOUT}s)..."
timeout -s KILL ${SYSYC_TIMEOUT} "${SYSYC}" -S "${sy_file}" -o "${assembly_file_S}" ${OPTIMIZE_FLAG}
SYSYC_S_STATUS=$?
if [ $SYSYC_S_STATUS -eq 0 ]; then
s_compile_ok=1
echo -e " \e[32m-> ${assembly_file_S} [成功]\e[0m"
else
[ $SYSYC_S_STATUS -eq 124 ] && echo -e " \e[31m-> [编译超时]\e[0m" || echo -e " \e[31m-> [编译失败, 退出码: ${SYSYC_S_STATUS}]\e[0m"
fi
echo " [2/2] 使用 sysyc 编译为 IR (超时 ${SYSYC_TIMEOUT}s)..."
timeout -s KILL ${SYSYC_TIMEOUT} "${SYSYC}" -s ir "${sy_file}" -o "${ir_file}" ${OPTIMIZE_FLAG}
SYSYC_IR_STATUS=$?
if [ $SYSYC_IR_STATUS -eq 0 ]; then
ir_compile_ok=1
echo -e " \e[32m-> ${ir_file} [成功]\e[0m"
else
[ $SYSYC_IR_STATUS -eq 124 ] && echo -e " \e[31m-> [编译超时]\e[0m" || echo -e " \e[31m-> [编译失败, 退出码: ${SYSYC_IR_STATUS}]\e[0m"
fi
if [ "$s_compile_ok" -eq 1 ] && [ "$ir_compile_ok" -eq 1 ]; then
is_passed=1
fi
fi
# --- 统计结果 ---
if [ "$is_passed" -eq 1 ]; then
((PASSED_CASES++))
else
# 确保 FAILED_CASES_LIST 的每一项都以换行符结尾
FAILED_CASES_LIST+="${relative_path_no_ext}.sy\n"
fi
echo ""
done <<< "$sy_files"
echo "========================================"
echo "测试完成"
echo "测试通过率: [${PASSED_CASES}/${TOTAL_CASES}]"
if [ -n "$FAILED_CASES_LIST" ]; then
echo ""
echo -e "\e[31m未通过的测例:\e[0m"
echo -e "${FAILED_CASES_LIST}"
fi
echo "========================================"
if [ "$PASSED_CASES" -eq "$TOTAL_CASES" ]; then
exit 0
else
exit 1
fi
# --- 修改:调用总结函数 ---
print_summary

View File

@@ -8,9 +8,11 @@ add_library(riscv64_backend_lib STATIC
Handler/CalleeSavedHandler.cpp
Handler/LegalizeImmediates.cpp
Handler/PrologueEpilogueInsertion.cpp
Handler/EliminateFrameIndices.cpp
Optimize/Peephole.cpp
Optimize/PostRA_Scheduler.cpp
Optimize/PreRA_Scheduler.cpp
Optimize/DivStrengthReduction.cpp
)
# 包含后端模块所需的头文件路径

View File

@@ -8,11 +8,6 @@ namespace sysy {
char CalleeSavedHandler::ID = 0;
// 辅助函数,用于判断一个物理寄存器是否为浮点寄存器
static bool is_fp_reg(PhysicalReg reg) {
return reg >= PhysicalReg::F0 && reg <= PhysicalReg::F31;
}
bool CalleeSavedHandler::runOnFunction(Function *F, AnalysisManager& AM) {
// This pass works on MachineFunction level, not IR level
return false;
@@ -20,114 +15,37 @@ bool CalleeSavedHandler::runOnFunction(Function *F, AnalysisManager& AM) {
void CalleeSavedHandler::runOnMachineFunction(MachineFunction* mfunc) {
StackFrameInfo& frame_info = mfunc->getFrameInfo();
std::set<PhysicalReg> used_callee_saved;
// 1. 扫描所有指令找出被使用的callee-saved寄存器
// 这个Pass在RegAlloc之后运行所以可以访问到物理寄存器
for (auto& mbb : mfunc->getBlocks()) {
for (auto& instr : mbb->getInstructions()) {
for (auto& op : instr->getOperands()) {
auto check_and_insert_reg = [&](RegOperand* reg_op) {
if (reg_op && !reg_op->isVirtual()) {
PhysicalReg preg = reg_op->getPReg();
// 检查整数 s1-s11
if (preg >= PhysicalReg::S1 && preg <= PhysicalReg::S11) {
used_callee_saved.insert(preg);
}
// 检查浮点 fs0-fs11 (f8,f9,f18-f27)
else if ((preg >= PhysicalReg::F8 && preg <= PhysicalReg::F9) || (preg >= PhysicalReg::F18 && preg <= PhysicalReg::F27)) {
used_callee_saved.insert(preg);
}
}
};
if (op->getKind() == MachineOperand::KIND_REG) {
check_and_insert_reg(static_cast<RegOperand*>(op.get()));
} else if (op->getKind() == MachineOperand::KIND_MEM) {
check_and_insert_reg(static_cast<MemOperand*>(op.get())->getBase());
}
}
}
}
const std::set<PhysicalReg>& used_callee_saved = frame_info.used_callee_saved_regs;
if (used_callee_saved.empty()) {
frame_info.callee_saved_size = 0;
frame_info.callee_saved_regs_to_store.clear();
return;
}
// 2. 计算并更新 frame_info
frame_info.callee_saved_size = used_callee_saved.size() * 8;
// 为了布局确定性和恢复顺序一致,对寄存器排序
std::vector<PhysicalReg> sorted_regs(used_callee_saved.begin(), used_callee_saved.end());
std::sort(sorted_regs.begin(), sorted_regs.end());
// 3. 在函数序言中插入保存指令
MachineBasicBlock* entry_block = mfunc->getBlocks().front().get();
auto& entry_instrs = entry_block->getInstructions();
// 插入点在函数入口标签之后,或者就是最开始
auto insert_pos = entry_instrs.begin();
if (!entry_instrs.empty() && entry_instrs.front()->getOpcode() == RVOpcodes::LABEL) {
insert_pos = std::next(insert_pos);
// 1. 计算被调用者保存寄存器所需的总空间大小
// s0 总是由 PEI Pass 单独处理,这里不计入大小,但要确保它在列表中
int size = 0;
std::set<PhysicalReg> regs_to_save = used_callee_saved;
if (regs_to_save.count(PhysicalReg::S0)) {
regs_to_save.erase(PhysicalReg::S0);
}
size = regs_to_save.size() * 8; // 每个寄存器占8字节 (64-bit)
frame_info.callee_saved_size = size;
std::vector<std::unique_ptr<MachineInstr>> save_instrs;
// [关键] 从局部变量区域之后开始分配空间
int current_offset = - (16 + frame_info.locals_size);
// 2. 创建一个有序的、需要保存的寄存器列表,以便后续 Pass 确定地生成代码
// s0 不应包含在此列表中,因为它由 PEI Pass 特殊处理
std::vector<PhysicalReg> sorted_regs(regs_to_save.begin(), regs_to_save.end());
std::sort(sorted_regs.begin(), sorted_regs.end(), [](PhysicalReg a, PhysicalReg b){
return static_cast<int>(a) < static_cast<int>(b);
});
frame_info.callee_saved_regs_to_store = sorted_regs;
for (PhysicalReg reg : sorted_regs) {
current_offset -= 8;
RVOpcodes save_op = is_fp_reg(reg) ? RVOpcodes::FSD : RVOpcodes::SD;
auto save_instr = std::make_unique<MachineInstr>(save_op);
save_instr->addOperand(std::make_unique<RegOperand>(reg));
save_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0), // 基址为帧指针 s0
std::make_unique<ImmOperand>(current_offset)
));
save_instrs.push_back(std::move(save_instr));
}
if (!save_instrs.empty()) {
entry_instrs.insert(insert_pos,
std::make_move_iterator(save_instrs.begin()),
std::make_move_iterator(save_instrs.end()));
}
// 4. 在函数结尾ret之前插入恢复指令
for (auto& mbb : mfunc->getBlocks()) {
for (auto it = mbb->getInstructions().begin(); it != mbb->getInstructions().end(); ++it) {
if ((*it)->getOpcode() == RVOpcodes::RET) {
std::vector<std::unique_ptr<MachineInstr>> restore_instrs;
// [关键] 使用与保存时完全相同的逻辑来计算偏移量
current_offset = - (16 + frame_info.locals_size);
for (PhysicalReg reg : sorted_regs) {
current_offset -= 8;
RVOpcodes restore_op = is_fp_reg(reg) ? RVOpcodes::FLD : RVOpcodes::LD;
auto restore_instr = std::make_unique<MachineInstr>(restore_op);
restore_instr->addOperand(std::make_unique<RegOperand>(reg));
restore_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(current_offset)
));
restore_instrs.push_back(std::move(restore_instr));
}
if (!restore_instrs.empty()) {
mbb->getInstructions().insert(it,
std::make_move_iterator(restore_instrs.begin()),
std::make_move_iterator(restore_instrs.end()));
}
goto next_block_label;
}
}
next_block_label:;
}
// 3. 更新栈帧总大小。
// 这是初步计算PEI Pass 会进行最终的对齐。
frame_info.total_size = frame_info.locals_size +
frame_info.spill_size +
frame_info.callee_saved_size;
}
} // namespace sysy

View File

@@ -0,0 +1,235 @@
#include "EliminateFrameIndices.h"
#include "RISCv64ISel.h"
#include <cassert>
#include <vector>
namespace sysy {
// getTypeSizeInBytes 是一个通用辅助函数,保持不变
unsigned EliminateFrameIndicesPass::getTypeSizeInBytes(Type* type) {
if (!type) {
assert(false && "Cannot get size of a null type.");
return 0;
}
switch (type->getKind()) {
case Type::kInt:
case Type::kFloat:
return 4;
case Type::kPointer:
return 8;
case Type::kArray: {
auto arrayType = type->as<ArrayType>();
return arrayType->getNumElements() * getTypeSizeInBytes(arrayType->getElementType());
}
default:
assert(false && "Unsupported type for size calculation.");
return 0;
}
}
void EliminateFrameIndicesPass::runOnMachineFunction(MachineFunction* mfunc) {
StackFrameInfo& frame_info = mfunc->getFrameInfo();
Function* F = mfunc->getFunc();
RISCv64ISel* isel = mfunc->getISel();
// 在这里处理栈传递的参数以便在寄存器分配前就将数据流显式化修复溢出逻辑的BUG。
// 2. 只为局部变量(AllocaInst)分配栈空间和计算偏移量
// 局部变量从 s0 下方(负偏移量)开始分配,紧接着为 ra 和 s0 预留的16字节之后
int local_var_offset = 16;
if(F) { // 确保函数指针有效
for (auto& bb : F->getBasicBlocks()) {
for (auto& inst : bb->getInstructions()) {
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
Type* allocated_type = alloca->getType()->as<PointerType>()->getBaseType();
int size = getTypeSizeInBytes(allocated_type);
// 优化栈帧大小对于大数组使用4字节对齐小对象使用8字节对齐
if (size >= 256) { // 大数组优化
size = (size + 3) & ~3; // 4字节对齐
} else {
size = (size + 7) & ~7; // 8字节对齐
}
if (size == 0) size = 4; // 最小4字节
local_var_offset += size;
unsigned alloca_vreg = isel->getVReg(alloca);
// 局部变量使用相对于s0的负向偏移
frame_info.alloca_offsets[alloca_vreg] = -local_var_offset;
}
}
}
}
// 记录仅由AllocaInst分配的局部变量的总大小
frame_info.locals_size = local_var_offset - 16;
// 记录局部变量区域分配结束的最终偏移量
frame_info.locals_end_offset = -local_var_offset;
// 在函数入口为所有栈传递的参数插入load指令
// 这个步骤至关重要它在寄存器分配之前为这些参数的vreg创建了明确的“定义(def)”指令。
// 这解决了在高寄存器压力下当这些vreg被溢出时`rewriteProgram`找不到其定义点而崩溃的问题。
if (F && isel && !mfunc->getBlocks().empty()) {
MachineBasicBlock* entry_block = mfunc->getBlocks().front().get();
std::vector<std::unique_ptr<MachineInstr>> arg_load_instrs;
// 步骤 3.1: 生成所有加载栈参数的指令,暂存起来
int arg_idx = 0;
for (Argument* arg : F->getArguments()) {
// 根据ABI前8个整型/指针参数通过寄存器传递,这里只处理超出部分。
if (arg_idx >= 8) {
// 计算参数在调用者栈帧中的位置该位置相对于被调用者的帧指针s0是正向偏移。
// 第9个参数(arg_idx=8)位于 0(s0)第10个(arg_idx=9)位于 8(s0),以此类推。
int offset = (arg_idx - 8) * 8;
unsigned arg_vreg = isel->getVReg(arg);
Type* arg_type = arg->getType();
// 根据参数类型选择正确的加载指令
RVOpcodes load_op;
if (arg_type->isFloat()) {
load_op = RVOpcodes::FLW; // 单精度浮点
} else if (arg_type->isPointer()) {
load_op = RVOpcodes::LD; // 64位指针
} else {
load_op = RVOpcodes::LW; // 32位整数
}
// 创建加载指令: lw/ld/flw vreg, offset(s0)
auto load_instr = std::make_unique<MachineInstr>(load_op);
load_instr->addOperand(std::make_unique<RegOperand>(arg_vreg));
load_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0), // 基址为帧指针
std::make_unique<ImmOperand>(offset)
));
arg_load_instrs.push_back(std::move(load_instr));
}
arg_idx++;
}
//仅当有需要加载的栈参数时,才执行插入逻辑
if (!arg_load_instrs.empty()) {
auto& entry_instrs = entry_block->getInstructions();
auto insertion_point = entry_instrs.begin(); // 默认插入点为块的开头
auto last_arg_save_it = entry_instrs.end();
// 步骤 3.2: 寻找一个安全的插入点。
// 遍历入口块的指令,找到最后一条保存“寄存器传递参数”的伪指令。
// 这样可以确保我们在所有 a0-a7 参数被保存之后,才执行可能覆盖它们的加载指令。
for (auto it = entry_instrs.begin(); it != entry_instrs.end(); ++it) {
MachineInstr* instr = it->get();
// 寻找代表保存参数到栈的伪指令
if (instr->getOpcode() == RVOpcodes::FRAME_STORE_W ||
instr->getOpcode() == RVOpcodes::FRAME_STORE_D ||
instr->getOpcode() == RVOpcodes::FRAME_STORE_F) {
// 检查被保存的值是否是寄存器参数 (arg_no < 8)
auto& operands = instr->getOperands();
if (operands.empty() || operands[0]->getKind() != MachineOperand::KIND_REG) continue;
unsigned src_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
Value* ir_value = isel->getVRegValueMap().count(src_vreg) ? isel->getVRegValueMap().at(src_vreg) : nullptr;
if (auto ir_arg = dynamic_cast<Argument*>(ir_value)) {
if (ir_arg->getIndex() < 8) {
last_arg_save_it = it; // 找到了一个保存寄存器参数的指令,更新位置
}
}
}
}
// 如果找到了这样的保存指令,我们的插入点就在它之后
if (last_arg_save_it != entry_instrs.end()) {
insertion_point = std::next(last_arg_save_it);
}
// 步骤 3.3: 在计算出的安全位置,一次性插入所有新创建的参数加载指令
entry_instrs.insert(insertion_point,
std::make_move_iterator(arg_load_instrs.begin()),
std::make_move_iterator(arg_load_instrs.end()));
}
}
// 4. 遍历所有机器指令,将访问局部变量的伪指令展开为真实指令
for (auto& mbb : mfunc->getBlocks()) {
std::vector<std::unique_ptr<MachineInstr>> new_instructions;
for (auto& instr_ptr : mbb->getInstructions()) {
RVOpcodes opcode = instr_ptr->getOpcode();
if (opcode == RVOpcodes::FRAME_LOAD_W || opcode == RVOpcodes::FRAME_LOAD_D || opcode == RVOpcodes::FRAME_LOAD_F) {
RVOpcodes real_load_op;
if (opcode == RVOpcodes::FRAME_LOAD_W) real_load_op = RVOpcodes::LW;
else if (opcode == RVOpcodes::FRAME_LOAD_D) real_load_op = RVOpcodes::LD;
else real_load_op = RVOpcodes::FLW;
auto& operands = instr_ptr->getOperands();
unsigned dest_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg(Type::getPointerType(Type::getIntType()));
// 展开为: addi addr_vreg, s0, offset
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
// 展开为: lw/ld/flw dest_vreg, 0(addr_vreg)
auto load_instr = std::make_unique<MachineInstr>(real_load_op);
load_instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
load_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(load_instr));
} else if (opcode == RVOpcodes::FRAME_STORE_W || opcode == RVOpcodes::FRAME_STORE_D || opcode == RVOpcodes::FRAME_STORE_F) {
RVOpcodes real_store_op;
if (opcode == RVOpcodes::FRAME_STORE_W) real_store_op = RVOpcodes::SW;
else if (opcode == RVOpcodes::FRAME_STORE_D) real_store_op = RVOpcodes::SD;
else real_store_op = RVOpcodes::FSW;
auto& operands = instr_ptr->getOperands();
unsigned src_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg(Type::getPointerType(Type::getIntType()));
// 展开为: addi addr_vreg, s0, offset
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
// 展开为: sw/sd/fsw src_vreg, 0(addr_vreg)
auto store_instr = std::make_unique<MachineInstr>(real_store_op);
store_instr->addOperand(std::make_unique<RegOperand>(src_vreg));
store_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(store_instr));
} else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_ADDR) {
auto& operands = instr_ptr->getOperands();
unsigned dest_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
// 将 `frame_addr rd, rs` 展开为 `addi rd, s0, offset`
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(dest_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
} else {
new_instructions.push_back(std::move(instr_ptr));
}
}
mbb->getInstructions() = std::move(new_instructions);
}
}
} // namespace sysy

View File

@@ -1,17 +1,22 @@
#include "PrologueEpilogueInsertion.h"
#include "RISCv64LLIR.h" // 假设包含了 PhysicalReg, RVOpcodes 等定义
#include "RISCv64ISel.h"
#include "RISCv64RegAlloc.h" // 需要访问RegAlloc的结果
#include <algorithm>
#include <vector>
#include <set>
namespace sysy {
char PrologueEpilogueInsertionPass::ID = 0;
void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc) {
StackFrameInfo& frame_info = mfunc->getFrameInfo();
Function* F = mfunc->getFunc();
RISCv64ISel* isel = mfunc->getISel();
// 1. 清理 KEEPALIVE 伪指令
for (auto& mbb : mfunc->getBlocks()) {
auto& instrs = mbb->getInstructions();
// 使用标准的 Erase-Remove Idiom 来删除满足条件的元素
instrs.erase(
std::remove_if(instrs.begin(), instrs.end(),
[](const std::unique_ptr<MachineInstr>& instr) {
@@ -22,39 +27,59 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
);
}
StackFrameInfo& frame_info = mfunc->getFrameInfo();
Function* F = mfunc->getFunc();
RISCv64ISel* isel = mfunc->getISel();
// [关键] 获取寄存器分配的结果 (vreg -> preg 的映射)
// RegAlloc Pass 必须已经运行过
// 2. 确定需要保存的被调用者保存寄存器 (callee-saved)
auto& vreg_to_preg_map = frame_info.vreg_to_preg_map;
std::set<PhysicalReg> used_callee_saved_regs_set;
const auto& callee_saved_int = getCalleeSavedIntRegs();
const auto& callee_saved_fp = getCalleeSavedFpRegs();
// 完全遵循 AsmPrinter 中的计算逻辑
for (const auto& pair : vreg_to_preg_map) {
PhysicalReg preg = pair.second;
bool is_int_cs = std::find(callee_saved_int.begin(), callee_saved_int.end(), preg) != callee_saved_int.end();
bool is_fp_cs = std::find(callee_saved_fp.begin(), callee_saved_fp.end(), preg) != callee_saved_fp.end();
if ((is_int_cs && preg != PhysicalReg::S0) || is_fp_cs) {
used_callee_saved_regs_set.insert(preg);
}
}
frame_info.callee_saved_regs_to_store.assign(
used_callee_saved_regs_set.begin(), used_callee_saved_regs_set.end()
);
std::sort(frame_info.callee_saved_regs_to_store.begin(), frame_info.callee_saved_regs_to_store.end());
frame_info.callee_saved_size = frame_info.callee_saved_regs_to_store.size() * 8;
// 3. 计算最终的栈帧总大小,包含栈溢出保护
int total_stack_size = frame_info.locals_size +
frame_info.spill_size +
frame_info.callee_saved_size +
16; // 为 ra 和 s0 固定的16字节
16;
// 栈溢出保护:增加最大栈帧大小以容纳大型数组
const int MAX_STACK_FRAME_SIZE = 8192; // 8KB to handle large arrays like 256*4*2 = 2048 bytes
if (total_stack_size > MAX_STACK_FRAME_SIZE) {
// 如果仍然超过限制,尝试优化对齐方式
std::cerr << "Warning: Stack frame size " << total_stack_size
<< " exceeds recommended limit " << MAX_STACK_FRAME_SIZE << " for function "
<< mfunc->getName() << std::endl;
}
// 优化减少对齐开销使用16字节对齐而非更大的对齐
int aligned_stack_size = (total_stack_size + 15) & ~15;
frame_info.total_size = aligned_stack_size;
// 只有在需要分配栈空间时才生成指令
if (aligned_stack_size > 0) {
// --- 1. 插入序言 ---
// --- 4. 插入完整的序言 ---
MachineBasicBlock* entry_block = mfunc->getBlocks().front().get();
auto& entry_instrs = entry_block->getInstructions();
std::vector<std::unique_ptr<MachineInstr>> prologue_instrs;
// 1. addi sp, sp, -aligned_stack_size
// 4.1. 分配栈帧
auto alloc_stack = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
alloc_stack->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
alloc_stack->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
alloc_stack->addOperand(std::make_unique<ImmOperand>(-aligned_stack_size));
prologue_instrs.push_back(std::move(alloc_stack));
// 2. sd ra, (aligned_stack_size - 8)(sp)
// 4.2. 保存 ra 和 s0
auto save_ra = std::make_unique<MachineInstr>(RVOpcodes::SD);
save_ra->addOperand(std::make_unique<RegOperand>(PhysicalReg::RA));
save_ra->addOperand(std::make_unique<MemOperand>(
@@ -62,8 +87,6 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
std::make_unique<ImmOperand>(aligned_stack_size - 8)
));
prologue_instrs.push_back(std::move(save_ra));
// 3. sd s0, (aligned_stack_size - 16)(sp)
auto save_fp = std::make_unique<MachineInstr>(RVOpcodes::SD);
save_fp->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
save_fp->addOperand(std::make_unique<MemOperand>(
@@ -72,66 +95,55 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
));
prologue_instrs.push_back(std::move(save_fp));
// 4. addi s0, sp, aligned_stack_size
// 4.3. 设置新的帧指针 s0
auto set_fp = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
set_fp->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
set_fp->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
set_fp->addOperand(std::make_unique<ImmOperand>(aligned_stack_size));
prologue_instrs.push_back(std::move(set_fp));
// --- 在s0设置完毕后使用物理寄存器加载栈参数 ---
if (F && isel) {
int arg_idx = 0;
for (Argument* arg : F->getArguments()) {
if (arg_idx >= 8) {
unsigned vreg = isel->getVReg(arg);
if (frame_info.alloca_offsets.count(vreg) && vreg_to_preg_map.count(vreg)) {
int offset = frame_info.alloca_offsets.at(vreg);
PhysicalReg dest_preg = vreg_to_preg_map.at(vreg);
Type* arg_type = arg->getType();
if (arg_type->isFloat()) {
auto load_arg = std::make_unique<MachineInstr>(RVOpcodes::FLW);
load_arg->addOperand(std::make_unique<RegOperand>(dest_preg));
load_arg->addOperand(std::make_unique<MemOperand>(
// 4.4. 保存所有使用到的被调用者保存寄存器
int next_available_offset = -(16 + frame_info.locals_size + frame_info.spill_size);
for (const auto& reg : frame_info.callee_saved_regs_to_store) {
// 改为“先更新,后使用”逻辑
next_available_offset -= 8; // 先为当前寄存器分配下一个可用槽位
RVOpcodes store_op = isFPR(reg) ? RVOpcodes::FSD : RVOpcodes::SD;
auto save_cs_reg = std::make_unique<MachineInstr>(store_op);
save_cs_reg->addOperand(std::make_unique<RegOperand>(reg));
save_cs_reg->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
std::make_unique<ImmOperand>(next_available_offset) // 使用新计算出的正确偏移
));
prologue_instrs.push_back(std::move(load_arg));
} else {
RVOpcodes load_op = arg_type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW;
auto load_arg = std::make_unique<MachineInstr>(load_op);
load_arg->addOperand(std::make_unique<RegOperand>(dest_preg));
load_arg->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
prologue_instrs.push_back(std::move(load_arg));
}
}
}
arg_idx++;
}
prologue_instrs.push_back(std::move(save_cs_reg));
// 不再需要在循环末尾递减
}
// 确定插入点
auto insert_pos = entry_instrs.begin();
// 一次性将所有序言指令插入
if (!prologue_instrs.empty()) {
entry_instrs.insert(insert_pos,
// 4.5. 将所有生成的序言指令一次性插入到函数入口
entry_instrs.insert(entry_instrs.begin(),
std::make_move_iterator(prologue_instrs.begin()),
std::make_move_iterator(prologue_instrs.end()));
}
// --- 2. 插入尾声 (此部分逻辑保持不变) ---
// --- 5. 插入完整的尾声 ---
for (auto& mbb : mfunc->getBlocks()) {
for (auto it = mbb->getInstructions().begin(); it != mbb->getInstructions().end(); ++it) {
if ((*it)->getOpcode() == RVOpcodes::RET) {
std::vector<std::unique_ptr<MachineInstr>> epilogue_instrs;
// 1. ld ra
// 5.1. 恢复被调用者保存寄存器
int next_available_offset_restore = -(16 + frame_info.locals_size + frame_info.spill_size);
for (const auto& reg : frame_info.callee_saved_regs_to_store) {
next_available_offset_restore -= 8; // 为下一个寄存器准备偏移
RVOpcodes load_op = isFPR(reg) ? RVOpcodes::FLD : RVOpcodes::LD;
auto restore_cs_reg = std::make_unique<MachineInstr>(load_op);
restore_cs_reg->addOperand(std::make_unique<RegOperand>(reg));
restore_cs_reg->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(next_available_offset_restore) // 使用当前偏移
));
epilogue_instrs.push_back(std::move(restore_cs_reg));
}
// 5.2. 恢复 ra 和 s0
auto restore_ra = std::make_unique<MachineInstr>(RVOpcodes::LD);
restore_ra->addOperand(std::make_unique<RegOperand>(PhysicalReg::RA));
restore_ra->addOperand(std::make_unique<MemOperand>(
@@ -139,8 +151,6 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
std::make_unique<ImmOperand>(aligned_stack_size - 8)
));
epilogue_instrs.push_back(std::move(restore_ra));
// 2. ld s0
auto restore_fp = std::make_unique<MachineInstr>(RVOpcodes::LD);
restore_fp->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
restore_fp->addOperand(std::make_unique<MemOperand>(
@@ -149,18 +159,18 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
));
epilogue_instrs.push_back(std::move(restore_fp));
// 3. addi sp, sp, aligned_stack_size
// 5.3. 释放栈帧
auto dealloc_stack = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
dealloc_stack->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
dealloc_stack->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
dealloc_stack->addOperand(std::make_unique<ImmOperand>(aligned_stack_size));
epilogue_instrs.push_back(std::move(dealloc_stack));
if (!epilogue_instrs.empty()) {
// 将尾声指令插入到 RET 指令之前
mbb->getInstructions().insert(it,
std::make_move_iterator(epilogue_instrs.begin()),
std::make_move_iterator(epilogue_instrs.end()));
}
goto next_block;
}
}

View File

@@ -0,0 +1,282 @@
#include "DivStrengthReduction.h"
#include <cmath>
#include <cstdint>
namespace sysy {
char DivStrengthReduction::ID = 0;
bool DivStrengthReduction::runOnFunction(Function *F, AnalysisManager& AM) {
// This pass works on MachineFunction level, not IR level
return false;
}
void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
if (!mfunc)
return;
bool debug = false; // Set to true for debugging
if (debug)
std::cout << "Running DivStrengthReduction optimization..." << std::endl;
int next_temp_reg = 1000;
auto createTempReg = [&]() -> int {
return next_temp_reg++;
};
struct MagicInfo {
int64_t magic;
int shift;
};
auto computeMagic = [](int64_t d, bool is_32bit) -> MagicInfo {
int word_size = is_32bit ? 32 : 64;
uint64_t ad = std::abs(d);
if (ad == 0) return {0, 0};
int l = std::floor(std::log2(ad));
if ((ad & (ad - 1)) == 0) { // power of 2
l = 0; // special case for power of 2, shift will be calculated differently
}
__int128_t one = 1;
__int128_t num;
int total_shift;
if (is_32bit) {
total_shift = 31 + l;
num = one << total_shift;
} else {
total_shift = 63 + l;
num = one << total_shift;
}
__int128_t den = ad;
int64_t magic = (num / den) + 1;
return {magic, total_shift};
};
auto isPowerOfTwo = [](int64_t n) -> bool {
return n > 0 && (n & (n - 1)) == 0;
};
auto getPowerOfTwoExponent = [](int64_t n) -> int {
if (n <= 0 || (n & (n - 1)) != 0) return -1;
int shift = 0;
while (n > 1) {
n >>= 1;
shift++;
}
return shift;
};
struct InstructionReplacement {
size_t index;
size_t count_to_erase;
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
};
for (auto &mbb_uptr : mfunc->getBlocks()) {
auto &mbb = *mbb_uptr;
auto &instrs = mbb.getInstructions();
std::vector<InstructionReplacement> replacements;
for (size_t i = 0; i < instrs.size(); ++i) {
auto *instr = instrs[i].get();
bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW);
if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) {
continue;
}
if (instr->getOperands().size() != 3) {
continue;
}
auto *dst_op = instr->getOperands()[0].get();
auto *src1_op = instr->getOperands()[1].get();
auto *src2_op = instr->getOperands()[2].get();
int64_t divisor = 0;
bool const_divisor_found = false;
size_t instructions_to_replace = 1;
if (src2_op->getKind() == MachineOperand::KIND_IMM) {
divisor = static_cast<ImmOperand *>(src2_op)->getValue();
const_divisor_found = true;
} else if (src2_op->getKind() == MachineOperand::KIND_REG) {
if (i > 0) {
auto *prev_instr = instrs[i - 1].get();
if (prev_instr->getOpcode() == RVOpcodes::LI && prev_instr->getOperands().size() == 2) {
auto *li_dst_op = prev_instr->getOperands()[0].get();
auto *li_imm_op = prev_instr->getOperands()[1].get();
if (li_dst_op->getKind() == MachineOperand::KIND_REG && li_imm_op->getKind() == MachineOperand::KIND_IMM) {
auto *div_reg_op = static_cast<RegOperand *>(src2_op);
auto *li_dst_reg_op = static_cast<RegOperand *>(li_dst_op);
if (div_reg_op->isVirtual() && li_dst_reg_op->isVirtual() &&
div_reg_op->getVRegNum() == li_dst_reg_op->getVRegNum()) {
divisor = static_cast<ImmOperand *>(li_imm_op)->getValue();
const_divisor_found = true;
instructions_to_replace = 2;
}
}
}
}
}
if (!const_divisor_found) {
continue;
}
auto *dst_reg = static_cast<RegOperand *>(dst_op);
auto *src1_reg = static_cast<RegOperand *>(src1_op);
if (divisor == 0) continue;
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
if (divisor == 1) {
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
else if (divisor == -1) {
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
newInstrs.push_back(std::move(negInstr));
}
else if (isPowerOfTwo(std::abs(divisor))) {
int shift = getPowerOfTwoExponent(std::abs(divisor));
int temp_reg = createTempReg();
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
newInstrs.push_back(std::move(srlInstr));
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
addInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(addInstr));
auto sraInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
newInstrs.push_back(std::move(sraInstr));
if (divisor < 0) {
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(negInstr));
} else {
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
}
else {
auto magic_info = computeMagic(divisor, is_32bit);
int magic_reg = createTempReg();
int temp_reg = createTempReg();
auto loadInstr = std::make_unique<MachineInstr>(RVOpcodes::LI);
loadInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
loadInstr->addOperand(std::make_unique<ImmOperand>(magic_info.magic));
newInstrs.push_back(std::move(loadInstr));
if (is_32bit) {
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MUL);
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulInstr));
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
newInstrs.push_back(std::move(sraInstr));
} else {
auto mulhInstr = std::make_unique<MachineInstr>(RVOpcodes::MULH);
mulhInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulhInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulhInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulhInstr));
int post_shift = magic_info.shift - 63;
if (post_shift > 0) {
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(post_shift));
newInstrs.push_back(std::move(sraInstr));
}
}
int sign_reg = createTempReg();
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
auto subInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
subInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
subInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
subInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
newInstrs.push_back(std::move(subInstr));
if (divisor < 0) {
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(negInstr));
} else {
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
}
if (!newInstrs.empty()) {
size_t start_index = i;
if (instructions_to_replace == 2) {
start_index = i - 1;
}
replacements.push_back({start_index, instructions_to_replace, std::move(newInstrs)});
}
}
for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) {
instrs.erase(instrs.begin() + it->index, instrs.begin() + it->index + it->count_to_erase);
instrs.insert(instrs.begin() + it->index,
std::make_move_iterator(it->newInstrs.begin()),
std::make_move_iterator(it->newInstrs.end()));
}
}
}
} // namespace sysy

View File

@@ -1,7 +1,8 @@
#include "RISCv64AsmPrinter.h"
#include "RISCv64ISel.h"
#include <stdexcept>
#include <sstream>
#include <iostream>
namespace sysy {
// 检查是否为内存加载/存储指令,以处理特殊的打印格式
@@ -60,7 +61,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break;
case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break;
case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break;
case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break;
case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; case RVOpcodes::MULH: *OS << "mulh "; break;
case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break;
case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break;
case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break;
@@ -104,7 +105,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
case RVOpcodes::FMV_S: *OS << "fmv.s "; break;
case RVOpcodes::FMV_W_X: *OS << "fmv.w.x "; break;
case RVOpcodes::FMV_X_W: *OS << "fmv.x.w "; break;
case RVOpcodes::CALL: { // [核心修改] 为CALL指令添加特殊处理逻辑
case RVOpcodes::CALL: { // 为CALL指令添加特殊处理逻辑
*OS << "call ";
// 遍历所有操作数,只寻找并打印函数名标签
for (const auto& op : instr->getOperands()) {
@@ -236,4 +237,30 @@ std::string RISCv64AsmPrinter::regToString(PhysicalReg reg) {
}
}
std::string RISCv64AsmPrinter::formatInstr(const MachineInstr* instr) {
if (!instr) return "(null instr)";
// 使用 stringstream 作为临时的输出目标
std::stringstream ss;
// 关键: 临时将类成员 'OS' 指向我们的 stringstream
std::ostream* old_os = this->OS;
this->OS = &ss;
// 修正: 调用正确的内部打印函数 printMachineInstr
printInstruction(const_cast<MachineInstr*>(instr), false);
// 恢复旧的 ostream 指针
this->OS = old_os;
// 获取stringstream的内容并做一些清理
std::string result = ss.str();
size_t endpos = result.find_last_not_of(" \t\n\r");
if (std::string::npos != endpos) {
result = result.substr(0, endpos + 1);
}
return result;
}
} // namespace sysy

View File

@@ -73,7 +73,7 @@ std::string RISCv64CodeGen::module_gen() {
for (const auto& global_ptr : module->getGlobals()) {
GlobalValue* global = global_ptr.get();
// [核心修改] 使用更健壮的逻辑来判断是否为大型零初始化数组
// 使用更健壮的逻辑来判断是否为大型零初始化数组
bool is_all_zeros = true;
const auto& init_values = global->getInitValues();
@@ -139,8 +139,30 @@ std::string RISCv64CodeGen::module_gen() {
ss << ".type " << global->getName() << ", @object\n";
ss << ".size " << global->getName() << ", " << total_size << "\n";
ss << global->getName() << ":\n";
bool is_all_zeros = true;
const auto& init_values = global->getInitValues();
if (init_values.getValues().empty()) {
is_all_zeros = true;
} else {
for (auto val : init_values.getValues()) {
if (auto const_val = dynamic_cast<ConstantValue*>(val)) {
if (!const_val->isZero()) {
is_all_zeros = false;
break;
}
} else {
is_all_zeros = false;
break;
}
}
}
if (is_all_zeros) {
ss << " .zero " << total_size << "\n";
} else {
// 对于有非零初始值的变量,保持原有的打印逻辑。
printInitializer(ss, global->getInitValues());
}
}
// b. 处理全局常量 (ConstantVariable)
for (const auto& const_ptr : module->getConsts()) {
@@ -174,15 +196,43 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
// === 完整的后端处理流水线 ===
// 阶段 1: 指令选择 (sysy::IR -> LLIR with virtual registers)
DEBUG = 0;
DEEPDEBUG = 0;
RISCv64ISel isel;
std::unique_ptr<MachineFunction> mfunc = isel.runOnFunction(func);
// 第一次调试打印输出
std::stringstream ss1;
RISCv64AsmPrinter printer1(mfunc.get());
printer1.run(ss1, true);
std::stringstream ss_after_isel;
RISCv64AsmPrinter printer_isel(mfunc.get());
printer_isel.run(ss_after_isel, true);
if (DEBUG) {
std::cout << ss_after_isel.str();
}
if (DEBUG) {
std::cerr << "====== Intermediate Representation after Instruction Selection ======\n"
<< ss_after_isel.str();
}
// 阶段 2: 指令调度 (Instruction Scheduling)
// 阶段 2: 消除帧索引 (展开伪指令,计算局部变量偏移)
// 这个Pass必须在寄存器分配之前运行
EliminateFrameIndicesPass efi_pass;
efi_pass.runOnMachineFunction(mfunc.get());
if (DEBUG) {
std::cerr << "====== stack info after eliminate frame indices ======\n";
mfunc->dumpStackFrameInfo(std::cerr);
std::stringstream ss_after_eli;
printer_isel.run(ss_after_eli, true);
std::cerr << "====== LLIR after eliminate frame indices ======\n"
<< ss_after_eli.str();
}
// 阶段 2: 除法强度削弱优化 (Division Strength Reduction)
DivStrengthReduction div_strength_reduction;
div_strength_reduction.runOnMachineFunction(mfunc.get());
// 阶段 2.1: 指令调度 (Instruction Scheduling)
PreRA_Scheduler scheduler;
scheduler.runOnMachineFunction(mfunc.get());
@@ -190,10 +240,20 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
RISCv64RegAlloc reg_alloc(mfunc.get());
reg_alloc.run();
if (DEBUG) {
std::cerr << "====== stack info after reg alloc ======\n";
mfunc->dumpStackFrameInfo(std::cerr);
}
// 阶段 3.1: 处理被调用者保存寄存器
CalleeSavedHandler callee_handler;
callee_handler.runOnMachineFunction(mfunc.get());
if (DEBUG) {
std::cerr << "====== stack info after callee handler ======\n";
mfunc->dumpStackFrameInfo(std::cerr);
}
// 阶段 4: 窥孔优化 (Peephole Optimization)
PeepholeOptimizer peephole;
peephole.runOnMachineFunction(mfunc.get());
@@ -206,7 +266,7 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
PrologueEpilogueInsertionPass pei_pass;
pei_pass.runOnMachineFunction(mfunc.get());
// 阶段 3.3: 清理产生的大立即数
// 阶段 3.3: 大立即数合法化
LegalizeImmediatesPass legalizer;
legalizer.runOnMachineFunction(mfunc.get());
@@ -214,8 +274,9 @@ std::string RISCv64CodeGen::function_gen(Function* func) {
std::stringstream ss;
RISCv64AsmPrinter printer(mfunc.get());
printer.run(ss);
if (DEBUG) ss << "\n" << ss1.str(); // 将指令选择阶段的结果也包含在最终输出中
return ss.str();
}
} // namespace sysy

View File

@@ -1,9 +1,10 @@
#include "RISCv64ISel.h"
#include "IR.h" // For GlobalValue
#include <stdexcept>
#include <set>
#include <functional>
#include <cmath> // For std::fabs
#include <limits> // For std::numeric_limits
#include <cmath>
#include <limits>
#include <iostream>
namespace sysy {
@@ -167,33 +168,6 @@ void RISCv64ISel::selectBasicBlock(BasicBlock* bb) {
select_recursive(node_to_select);
}
}
if (CurMBB == MFunc->getBlocks().front().get()) { // 只对入口块操作
auto keepalive = std::make_unique<MachineInstr>(RVOpcodes::PSEUDO_KEEPALIVE);
for (Argument* arg : F->getArguments()) {
keepalive->addOperand(std::make_unique<RegOperand>(getVReg(arg)));
}
auto& instrs = CurMBB->getInstructions();
auto insert_pos = instrs.end();
// 关键:检查基本块是否以一个“终止指令”结尾
if (!instrs.empty()) {
RVOpcodes last_op = instrs.back()->getOpcode();
// 扩充了判断条件,涵盖所有可能的终止指令
if (last_op == RVOpcodes::J || last_op == RVOpcodes::RET ||
last_op == RVOpcodes::BEQ || last_op == RVOpcodes::BNE ||
last_op == RVOpcodes::BLT || last_op == RVOpcodes::BGE ||
last_op == RVOpcodes::BLTU || last_op == RVOpcodes::BGEU)
{
// 如果是,插入点就在这个终止指令之前
insert_pos = std::prev(instrs.end());
}
}
// 在计算出的正确位置插入伪指令
instrs.insert(insert_pos, std::move(keepalive));
}
}
// 核心函数为DAG节点选择并生成MachineInstr (已修复和增强的完整版本)
@@ -209,9 +183,13 @@ void RISCv64ISel::selectNode(DAGNode* node) {
case DAGNode::CONSTANT:
case DAGNode::ALLOCA_ADDR:
if (node->value) {
// GlobalValue objects (global variables) should not get virtual registers
// since they represent memory addresses, not register-allocated values
if (dynamic_cast<GlobalValue*>(node->value) == nullptr) {
// 确保它有一个关联的虚拟寄存器即可,不生成代码。
getVReg(node->value);
}
}
break;
case DAGNode::FP_CONSTANT: {
@@ -402,7 +380,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
Value* base = nullptr;
Value* offset = nullptr;
// [修改] 扩展基地址的判断,使其可以识别 AllocaInst 或 GlobalValue
// 扩展基地址的判断,使其可以识别 AllocaInst 或 GlobalValue
if (dynamic_cast<AllocaInst*>(lhs) || dynamic_cast<GlobalValue*>(lhs)) {
base = lhs;
offset = rhs;
@@ -421,7 +399,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
CurMBB->addInstruction(std::move(li));
}
// 2. [修改] 根据基地址的类型,生成不同的指令来获取基地址
// 2. 根据基地址的类型,生成不同的指令来获取基地址
auto base_addr_vreg = getNewVReg(Type::getIntType()); // 创建一个新的临时vreg来存放基地址
// 情况一:基地址是局部栈变量
@@ -452,7 +430,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
}
}
// [V2优点] 在BINARY节点内部按需加载常量操作数。
// 在BINARY节点内部按需加载常量操作数。
auto load_val_if_const = [&](Value* val) {
if (auto c = dynamic_cast<ConstantValue*>(val)) {
if (DEBUG) {
@@ -483,7 +461,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
auto dest_vreg = getVReg(bin);
auto lhs_vreg = getVReg(lhs);
// [V2优点] 融合 ADDIW 优化。
// 融合 ADDIW 优化。
if (rhs_is_imm_opt) {
auto rhs_const = dynamic_cast<ConstantValue*>(rhs);
auto instr = std::make_unique<MachineInstr>(RVOpcodes::ADDIW);
@@ -539,6 +517,15 @@ void RISCv64ISel::selectNode(DAGNode* node) {
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kSRA: {
auto rhs_const = dynamic_cast<ConstantInteger*>(rhs);
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SRAIW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<ImmOperand>(rhs_const->getInt()));
CurMBB->addInstruction(std::move(instr));
break;
}
case BinaryInst::kICmpEQ: { // 等于 (a == b) -> (subw; seqz)
auto sub = std::make_unique<MachineInstr>(RVOpcodes::SUBW);
sub->addOperand(std::make_unique<RegOperand>(dest_vreg));
@@ -758,11 +745,83 @@ void RISCv64ISel::selectNode(DAGNode* node) {
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFtoI: { // 浮点 to 整数
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FCVT_W_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg)); // 目标是整数vreg
instr->addOperand(std::make_unique<RegOperand>(src_vreg)); // 源是浮点vreg
CurMBB->addInstruction(std::move(instr));
case Instruction::kFtoI: { // 浮点 to 整数 (带向下取整)
// 目标:实现 floor(x) 的效果, C/C++中浮点转整数是截断(truncate)
// 对于正数floor(x) == truncate(x)
// RISC-V的 fcvt.w.s 默认是“四舍五入到偶数”
// 我们需要手动实现截断逻辑
// 逻辑:
// temp_i = fcvt.w.s(x) // 四舍五入
// temp_f = fcvt.s.w(temp_i) // 转回浮点
// if (x < temp_f) { // 如果原数更小,说明被“五入”了
// result = temp_i - 1
// } else {
// result = temp_i
// }
auto temp_i_vreg = getNewVReg(Type::getIntType());
auto temp_f_vreg = getNewVReg(Type::getFloatType());
auto cmp_vreg = getNewVReg(Type::getIntType());
// 1. fcvt.w.s temp_i_vreg, src_vreg
auto fcvt_w = std::make_unique<MachineInstr>(RVOpcodes::FCVT_W_S);
fcvt_w->addOperand(std::make_unique<RegOperand>(temp_i_vreg));
fcvt_w->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(fcvt_w));
// 2. fcvt.s.w temp_f_vreg, temp_i_vreg
auto fcvt_s = std::make_unique<MachineInstr>(RVOpcodes::FCVT_S_W);
fcvt_s->addOperand(std::make_unique<RegOperand>(temp_f_vreg));
fcvt_s->addOperand(std::make_unique<RegOperand>(temp_i_vreg));
CurMBB->addInstruction(std::move(fcvt_s));
// 3. flt.s cmp_vreg, src_vreg, temp_f_vreg
auto flt = std::make_unique<MachineInstr>(RVOpcodes::FLT_S);
flt->addOperand(std::make_unique<RegOperand>(cmp_vreg));
flt->addOperand(std::make_unique<RegOperand>(src_vreg));
flt->addOperand(std::make_unique<RegOperand>(temp_f_vreg));
CurMBB->addInstruction(std::move(flt));
// 创建标签
int unique_id = this->local_label_counter++;
std::string rounded_up_label = MFunc->getName() + "_ftoi_rounded_up_" + std::to_string(unique_id);
std::string done_label = MFunc->getName() + "_ftoi_done_" + std::to_string(unique_id);
// 4. bne cmp_vreg, x0, rounded_up_label
auto bne = std::make_unique<MachineInstr>(RVOpcodes::BNE);
bne->addOperand(std::make_unique<RegOperand>(cmp_vreg));
bne->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
bne->addOperand(std::make_unique<LabelOperand>(rounded_up_label));
CurMBB->addInstruction(std::move(bne));
// 5. else 分支: mv dest_vreg, temp_i_vreg
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv->addOperand(std::make_unique<RegOperand>(dest_vreg));
mv->addOperand(std::make_unique<RegOperand>(temp_i_vreg));
CurMBB->addInstruction(std::move(mv));
// 6. j done_label
auto j = std::make_unique<MachineInstr>(RVOpcodes::J);
j->addOperand(std::make_unique<LabelOperand>(done_label));
CurMBB->addInstruction(std::move(j));
// 7. rounded_up_label:
auto label_up = std::make_unique<MachineInstr>(RVOpcodes::LABEL);
label_up->addOperand(std::make_unique<LabelOperand>(rounded_up_label));
CurMBB->addInstruction(std::move(label_up));
// 8. addiw dest_vreg, temp_i_vreg, -1
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDIW);
addi->addOperand(std::make_unique<RegOperand>(dest_vreg));
addi->addOperand(std::make_unique<RegOperand>(temp_i_vreg));
addi->addOperand(std::make_unique<ImmOperand>(-1));
CurMBB->addInstruction(std::move(addi));
// 9. done_label:
auto label_done = std::make_unique<MachineInstr>(RVOpcodes::LABEL);
label_done->addOperand(std::make_unique<LabelOperand>(done_label));
CurMBB->addInstruction(std::move(label_done));
break;
}
case Instruction::kFNeg: { // 浮点取负
@@ -943,7 +1002,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
// --- 步骤 3: 生成CALL指令 ---
auto call_instr = std::make_unique<MachineInstr>(RVOpcodes::CALL);
// [协议] 如果函数有返回值,将它的目标虚拟寄存器作为第一个操作数
// 如果函数有返回值,将它的目标虚拟寄存器作为第一个操作数
if (!call->getType()->isVoid()) {
unsigned dest_vreg = getVReg(call);
call_instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
@@ -1020,7 +1079,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
} else {
// --- 处理整数/指针返回值 ---
// 返回值需要被放入 a0
// [V2优点] 在RETURN节点内加载常量返回值
// 在RETURN节点内加载常量返回值
if (auto const_val = dynamic_cast<ConstantValue*>(ret_val)) {
auto li_instr = std::make_unique<MachineInstr>(RVOpcodes::LI);
li_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
@@ -1034,7 +1093,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
}
}
}
// [V1设计保留] 函数尾声epilogue不由RETURN节点生成
// 函数尾声epilogue不由RETURN节点生成
// 而是由后续的AsmPrinter或其它Pass统一处理这是一种常见且有效的模块化设计。
auto ret_mi = std::make_unique<MachineInstr>(RVOpcodes::RET);
CurMBB->addInstruction(std::move(ret_mi));
@@ -1048,7 +1107,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
auto then_bb_name = cond_br->getThenBlock()->getName();
auto else_bb_name = cond_br->getElseBlock()->getName();
// [优化] 检查分支条件是否为编译期常量
// 检查分支条件是否为编译期常量
if (auto const_cond = dynamic_cast<ConstantValue*>(condition)) {
// 如果条件是常量直接生成一个无条件跳转J而不是BNE
if (const_cond->getInt() != 0) { // 条件为 true
@@ -1063,7 +1122,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
}
// 如果条件不是常量,则执行标准流程
else {
// [修复] 为条件变量生成加载指令(如果它是常量的话,尽管上面已经处理了)
// 为条件变量生成加载指令(如果它是常量的话,尽管上面已经处理了)
// 这一步是为了逻辑完整,以防有其他类型的常量没有被捕获
if (auto const_val = dynamic_cast<ConstantValue*>(condition)) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
@@ -1097,7 +1156,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
}
case DAGNode::MEMSET: {
// [V1设计保留] Memset的核心展开逻辑在虚拟寄存器层面是正确的无需修改。
// Memset的核心展开逻辑在虚拟寄存器层面是正确的无需修改。
// 之前的bug是由于其输入地址、值、大小的虚拟寄存器未被正确初始化。
// 在修复了CONSTANT/ALLOCA_ADDR的加载问题后此处的逻辑现在可以正常工作。
@@ -1280,14 +1339,19 @@ void RISCv64ISel::selectNode(DAGNode* node) {
if (stride != 0) {
// --- 为当前索引和步长生成偏移计算指令 ---
auto offset_vreg = getNewVReg();
auto index_vreg = getVReg(indexValue);
// 如果索引是常量,先用 LI 指令加载到虚拟寄存器
// 处理索引 - 区分常量与动态值
unsigned index_vreg;
if (auto const_index = dynamic_cast<ConstantValue*>(indexValue)) {
// 对于常量索引,直接创建新的虚拟寄存器
index_vreg = getNewVReg();
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(index_vreg));
li->addOperand(std::make_unique<ImmOperand>(const_index->getInt()));
CurMBB->addInstruction(std::move(li));
} else {
// 对于动态索引,使用已存在的虚拟寄存器
index_vreg = getVReg(indexValue);
}
// 优化如果步长是1可以直接移动(MV)作为偏移量,无需乘法
@@ -1445,7 +1509,7 @@ std::vector<std::unique_ptr<RISCv64ISel::DAGNode>> RISCv64ISel::build_dag(BasicB
// 依次添加所有索引作为后续的操作数
for (auto index : gep->getIndices()) {
// [修复] 从 Use 对象中获取真正的 Value*
// 从 Use 对象中获取真正的 Value*
gep_node->operands.push_back(get_operand_node(index->getValue(), value_to_node, nodes_storage));
}
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
@@ -1473,7 +1537,7 @@ std::vector<std::unique_ptr<RISCv64ISel::DAGNode>> RISCv64ISel::build_dag(BasicB
}
}
}
if (bin->getKind() >= Instruction::kFAdd) { // 假设浮点指令枚举值更大
if (bin->isFPBinary()) { // 假设浮点指令枚举值更大
auto fbin_node = create_node(DAGNode::FBINARY, bin, value_to_node, nodes_storage);
fbin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage));
fbin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));
@@ -1549,7 +1613,7 @@ unsigned RISCv64ISel::getTypeSizeInBytes(Type* type) {
}
}
// [新] 打印DAG图以供调试的辅助函数
// 打印DAG图以供调试的辅助函数
void RISCv64ISel::print_dag(const std::vector<std::unique_ptr<DAGNode>>& dag, const std::string& bb_name) {
// 检查是否有DEBUG宏或者全局变量避免在非调试模式下打印
// if (!DEBUG) return;
@@ -1645,4 +1709,8 @@ void RISCv64ISel::print_dag(const std::vector<std::unique_ptr<DAGNode>>& dag, co
std::cerr << "======================================\n\n";
}
unsigned int RISCv64ISel::getVRegCounter() const {
return vreg_counter;
}
} // namespace sysy

View File

@@ -1,6 +1,122 @@
#include "RISCv64LLIR.h"
#include <vector>
#include <iostream> // 用于 std::ostream 和 std::cerr
#include <string> // 用于 std::string
namespace sysy {
// 辅助函数:将 PhysicalReg 枚举转换为可读的字符串
std::string regToString(PhysicalReg reg) {
switch (reg) {
case PhysicalReg::ZERO: return "x0"; case PhysicalReg::RA: return "ra";
case PhysicalReg::SP: return "sp"; case PhysicalReg::GP: return "gp";
case PhysicalReg::TP: return "tp"; case PhysicalReg::T0: return "t0";
case PhysicalReg::T1: return "t1"; case PhysicalReg::T2: return "t2";
case PhysicalReg::S0: return "s0"; case PhysicalReg::S1: return "s1";
case PhysicalReg::A0: return "a0"; case PhysicalReg::A1: return "a1";
case PhysicalReg::A2: return "a2"; case PhysicalReg::A3: return "a3";
case PhysicalReg::A4: return "a4"; case PhysicalReg::A5: return "a5";
case PhysicalReg::A6: return "a6"; case PhysicalReg::A7: return "a7";
case PhysicalReg::S2: return "s2"; case PhysicalReg::S3: return "s3";
case PhysicalReg::S4: return "s4"; case PhysicalReg::S5: return "s5";
case PhysicalReg::S6: return "s6"; case PhysicalReg::S7: return "s7";
case PhysicalReg::S8: return "s8"; case PhysicalReg::S9: return "s9";
case PhysicalReg::S10: return "s10"; case PhysicalReg::S11: return "s11";
case PhysicalReg::T3: return "t3"; case PhysicalReg::T4: return "t4";
case PhysicalReg::T5: return "t5"; case PhysicalReg::T6: return "t6";
case PhysicalReg::F0: return "f0"; case PhysicalReg::F1: return "f1";
case PhysicalReg::F2: return "f2"; case PhysicalReg::F3: return "f3";
case PhysicalReg::F4: return "f4"; case PhysicalReg::F5: return "f5";
case PhysicalReg::F6: return "f6"; case PhysicalReg::F7: return "f7";
case PhysicalReg::F8: return "f8"; case PhysicalReg::F9: return "f9";
case PhysicalReg::F10: return "f10"; case PhysicalReg::F11: return "f11";
case PhysicalReg::F12: return "f12"; case PhysicalReg::F13: return "f13";
case PhysicalReg::F14: return "f14"; case PhysicalReg::F15: return "f15";
case PhysicalReg::F16: return "f16"; case PhysicalReg::F17: return "f17";
case PhysicalReg::F18: return "f18"; case PhysicalReg::F19: return "f19";
case PhysicalReg::F20: return "f20"; case PhysicalReg::F21: return "f21";
case PhysicalReg::F22: return "f22"; case PhysicalReg::F23: return "f23";
case PhysicalReg::F24: return "f24"; case PhysicalReg::F25: return "f25";
case PhysicalReg::F26: return "f26"; case PhysicalReg::F27: return "f27";
case PhysicalReg::F28: return "f28"; case PhysicalReg::F29: return "f29";
case PhysicalReg::F30: return "f30"; case PhysicalReg::F31: return "f31";
default: return "UNKNOWN_REG";
}
}
// 打印栈帧信息的完整实现
void MachineFunction::dumpStackFrameInfo(std::ostream& os) const {
const StackFrameInfo& info = frame_info;
os << "--- Stack Frame Info for function '" << getName() << "' ---\n";
// 打印尺寸信息
os << " Sizes:\n";
os << " Total Size: " << info.total_size << " bytes\n";
os << " Locals Size: " << info.locals_size << " bytes\n";
os << " Spill Size: " << info.spill_size << " bytes\n";
os << " Callee-Saved Size: " << info.callee_saved_size << " bytes\n";
os << "\n";
// 打印 Alloca 变量的偏移量
os << " Alloca Offsets (vreg -> offset from FP):\n";
if (info.alloca_offsets.empty()) {
os << " (None)\n";
} else {
for (const auto& pair : info.alloca_offsets) {
os << " %vreg" << pair.first << " -> " << pair.second << "\n";
}
}
os << "\n";
// 打印溢出变量的偏移量
os << " Spill Offsets (vreg -> offset from FP):\n";
if (info.spill_offsets.empty()) {
os << " (None)\n";
} else {
for (const auto& pair : info.spill_offsets) {
os << " %vreg" << pair.first << " -> " << pair.second << "\n";
}
}
os << "\n";
// 打印使用的被调用者保存寄存器
os << " Used Callee-Saved Registers:\n";
if (info.used_callee_saved_regs.empty()) {
os << " (None)\n";
} else {
os << " { ";
for (const auto& reg : info.used_callee_saved_regs) {
os << regToString(reg) << " ";
}
os << "}\n";
}
os << "\n";
// 打印需要保存/恢复的被调用者保存寄存器 (有序)
os << " Callee-Saved Registers to Store/Restore:\n";
if (info.callee_saved_regs_to_store.empty()) {
os << " (None)\n";
} else {
os << " [ ";
for (const auto& reg : info.callee_saved_regs_to_store) {
os << regToString(reg) << " ";
}
os << "]\n";
}
os << "\n";
// 打印最终的寄存器分配结果
os << " Final Register Allocation Map (vreg -> preg):\n";
if (info.vreg_to_preg_map.empty()) {
os << " (None)\n";
} else {
for (const auto& pair : info.vreg_to_preg_map) {
os << " %vreg" << pair.first << " -> " << regToString(pair.second) << "\n";
}
}
os << "---------------------------------------------------\n";
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
#ifndef ELIMINATE_FRAME_INDICES_H
#define ELIMINATE_FRAME_INDICES_H
#include "RISCv64LLIR.h"
namespace sysy {
class EliminateFrameIndicesPass {
public:
// Pass 的主入口函数
void runOnMachineFunction(MachineFunction* mfunc);
private:
// 帮助计算类型大小的辅助函数从原RegAlloc中移出
unsigned getTypeSizeInBytes(Type* type);
};
} // namespace sysy
#endif // ELIMINATE_FRAME_INDICES_H

View File

@@ -0,0 +1,30 @@
#ifndef RISCV64_DIV_STRENGTH_REDUCTION_H
#define RISCV64_DIV_STRENGTH_REDUCTION_H
#include "RISCv64LLIR.h"
#include "Pass.h"
namespace sysy {
/**
* @class DivStrengthReduction
* @brief 除法强度削弱优化器
* * 将除法运算转换为乘法运算使用magic number算法
* 适用于除数为常数的情况,可以显著提高性能
*/
class DivStrengthReduction : public Pass {
public:
static char ID;
DivStrengthReduction() : Pass("div-strength-reduction", Granularity::Function, PassKind::Optimization) {}
void *getPassID() const override { return &ID; }
bool runOnFunction(Function *F, AnalysisManager& AM) override;
void runOnMachineFunction(MachineFunction* mfunc);
};
} // namespace sysy
#endif // RISCV64_DIV_STRENGTH_REDUCTION_H

View File

@@ -20,6 +20,8 @@ public:
void setStream(std::ostream& os) { OS = &os; }
// 辅助函数
std::string regToString(PhysicalReg reg);
std::string formatInstr(const MachineInstr *instr);
private:
// 打印各个部分
void printBasicBlock(MachineBasicBlock* mbb, bool debug = false);

View File

@@ -22,7 +22,6 @@ private:
// 函数级代码生成 (实现新的流水线)
std::string function_gen(Function* func);
// 私有辅助函数,用于根据类型计算其占用的字节数。
unsigned getTypeSizeInBytes(Type* type);

View File

@@ -3,6 +3,12 @@
#include "RISCv64LLIR.h"
// Forward declarations
namespace sysy {
class GlobalValue;
class Value;
}
extern int DEBUG;
extern int DEEPDEBUG;
@@ -18,6 +24,7 @@ public:
unsigned getVReg(Value* val);
unsigned getNewVReg() { return vreg_counter++; }
unsigned getNewVReg(Type* type);
unsigned getVRegCounter() const;
// 获取 vreg_map 的公共接口
const std::map<Value*, unsigned>& getVRegMap() const { return vreg_map; }
const std::map<unsigned, Value*>& getVRegValueMap() const { return vreg_to_value_map; }

View File

@@ -3,6 +3,7 @@
#include "IR.h" // 确保包含了您自己的IR头文件
#include <string>
#include <iostream>
#include <vector>
#include <memory>
#include <cstdint>
@@ -38,14 +39,14 @@ enum class PhysicalReg {
// 用于内部表示物理寄存器在干扰图中的节点ID一个简单的特殊ID确保不与vreg_counter冲突
// 假设 vreg_counter 不会达到这么大的值
PHYS_REG_START_ID = 100000,
PHYS_REG_START_ID = 1000000,
PHYS_REG_END_ID = PHYS_REG_START_ID + 320, // 预留足够的空间
};
// RISC-V 指令操作码枚举
enum class RVOpcodes {
// 算术指令
ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW,
ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, MULH, DIV, DIVW, REM, REMW,
// 逻辑指令
XOR, XORI, OR, ORI, AND, ANDI,
// 移位指令
@@ -195,6 +196,11 @@ public:
preg = new_preg;
is_virtual = false;
}
void setVRegNum(unsigned new_vreg_num) {
vreg_num = new_vreg_num;
is_virtual = true; // 确保设置vreg时操作数状态正确
}
private:
unsigned vreg_num = 0;
PhysicalReg preg = PhysicalReg::ZERO;
@@ -274,14 +280,15 @@ private:
// 栈帧信息
struct StackFrameInfo {
int locals_size = 0; // 仅为AllocaInst分配的大小
int locals_end_offset = 0; // 记录局部变量分配结束后的偏移量(相对于s0为负)
int spill_size = 0; // 仅为溢出分配的大小
int total_size = 0; // 总大小
int callee_saved_size = 0; // 保存寄存器的大小
std::map<unsigned, int> alloca_offsets; // <AllocaInst的vreg, 栈偏移>
std::map<unsigned, int> spill_offsets; // <溢出vreg, 栈偏移>
std::set<PhysicalReg> used_callee_saved_regs; // 使用的保存寄存器
std::map<unsigned, PhysicalReg> vreg_to_preg_map;
std::vector<PhysicalReg> callee_saved_regs; // 用于存储需要保存的被调用者保存寄存器列表
std::map<unsigned, PhysicalReg> vreg_to_preg_map; // RegAlloc最终的分配结果
std::vector<PhysicalReg> callee_saved_regs_to_store; // 已排序的、需要存取的被调用者保存寄存器
};
// 机器函数
@@ -295,7 +302,7 @@ public:
StackFrameInfo& getFrameInfo() { return frame_info; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& getBlocks() const { return blocks; }
std::vector<std::unique_ptr<MachineBasicBlock>>& getBlocks() { return blocks; }
void dumpStackFrameInfo(std::ostream& os = std::cerr) const;
void addBlock(std::unique_ptr<MachineBasicBlock> block) {
blocks.push_back(std::move(block));
}

View File

@@ -8,7 +8,10 @@
#include "CalleeSavedHandler.h"
#include "LegalizeImmediates.h"
#include "PrologueEpilogueInsertion.h"
#include "EliminateFrameIndices.h"
#include "Pass.h"
#include "DivStrengthReduction.h"
namespace sysy {

View File

@@ -3,9 +3,15 @@
#include "RISCv64LLIR.h"
#include "RISCv64ISel.h" // 包含 RISCv64ISel.h 以访问 ISel 和 Value 类型
#include <set>
#include <vector>
#include <map>
#include <stack>
extern int DEBUG;
extern int DEEPDEBUG;
extern int DEBUGLENGTH; // 用于限制调试输出的长度
extern int DEEPERDEBUG; // 用于更深层次的调试输出
namespace sysy {
@@ -17,58 +23,98 @@ public:
void run();
private:
using LiveSet = std::set<unsigned>; // 活跃虚拟寄存器集合
using InterferenceGraph = std::map<unsigned, std::set<unsigned>>;
// 类型定义与Python版本对应
using VRegSet = std::set<unsigned>;
using InterferenceGraph = std::map<unsigned, VRegSet>;
using VRegStack = std::vector<unsigned>; // 使用vector模拟栈方便遍历
using MoveList = std::map<unsigned, std::set<const MachineInstr*>>;
using AliasMap = std::map<unsigned, unsigned>;
using ColorMap = std::map<unsigned, PhysicalReg>;
using VRegMoveSet = std::set<const MachineInstr*>;
// 栈帧管理
void eliminateFrameIndices();
// --- 核心算法流程 ---
void initialize();
void build();
void makeWorklist();
void simplify();
void coalesce();
void freeze();
void selectSpill();
void assignColors();
void rewriteProgram();
bool doAllocation();
void applyColoring();
// 活跃性分析
void dumpState(const std::string &stage);
void precolorByCallingConvention();
// --- 辅助函数 ---
void getInstrUseDef(const MachineInstr* instr, VRegSet& use, VRegSet& def);
void getInstrUseDef_Liveness(const MachineInstr *instr, VRegSet &use, VRegSet &def);
void addEdge(unsigned u, unsigned v);
VRegSet adjacent(unsigned n);
VRegMoveSet nodeMoves(unsigned n);
bool moveRelated(unsigned n);
void decrementDegree(unsigned m);
void enableMoves(const VRegSet& nodes);
unsigned getAlias(unsigned n);
void addWorklist(unsigned u);
bool briggsHeuristic(unsigned u, unsigned v);
bool georgeHeuristic(unsigned u, unsigned v);
void combine(unsigned u, unsigned v);
void freezeMoves(unsigned u);
void collectUsedCalleeSavedRegs();
bool isFPVReg(unsigned vreg) const;
std::string regToString(PhysicalReg reg);
std::string regIdToString(unsigned id);
// --- 活跃性分析 ---
void analyzeLiveness();
// 构建干扰图
void buildInterferenceGraph();
// 图着色分配寄存器
void colorGraph();
// 重写函数替换vreg并插入溢出代码
void rewriteFunction();
// 辅助函数获取指令的Use/Def集合
void getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def);
// 辅助函数,处理调用约定
void handleCallingConvention();
MachineFunction* MFunc;
RISCv64ISel* ISel;
// 活跃性分析结果
std::map<const MachineInstr*, LiveSet> live_in_map;
std::map<const MachineInstr*, LiveSet> live_out_map;
// 干扰图
InterferenceGraph interference_graph;
// 图着色结果
std::map<unsigned, PhysicalReg> color_map; // vreg -> preg
std::set<unsigned> spilled_vregs; // 被溢出的vreg集合
// 可用的物理寄存器池
// --- 算法数据结构 ---
// 寄存器池
std::vector<PhysicalReg> allocable_int_regs;
std::vector<PhysicalReg> allocable_fp_regs;
int K_int; // 整数寄存器数量
int K_fp; // 浮点寄存器数量
// 存储vreg到IR Value*的反向映射
// 这个map将在run()函数开始时被填充并在rewriteFunction()中使用。
std::map<unsigned, Value*> vreg_to_value_map;
std::map<PhysicalReg, unsigned> preg_to_vreg_id_map; // 物理寄存器到特殊vreg ID的映射
// 节点集合
VRegSet precolored; // 预着色的节点 (物理寄存器)
VRegSet initial; // 初始的、所有待处理的虚拟寄存器节点
VRegSet simplifyWorklist;
VRegSet freezeWorklist;
VRegSet spillWorklist;
VRegSet spilledNodes;
VRegSet coalescedNodes;
VRegSet coloredNodes;
VRegStack selectStack;
// 用于计算类型大小的辅助函数
unsigned getTypeSizeInBytes(Type* type);
// Move指令相关
std::set<const MachineInstr*> coalescedMoves;
std::set<const MachineInstr*> constrainedMoves;
std::set<const MachineInstr*> frozenMoves;
std::set<const MachineInstr*> worklistMoves;
std::set<const MachineInstr*> activeMoves;
// 辅助函数,用于打印集合
static void printLiveSet(const LiveSet& s, const std::string& name, std::ostream& os);
// 数据结构
InterferenceGraph adjSet;
std::map<unsigned, VRegSet> adjList; // 邻接表
std::map<unsigned, int> degree;
MoveList moveList;
AliasMap alias;
ColorMap color_map;
// 活跃性分析结果
std::map<const MachineInstr*, VRegSet> live_in_map;
std::map<const MachineInstr*, VRegSet> live_out_map;
// VReg -> Value* 和 VReg -> Type* 的映射
const std::map<unsigned, Value*>& vreg_to_value_map;
const std::map<unsigned, Type*>& vreg_type_map;
};
} // namespace sysy

View File

@@ -20,6 +20,10 @@
#include <algorithm>
namespace sysy {
// Global cleanup function to release all statically allocated IR objects
void cleanupIRPools();
/**
* \defgroup type Types
* @brief Sysy的类型系统
@@ -83,6 +87,7 @@ class Type {
auto as() const -> std::enable_if_t<std::is_base_of_v<Type, T>, T *> {
return dynamic_cast<T *>(const_cast<Type *>(this));
}
virtual void print(std::ostream& os) const;
};
class PointerType : public Type {
@@ -95,6 +100,9 @@ class PointerType : public Type {
public:
static PointerType* get(Type *baseType); ///< 获取指向baseType的Pointer类型
// Cleanup method to release all cached pointer types (call at program exit)
static void cleanup();
public:
Type* getBaseType() const { return baseType; } ///< 获取指向的类型
};
@@ -112,6 +120,9 @@ class FunctionType : public Type {
/// 获取返回值类型为returnType 形参类型列表为paramTypes的Function类型
static FunctionType* get(Type *returnType, const std::vector<Type *> &paramTypes = {});
// Cleanup method to release all cached function types (call at program exit)
static void cleanup();
public:
Type* getReturnType() const { return returnType; } ///< 获取返回值类信息
auto getParamTypes() const { return make_range(paramTypes); } ///< 获取形参类型列表
@@ -124,6 +135,9 @@ class ArrayType : public Type {
// numElements该维度的大小 (例如int[3] 的 numElements 是 3)
static ArrayType *get(Type *elementType, unsigned numElements);
// Cleanup method to release all cached array types (call at program exit)
static void cleanup();
Type *getElementType() const { return elementType; }
unsigned getNumElements() const { return numElements; }
@@ -202,9 +216,11 @@ class Use {
public:
unsigned getIndex() const { return index; } ///< 返回value在User操作数中的位置
void setIndex(int newIndex) { index = newIndex; } ///< 设置value在User操作数中的位置
User* getUser() const { return user; } ///< 返回使用者
Value* getValue() const { return value; } ///< 返回被使用的值
void setValue(Value *newValue) { value = newValue; } ///< 将被使用的值设置为newValue
void print(std::ostream& os) const;
};
//! The base class of all value types
@@ -229,7 +245,15 @@ class Value {
std::list<std::shared_ptr<Use>>& getUses() { return uses; } ///< 获取使用关系列表
void addUse(const std::shared_ptr<Use> &use) { uses.push_back(use); } ///< 添加使用关系
void replaceAllUsesWith(Value *value); ///< 将原来使用该value的使用者全变为使用给定参数value并修改相应use关系
void removeUse(const std::shared_ptr<Use> &use) { uses.remove(use); } ///< 删除使用关系use
void removeUse(const std::shared_ptr<Use> &use) {
assert(use != nullptr && "Use cannot be null");
assert(use->getValue() == this && "Use being removed does NOT point to this Value!");
auto it = std::find(uses.begin(), uses.end(), use);
assert(it != uses.end() && "Use not found in Value's uses");
uses.remove(use);
} ///< 删除使用关系use
void removeAllUses();
virtual void print(std::ostream& os) const = 0; ///< 输出值信息到输出流
};
/**
@@ -357,6 +381,9 @@ public:
// Static factory method to get a canonical ConstantValue from the pool
static ConstantValue* get(Type* type, ConstantValVariant val);
// Cleanup method to release all cached constants (call at program exit)
static void cleanup();
// Helper methods to access constant values with appropriate casting
int getInt() const {
auto val = getVal();
@@ -394,6 +421,7 @@ public:
virtual bool isZero() const = 0;
virtual bool isOne() const = 0;
void print(std::ostream& os) const = 0;
};
class ConstantInteger : public ConstantValue {
@@ -420,6 +448,7 @@ public:
bool isZero() const override { return constVal == 0; }
bool isOne() const override { return constVal == 1; }
void print(std::ostream& os) const;
};
class ConstantFloating : public ConstantValue {
@@ -446,6 +475,7 @@ public:
bool isZero() const override { return constFVal == 0.0f; }
bool isOne() const override { return constFVal == 1.0f; }
void print(std::ostream& os) const;
};
class UndefinedValue : public ConstantValue {
@@ -461,6 +491,9 @@ protected:
public:
static UndefinedValue* get(Type* type);
// Cleanup method to release all cached undefined values (call at program exit)
static void cleanup();
size_t hash() const override {
return std::hash<Type*>{}(getType());
}
@@ -477,6 +510,7 @@ public:
bool isZero() const override { return false; }
bool isOne() const override { return false; }
void print(std::ostream& os) const;
};
// --- End of refactored ConstantValue and related classes ---
@@ -617,6 +651,11 @@ public:
}
} ///< 移除指定位置的指令
iterator moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block);
/// 清理基本块中的所有使用关系
void cleanup();
void print(std::ostream& os) const;
};
//! User is the abstract base type of `Value` types which use other `Value` as
@@ -633,21 +672,6 @@ class User : public Value {
explicit User(Type *type, const std::string &name = "") : Value(type, name) {}
public:
// ~User() override {
// // 当 User 对象被销毁时例如LoadInst 或 StoreInst 被删除时),
// // 它必须通知它所使用的所有 Value将对应的 Use 关系从它们的 uses 列表中移除。
// // 这样可以防止 Value 的 uses 列表中出现悬空的 Use 对象。
// for (const auto &use_ptr : operands) {
// // 确保 use_ptr 非空,并且其内部指向的 Value* 也非空
// // (虽然通常情况下不会为空,但为了健壮性考虑)
// if (use_ptr && use_ptr->getValue()) {
// use_ptr->getValue()->removeUse(use_ptr);
// }
// }
// // operands 向量本身是 std::vector<std::shared_ptr<Use>>
// // 在此析构函数结束后operands 向量会被销毁,其内部的 shared_ptr 也会被释放,
// // 如果 shared_ptr 引用计数降为0Use 对象本身也会被销毁。
// }
unsigned getNumOperands() const { return operands.size(); } ///< 获取操作数数量
auto operand_begin() const { return operands.begin(); } ///< 返回操作数列表的开头迭代器
auto operand_end() const { return operands.end(); } ///< 返回操作数列表的结尾迭代器
@@ -657,11 +681,7 @@ class User : public Value {
operands.emplace_back(std::make_shared<Use>(operands.size(), this, value));
value->addUse(operands.back());
} ///< 增加操作数
void removeOperand(unsigned index) {
auto value = getOperand(index);
value->removeUse(operands[index]);
operands.erase(operands.begin() + index);
} ///< 移除操作数
void removeOperand(unsigned index);
template <typename ContainerT>
void addOperands(const ContainerT &newoperands) {
for (auto value : newoperands) {
@@ -670,6 +690,9 @@ class User : public Value {
} ///< 增加多个操作数
void replaceOperand(unsigned index, Value *value); ///< 替换操作数
void setOperand(unsigned index, Value *value); ///< 设置操作数
/// 清理用户的所有操作数使用关系
void cleanup();
};
/*!
@@ -728,6 +751,8 @@ class Instruction : public User {
kPhi = 0x1UL << 39,
kBitItoF = 0x1UL << 40,
kBitFtoI = 0x1UL << 41,
kSRA = 0x1UL << 42,
kMulh = 0x1UL << 43
};
protected:
@@ -745,57 +770,57 @@ public:
std::string getKindString() const{
switch (kind) {
case kInvalid:
return "Invalid";
return "invalid";
case kAdd:
return "Add";
return "add";
case kSub:
return "Sub";
return "sub";
case kMul:
return "Mul";
return "mul";
case kDiv:
return "Div";
return "sdiv";
case kRem:
return "Rem";
return "srem";
case kICmpEQ:
return "ICmpEQ";
return "icmp eq";
case kICmpNE:
return "ICmpNE";
return "icmp ne";
case kICmpLT:
return "ICmpLT";
return "icmp slt";
case kICmpGT:
return "ICmpGT";
return "icmp sgt";
case kICmpLE:
return "ICmpLE";
return "icmp sle";
case kICmpGE:
return "ICmpGE";
return "icmp sge";
case kFAdd:
return "FAdd";
return "fadd";
case kFSub:
return "FSub";
return "fsub";
case kFMul:
return "FMul";
return "fmul";
case kFDiv:
return "FDiv";
return "fdiv";
case kFCmpEQ:
return "FCmpEQ";
return "fcmp oeq";
case kFCmpNE:
return "FCmpNE";
return "fcmp one";
case kFCmpLT:
return "FCmpLT";
return "fcmp olt";
case kFCmpGT:
return "FCmpGT";
return "fcmp ogt";
case kFCmpLE:
return "FCmpLE";
return "fcmp ole";
case kFCmpGE:
return "FCmpGE";
return "fcmp oge";
case kAnd:
return "And";
return "and";
case kOr:
return "Or";
return "or";
case kNeg:
return "Neg";
return "neg";
case kNot:
return "Not";
return "not";
case kFNeg:
return "FNeg";
case kFNot:
@@ -803,27 +828,35 @@ public:
case kFtoI:
return "FtoI";
case kItoF:
return "IToF";
return "iToF";
case kCall:
return "Call";
return "call";
case kCondBr:
return "CondBr";
return "condBr";
case kBr:
return "Br";
return "br";
case kReturn:
return "Return";
return "return";
case kUnreachable:
return "unreachable";
case kAlloca:
return "Alloca";
return "alloca";
case kLoad:
return "Load";
return "load";
case kStore:
return "Store";
return "store";
case kGetElementPtr:
return "GetElementPtr";
return "getElementPtr";
case kMemset:
return "Memset";
return "memset";
case kPhi:
return "Phi";
return "phi";
case kBitItoF:
return "BitItoF";
case kBitFtoI:
return "BitFtoI";
case kSRA:
return "ashr";
default:
return "Unknown";
}
@@ -835,11 +868,15 @@ public:
bool isBinary() const {
static constexpr uint64_t BinaryOpMask =
(kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr) |
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) |
(kAdd | kSub | kMul | kDiv | kRem | kAnd | kOr | kSRA | kMulh) |
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE);
return kind & BinaryOpMask;
}
bool isFPBinary() const {
static constexpr uint64_t FPBinaryOpMask =
(kFAdd | kFSub | kFMul | kFDiv) |
(kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE);
return kind & BinaryOpMask;
return kind & FPBinaryOpMask;
}
bool isUnary() const {
static constexpr uint64_t UnaryOpMask =
@@ -886,6 +923,10 @@ public:
static constexpr uint64_t DefineOpMask = kAlloca | kStore | kPhi;
return (kind & DefineOpMask) != 0U;
}
virtual ~Instruction() = default;
virtual void print(std::ostream& os) const = 0;
}; // class Instruction
class Function;
@@ -907,66 +948,56 @@ class PhiInst : public Instruction {
const std::string &name = "")
: Instruction(Kind::kPhi, type, parent, name), vsize(rhs.size()) {
assert(rhs.size() == Blocks.size() && "PhiInst: rhs and Blocks must have the same size");
for(size_t i = 0; i < rhs.size(); ++i) {
for(size_t i = 0; i < vsize; ++i) {
addOperand(rhs[i]);
addOperand(Blocks[i]);
blk2val[Blocks[i]] = rhs[i];
}
}
public:
Value* getValue(unsigned k) const {return getOperand(2 * k);} ///< 获取位置为k的值
BasicBlock* getBlock(unsigned k) const {return dynamic_cast<BasicBlock*>(getOperand(2 * k + 1));}
//增加llvm同名方法实现获取value和block
Value* getIncomingValue(unsigned k) const {return getOperand(2 * k);} ///< 获取位置为k的值
BasicBlock* getIncomingBlock(unsigned k) const {return dynamic_cast<BasicBlock*>(getOperand(2 * k + 1));}
Value* getIncomingValue(BasicBlock* blk) const {
return getvalfromBlk(blk);
} ///< 获取指定基本块的传入值
BasicBlock* getIncomingBlock(Value* val) const {
return getBlkfromVal(val);
} ///< 获取指定值的传入基本块
void replaceIncoming(BasicBlock *oldBlock, BasicBlock *newBlock, Value *newValue){
delBlk(oldBlock);
addIncoming(newValue, newBlock);
}
auto& getincomings() const {return blk2val;} ///< 获取所有的基本块和对应的值
auto getIncomingValues() const {
std::vector<std::pair<BasicBlock*, Value*>> result;
for (const auto& [block, value] : blk2val) {
result.emplace_back(block, value);
}
return result;
}
Value* getvalfromBlk(BasicBlock* blk) const ;
BasicBlock* getBlkfromVal(Value* val) const ;
unsigned getNumIncomingValues() const { return vsize; } ///< 获取传入值的数量
Value *getIncomingValue(unsigned Idx) const { return getOperand(Idx * 2); } ///< 获取指定位置的传入值
BasicBlock *getIncomingBlock(unsigned Idx) const {return dynamic_cast<BasicBlock *>(getOperand(Idx * 2 + 1)); } ///< 获取指定位置的传入基本块
Value* getValfromBlk(BasicBlock* block);
BasicBlock* getBlkfromVal(Value* value);
void addIncoming(Value *value, BasicBlock *block) {
assert(value && block && "PhiInst: value and block must not be null");
assert(value && block && "PhiInst: value and block cannot be null");
addOperand(value);
addOperand(block);
blk2val[block] = value;
vsize++;
} ///< 添加传入值和对应的基本块
void removeIncoming(BasicBlock *block){
delBlk(block);
void removeIncoming(unsigned Idx) {
assert(Idx < vsize && "PhiInst: Index out of bounds");
auto blk = getIncomingBlock(Idx);
removeOperand(Idx * 2 + 1); // Remove block
removeOperand(Idx * 2); // Remove value
blk2val.erase(blk);
vsize--;
} ///< 移除指定位置的传入值和对应的基本块
// 移除指定的传入值或基本块
void removeIncomingValue(Value *value);
void removeIncomingBlock(BasicBlock *block);
// 设置指定位置的传入值或基本块
void setIncomingValue(unsigned Idx, Value *value);
void setIncomingBlock(unsigned Idx, BasicBlock *block);
// 替换指定位置的传入值或基本块(原理是删除再添加)保留旧块或者旧值
void replaceIncomingValue(Value *oldValue, Value *newValue);
void replaceIncomingBlock(BasicBlock *oldBlock, BasicBlock *newBlock);
// 替换指定位置的传入值或基本块(原理是删除再添加)
void replaceIncomingValue(Value *oldValue, Value *newValue, BasicBlock *newBlock);
void replaceIncomingBlock(BasicBlock *oldBlock, BasicBlock *newBlock, Value *newValue);
void refreshMap() {
blk2val.clear();
for (unsigned i = 0; i < vsize; ++i) {
blk2val[getIncomingBlock(i)] = getIncomingValue(i);
}
void delValue(Value* val);
void delBlk(BasicBlock* blk);
void replaceBlk(BasicBlock* newBlk, unsigned k);
void replaceold2new(BasicBlock* oldBlk, BasicBlock* newBlk);
void refreshB2VMap();
} ///< 刷新块到值的映射关系
auto getValues() { return make_range(std::next(operand_begin()), operand_end()); }
void print(std::ostream& os) const override;
};
@@ -975,16 +1006,14 @@ class CallInst : public Instruction {
friend class IRBuilder;
protected:
CallInst(Function *callee, const std::vector<Value *> &args = {},
BasicBlock *parent = nullptr, const std::string &name = "");
CallInst(Function *callee, const std::vector<Value *> &args, BasicBlock *parent = nullptr, const std::string &name = "");
public:
Function* getCallee() const;
Function *getCallee() const;
auto getArguments() const {
return make_range(std::next(operand_begin()), operand_end());
}
void print(std::ostream& os) const override;
}; // class CallInst
//! Unary instruction, includes '!', '-' and type conversion.
@@ -1002,7 +1031,7 @@ protected:
public:
Value* getOperand() const { return User::getOperand(0); }
void print(std::ostream& os) const override;
}; // class UnaryInst
//! Binary instruction, e.g., arithmatic, relation, logic, etc.
@@ -1081,6 +1110,7 @@ public:
// 后端处理数组访存操作时需要创建计算地址的指令,需要在外部构造 BinaryInst 对象
return new BinaryInst(kind, type, lhs, rhs, parent, name);
}
void print(std::ostream& os) const override;
}; // class BinaryInst
//! The return statement
@@ -1101,6 +1131,7 @@ class ReturnInst : public Instruction {
Value* getReturnValue() const {
return hasReturnValue() ? getOperand(0) : nullptr;
}
void print(std::ostream& os) const override;
};
//! Unconditional branch
@@ -1130,7 +1161,7 @@ public:
}
return succs;
}
void print(std::ostream& os) const override;
}; // class UncondBrInst
//! Conditional branch
@@ -1170,7 +1201,7 @@ public:
}
return succs;
}
void print(std::ostream& os) const override;
}; // class CondBrInst
class UnreachableInst : public Instruction {
@@ -1178,7 +1209,7 @@ public:
// 构造函数:设置指令类型为 kUnreachable
explicit UnreachableInst(const std::string& name, BasicBlock *parent = nullptr)
: Instruction(kUnreachable, Type::getVoidType(), parent, "") {}
void print(std::ostream& os) const { os << "unreachable"; }
};
//! Allocate memory for stack variables, used for non-global variable declartion
@@ -1196,7 +1227,7 @@ public:
Type* getAllocatedType() const {
return getType()->as<PointerType>()->getBaseType();
} ///< 获取分配的类型
void print(std::ostream& os) const override;
}; // class AllocaInst
@@ -1234,6 +1265,7 @@ public:
BasicBlock *parent = nullptr, const std::string &name = "") {
return new GetElementPtrInst(resultType, basePointer, indices, parent, name);
}
void print(std::ostream& os) const override;
};
//! Load a value from memory address specified by a pointer value
@@ -1251,7 +1283,7 @@ protected:
public:
Value* getPointer() const { return getOperand(0); }
void print(std::ostream& os) const override;
}; // class LoadInst
//! Store a value to memory address specified by a pointer value
@@ -1270,7 +1302,7 @@ protected:
public:
Value* getValue() const { return getOperand(0); }
Value* getPointer() const { return getOperand(1); }
void print(std::ostream& os) const override;
}; // class StoreInst
//! Memset instruction
@@ -1300,7 +1332,7 @@ public:
Value* getBegin() const { return getOperand(1); }
Value* getSize() const { return getOperand(2); }
Value* getValue() const { return getOperand(3); }
void print(std::ostream& os) const override;
};
class GlobalValue;
@@ -1318,6 +1350,11 @@ public:
public:
Function* getParent() const { return func; }
int getIndex() const { return index; }
/// 清理参数的使用关系
void cleanup();
void print(std::ostream& os) const;
};
@@ -1406,6 +1443,11 @@ protected:
blocks.emplace_front(block);
return block;
}
/// 清理函数中的所有使用关系
void cleanup();
void print(std::ostream& os) const;
};
//! Global value declared at file scope
@@ -1471,6 +1513,7 @@ public:
return getByIndex(index);
} ///< 通过多维索引indices获取初始值
const ValueCounter& getInitValues() const { return initValues; }
void print(std::ostream& os) const;
}; // class GlobalValue
@@ -1528,6 +1571,8 @@ class ConstantVariable : public Value {
return getByIndex(index);
} ///< 通过多维索引indices获取初始值
const ValueCounter& getInitValues() const { return initValues; } ///< 获取初始值
void print(std::ostream& os) const;
void print_init(std::ostream& os) const;
};
using SymbolTableNode = struct SymbolTableNode {
@@ -1550,6 +1595,8 @@ class SymbolTable {
Value* getVariable(const std::string &name) const; ///< 根据名字name以及当前作用域获取变量
Value* addVariable(const std::string &name, Value *variable); ///< 添加变量
void registerParameterName(const std::string &name); ///< 注册函数参数名字避免alloca重名
void addVariableDirectly(const std::string &name, Value *variable); ///< 直接添加变量到当前作用域,不重命名
std::vector<std::unique_ptr<GlobalValue>>& getGlobals(); ///< 获取全局变量列表
const std::vector<std::unique_ptr<ConstantVariable>>& getConsts() const; ///< 获取全局常量列表
void enterNewScope(); ///< 进入新的作用域
@@ -1557,6 +1604,9 @@ class SymbolTable {
bool isInGlobalScope() const; ///< 是否位于全局作用域
void enterGlobalScope(); ///< 进入全局作用域
bool isCurNodeNull() { return curNode == nullptr; }
/// 清理符号表中的所有内容
void cleanup();
};
//! IR unit for representing a SysY compile unit
@@ -1609,6 +1659,12 @@ class Module {
void addVariable(const std::string &name, AllocaInst *variable) {
variableTable.addVariable(name, variable);
} ///< 添加变量
void addVariableDirectly(const std::string &name, AllocaInst *variable) {
variableTable.addVariableDirectly(name, variable);
} ///< 直接添加变量到当前作用域,不重命名
void registerParameterName(const std::string &name) {
variableTable.registerParameterName(name);
} ///< 注册函数参数名字避免alloca重名
Value* getVariable(const std::string &name) {
return variableTable.getVariable(name);
} ///< 根据名字name和当前作用域获取变量
@@ -1621,7 +1677,7 @@ class Module {
} ///< 获取函数
Function* getExternalFunction(const std::string &name) const {
auto result = externalFunctions.find(name);
if (result == functions.end()) {
if (result == externalFunctions.end()) {
return nullptr;
}
return result->second.get();
@@ -1641,6 +1697,11 @@ class Module {
void leaveScope() { variableTable.leaveScope(); } ///< 离开作用域
bool isInGlobalArea() const { return variableTable.isInGlobalScope(); } ///< 是否位于全局作用域
/// 清理模块中的所有对象,包括函数、基本块、指令等
void cleanup();
void print(std::ostream& os) const;
};
/*!

View File

@@ -217,6 +217,12 @@ class IRBuilder {
BinaryInst * createOrInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kOr, Type::getIntType(), lhs, rhs, name);
} ///< 创建按位或指令
BinaryInst * createSRAInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kSRA, Type::getIntType(), lhs, rhs, name);
} ///< 创建算术右移指令
BinaryInst * createMulhInst(Value *lhs, Value *rhs, const std::string &name = "") {
return createBinaryInst(Instruction::kMulh, Type::getIntType(), lhs, rhs, name);
} ///< 创建高位乘法指令
CallInst * createCallInst(Function *callee, const std::vector<Value *> &args, const std::string &name = "") {
std::string newName;
if (name.empty() && callee->getReturnType() != Type::getVoidType()) {
@@ -344,38 +350,31 @@ class IRBuilder {
Type *currentWalkType = pointerType->as<PointerType>()->getBaseType();
// 遍历所有索引来深入类型层次结构。
// `indices` 向量包含了所有 GEP 索引,包括由 `visitLValue` 等函数添加的初始 `0` 索引
// 重要:第一个索引总是用于"解引用"指针,后续索引才用于数组/结构体的索引
for (int i = 0; i < indices.size(); ++i) {
if (i == 0) {
// 第一个索引:总是用于"解引用"基指针不改变currentWalkType
// 例如:对于 `[4 x i32]* ptr, i32 0`第一个0只是说"访问ptr指向的对象"
// currentWalkType 保持为 `[4 x i32]`
continue;
} else {
// 后续索引:用于实际的数组/结构体索引
if (currentWalkType->isArray()) {
// 情况一:当前遍历类型是 `ArrayType`。
// 索引用于选择数组元素,`currentWalkType` 更新为数组的元素类型。
// 数组索引:选择数组中的元素
currentWalkType = currentWalkType->as<ArrayType>()->getElementType();
} else if (currentWalkType->isPointer()) {
// 情况二:当前遍历类型是 `PointerType`。
// 这意味着我们正在通过一个指针来访问其指向的内存。
// 索引用于选择该指针所指向的“数组”的元素。
// `currentWalkType` 更新为该指针所指向的基础类型。
// 例如:如果 `currentWalkType` 是 `i32*`,它将变为 `i32`。
// 如果 `currentWalkType` 是 `[10 x i32]*`,它将变为 `[10 x i32]`。
// 指针索引:解引用指针并继续
currentWalkType = currentWalkType->as<PointerType>()->getBaseType();
} else {
// 情况三:当前遍历类型是标量类型 (例如 `i32`, `float` 等非聚合、非指针类型)。
//
// 如果 `currentWalkType` 是标量,并且当前索引 `i` **不是** `indices` 向量中的最后一个索引,
// 这意味着尝试对一个标量类型进行进一步的结构性索引,这是**无效的**。
// 例如:`int x; x[0];` 对应的 GEP 链中,`x` 的类型是 `i32`,再加 `[0]` 索引就是错误。
//
// 如果 `currentWalkType` 是标量,且这是**最后一个索引** (`i == indices.size() - 1`)
// 那么 GEP 是合法的,它只是计算一个偏移地址,最终的类型就是这个标量类型。
// 此时 `currentWalkType` 保持不变,循环结束。
// 标量类型:不能进一步索引
if (i < indices.size() - 1) {
assert(false && "Invalid GEP indexing: attempting to index into a non-aggregate/non-pointer type with further indices.");
return nullptr; // 返回空指针表示类型推断失败
}
// 如果是最后一个索引,且当前类型是标量,则类型保持不变,这是合法的。
// 循环会自然结束,返回正确的 `currentWalkType`。
return nullptr;
}
}
}
}
// 所有索引处理完毕后,`currentWalkType` 就是 GEP 指令最终计算出的地址所指向的元素的类型。
return currentWalkType;
}

View File

@@ -6,30 +6,82 @@
#include <set>
#include <vector>
#include <algorithm>
#include <functional>
namespace sysy {
// 支配树分析结果类 (保持不变)
// 支配树分析结果类
class DominatorTree : public AnalysisResultBase {
public:
DominatorTree(Function* F);
// 获取指定基本块的所有支配者
const std::set<BasicBlock*>* getDominators(BasicBlock* BB) const;
// 获取指定基本块的即时支配者 (Immediate Dominator)
BasicBlock* getImmediateDominator(BasicBlock* BB) const;
// 获取指定基本块的支配边界 (Dominance Frontier)
const std::set<BasicBlock*>* getDominanceFrontier(BasicBlock* BB) const;
// 获取指定基本块在支配树中的子节点
const std::set<BasicBlock*>* getDominatorTreeChildren(BasicBlock* BB) const;
// 额外的 Getter获取所有支配者、即时支配者和支配边界的完整映射可选主要用于调试或特定场景
const std::map<BasicBlock*, std::set<BasicBlock*>>& getDominatorsMap() const { return Dominators; }
const std::map<BasicBlock*, BasicBlock*>& getIDomsMap() const { return IDoms; }
const std::map<BasicBlock*, std::set<BasicBlock*>>& getDominanceFrontiersMap() const { return DominanceFrontiers; }
// 计算所有基本块的支配者集合
void computeDominators(Function* F);
// 计算所有基本块的即时支配者(内部使用 Lengauer-Tarjan 算法)
void computeIDoms(Function* F);
// 计算所有基本块的支配边界
void computeDominanceFrontiers(Function* F);
// 计算支配树的结构(即每个节点的直接子节点)
void computeDominatorTreeChildren(Function* F);
private:
// 与该支配树关联的函数
Function* AssociatedFunction;
std::map<BasicBlock*, std::set<BasicBlock*>> Dominators;
std::map<BasicBlock*, BasicBlock*> IDoms;
std::map<BasicBlock*, std::set<BasicBlock*>> DominanceFrontiers;
std::map<BasicBlock*, std::set<BasicBlock*>> DominatorTreeChildren;
std::map<BasicBlock*, std::set<BasicBlock*>> Dominators; // 每个基本块的支配者集合
std::map<BasicBlock*, BasicBlock*> IDoms; // 每个基本块的即时支配者
std::map<BasicBlock*, std::set<BasicBlock*>> DominanceFrontiers; // 每个基本块的支配边界
std::map<BasicBlock*, std::set<BasicBlock*>> DominatorTreeChildren; // 支配树中每个基本块的子节点
// ==========================================================
// Lengauer-Tarjan 算法内部所需的数据结构和辅助函数
// 这些成员是私有的,以封装 LT 算法的复杂性并避免命名空间污染
// ==========================================================
// DFS 遍历相关:
std::map<BasicBlock*, int> dfnum_map; // 存储每个基本块的 DFS 编号
std::vector<BasicBlock*> vertex_vec; // 通过 DFS 编号反向查找对应的基本块指针
std::map<BasicBlock*, BasicBlock*> parent_map; // 存储 DFS 树中每个基本块的父节点
int df_counter; // DFS 计数器,也代表 DFS 遍历的总节点数 (N)
// 半支配者 (Semi-dominator) 相关:
std::map<BasicBlock*, BasicBlock*> sdom_map; // 存储每个基本块的半支配者
std::map<BasicBlock*, BasicBlock*> idom_map; // 存储每个基本块的即时支配者 (IDom)
std::map<BasicBlock*, std::vector<BasicBlock*>> bucket_map; // 桶结构,用于存储具有相同半支配者的节点,以延迟 IDom 计算
// 并查集 (Union-Find) 相关(用于 evalAndCompress 函数):
std::map<BasicBlock*, BasicBlock*> ancestor_map; // 并查集中的父节点(用于路径压缩)
std::map<BasicBlock*, BasicBlock*> label_map; // 并查集中,每个集合的代表节点(或其路径上 sdom 最小的节点)
// ==========================================================
// 辅助计算函数 (私有)
// ==========================================================
// 计算基本块的逆后序遍历 (Reverse Post Order, RPO) 顺序
// RPO 用于优化支配者计算和 LT 算法的效率
std::vector<BasicBlock*> computeReversePostOrder(Function* F);
// Lengauer-Tarjan 算法特定的辅助 DFS 函数
// 用于初始化 dfnum_map, vertex_vec, parent_map
void dfs_lt_helper(BasicBlock* u);
// 结合了并查集的 Find 操作和 LT 算法的 Eval 操作
// 用于在路径压缩时更新 label找到路径上 sdom 最小的节点
BasicBlock* evalAndCompress_lt_helper(BasicBlock* i);
// 并查集的 Link 操作
// 将 v_child 挂载到 u_parent 的并查集树下
void link_lt_helper(BasicBlock* u_parent, BasicBlock* v_child);
};

View File

@@ -0,0 +1,20 @@
#pragma once
#include "IR.h"
#include "Pass.h"
#include <queue>
#include <set>
namespace sysy {
class BuildCFG : public OptimizationPass {
public:
static void *ID;
BuildCFG() : OptimizationPass("BuildCFG", Granularity::Function) {}
bool runOnFunction(Function *F, AnalysisManager &AM) override;
void getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const override;
void *getPassID() const override { return &ID; }
};
} // namespace sysy

View File

@@ -0,0 +1,24 @@
#pragma once
#include "../Pass.h"
namespace sysy {
class LargeArrayToGlobalPass : public OptimizationPass {
public:
static void *ID;
LargeArrayToGlobalPass() : OptimizationPass("LargeArrayToGlobal", Granularity::Module) {}
bool runOnModule(Module *M, AnalysisManager &AM) override;
void *getPassID() const override {
return &ID;
}
private:
unsigned calculateTypeSize(Type *type);
void convertAllocaToGlobal(AllocaInst *alloca, Function *F, Module *M);
std::string generateUniqueGlobalName(AllocaInst *alloca, Function *F);
};
} // namespace sysy

View File

@@ -48,13 +48,6 @@ public:
}
}
}
// 清空 User 的 operands 向量。这会递减 User 持有的 shared_ptr<Use> 的引用计数。
// 当引用计数降为 0 时Use 对象本身将被销毁。
// User::operands.clear(); // 这个步骤会在 Instruction 的析构函数中自动完成,因为它是 vector 成员
// 或者我们可以在 User::removeOperand 方法中确保 Use 对象从 operands 中移除。
// 实际上,只要 Value::removeUse(use_ptr) 被调用了,
// 当 Instruction 所在的 unique_ptr 销毁时,它的 operands vector 也会被销毁。
// 所以这里不需要显式 clear()
}
static void usedelete(Instruction *inst) {
assert(inst && "Instruction to delete cannot be null.");
@@ -92,7 +85,7 @@ public:
// 步骤3: 物理删除指令并返回下一个迭代器
return parentBlock->removeInst(inst_it);
}
}
// 判断是否是全局变量
static bool isGlobal(Value *val) {

View File

@@ -299,7 +299,7 @@ private:
IRBuilder *pBuilder;
public:
PassManager() = default;
PassManager() = delete;
~PassManager() = default;
PassManager(Module *module, IRBuilder *builder) : pmodule(module) ,pBuilder(builder), analysisManager(module) {}

View File

@@ -86,7 +86,60 @@ private:
case LPAREN: case RPAREN: return 0; // Parentheses have lowest precedence for stack logic
default: return -1; // Unknown operator
}
};
struct ExpKey {
BinaryOp op; ///< 操作符
Value *left; ///< 左操作数
Value *right; ///< 右操作数
ExpKey(BinaryOp op, Value *left, Value *right) : op(op), left(left), right(right) {}
bool operator<(const ExpKey &other) const {
if (op != other.op)
return op < other.op; ///< 比较操作符
if (left != other.left)
return left < other.left; ///< 比较左操作
return right < other.right; ///< 比较右操作数
} ///< 重载小于运算符用于比较ExpKey
};
struct UnExpKey {
BinaryOp op; ///< 一元操作符
Value *operand; ///< 操作数
UnExpKey(BinaryOp op, Value *operand) : op(op), operand(operand) {}
bool operator<(const UnExpKey &other) const {
if (op != other.op)
return op < other.op; ///< 比较操作符
return operand < other.operand; ///< 比较操作数
} ///< 重载小于运算符用于比较UnExpKey
};
struct GEPKey {
Value *basePointer;
std::vector<Value *> indices;
// 为 std::map 定义比较运算符,使得 GEPKey 可以作为键
bool operator<(const GEPKey &other) const {
if (basePointer != other.basePointer) {
return basePointer < other.basePointer;
}
// 逐个比较索引,确保顺序一致
if (indices.size() != other.indices.size()) {
return indices.size() < other.indices.size();
}
for (size_t i = 0; i < indices.size(); ++i) {
if (indices[i] != other.indices[i]) {
return indices[i] < other.indices[i];
}
}
return false; // 如果 basePointer 和所有索引都相同,则认为相等
}
};
std::map<GEPKey, Value*> availableGEPs; ///< 用于存储 GEP 的缓存
std::map<ExpKey, Value*> availableBinaryExpressions;
std::map<UnExpKey, Value*> availableUnaryExpressions;
std::map<Value*, Value*> availableLoads;
public:
SysYIRGenerator() = default;
@@ -167,6 +220,15 @@ public:
Value* computeExp(SysYParser::ExpContext *ctx, Type* targetType = nullptr);
Value* computeAddExp(SysYParser::AddExpContext *ctx, Type* targetType = nullptr);
void compute();
// 参数是发生 store 操作的目标地址/变量的 Value*
void invalidateExpressionsOnStore(Value* storedAddress);
// 清除因函数调用而失效的表达式缓存(保守策略)
void invalidateExpressionsOnCall();
// 在进入新的基本块时清空所有表达式缓存
void enterNewBasicBlock();
public:
// 获取GEP指令的地址
Value* getGEPAddressInst(Value* basePointer, const std::vector<Value*>& indices);

View File

@@ -18,6 +18,8 @@ add_library(midend_lib STATIC
Pass/Optimize/SysYIRCFGOpt.cpp
Pass/Optimize/SCCP.cpp
Pass/Optimize/LoopNormalization.cpp
Pass/Optimize/BuildCFG.cpp
Pass/Optimize/LargeArrayToGlobal.cpp
)
# 包含中端模块所需的头文件路径

File diff suppressed because it is too large Load Diff

View File

@@ -1,21 +1,30 @@
#include "Dom.h"
#include <algorithm> // for std::set_intersection, std::set_difference, std::set_union
#include <algorithm> // for std::set_intersection, std::reverse
#include <iostream> // for debug output
#include <limits> // for std::numeric_limits
#include <queue>
#include <functional> // for std::function
#include <map>
#include <vector>
#include <set>
namespace sysy {
// 初始化 支配树静态 ID
// ==============================================================
// DominatorTreeAnalysisPass 的静态ID
// ==============================================================
void *DominatorTreeAnalysisPass::ID = (void *)&DominatorTreeAnalysisPass::ID;
// ==============================================================
// DominatorTree 结果类的实现
// ==============================================================
// 构造函数:初始化关联函数,但不进行计算
DominatorTree::DominatorTree(Function *F) : AssociatedFunction(F) {
// 构造时可以不计算,在分析遍运行里计算并填充
// 构造时不需要计算,在分析遍运行里计算并填充
}
// Getter 方法 (保持不变)
const std::set<BasicBlock *> *DominatorTree::getDominators(BasicBlock *BB) const {
auto it = Dominators.find(BB);
if (it != Dominators.end()) {
@@ -48,7 +57,7 @@ const std::set<BasicBlock *> *DominatorTree::getDominatorTreeChildren(BasicBlock
return nullptr;
}
// 辅助函数:打印 BasicBlock 集合
// 辅助函数:打印 BasicBlock 集合 (保持不变)
void printBBSet(const std::string &prefix, const std::set<BasicBlock *> &s) {
if (!DEBUG)
return;
@@ -63,24 +72,52 @@ void printBBSet(const std::string &prefix, const std::set<BasicBlock *> &s) {
std::cout << "}" << std::endl;
}
// 辅助函数:计算逆后序遍历 (RPO) - 保持不变
std::vector<BasicBlock*> DominatorTree::computeReversePostOrder(Function* F) {
std::vector<BasicBlock*> postOrder;
std::set<BasicBlock*> visited;
std::function<void(BasicBlock*)> dfs_rpo =
[&](BasicBlock* bb) {
visited.insert(bb);
for (BasicBlock* succ : bb->getSuccessors()) {
if (visited.find(succ) == visited.end()) {
dfs_rpo(succ);
}
}
postOrder.push_back(bb);
};
dfs_rpo(F->getEntryBlock());
std::reverse(postOrder.begin(), postOrder.end());
if (DEBUG) {
std::cout << "--- Computed RPO: ";
for (BasicBlock* bb : postOrder) {
std::cout << bb->getName() << " ";
}
std::cout << "---" << std::endl;
}
return postOrder;
}
// computeDominators 方法 (保持不变因为它它是独立于IDom算法的)
void DominatorTree::computeDominators(Function *F) {
if (DEBUG)
std::cout << "--- Computing Dominators ---" << std::endl;
BasicBlock *entryBlock = F->getEntryBlock();
std::vector<BasicBlock *> bbs_in_order; // 用于确定遍历顺序,如果需要的话
std::vector<BasicBlock*> bbs_rpo = computeReversePostOrder(F);
// 初始化:入口块只被自己支配,其他块被所有块支配
for (const auto &bb_ptr : F->getBasicBlocks()) {
BasicBlock *bb = bb_ptr.get();
bbs_in_order.push_back(bb); // 收集所有块
for (BasicBlock *bb : bbs_rpo) {
if (bb == entryBlock) {
Dominators[bb].clear();
Dominators[bb].insert(bb);
if (DEBUG)
std::cout << "Init Dominators[" << bb->getName() << "]: {" << bb->getName() << "}" << std::endl;
if (DEBUG) std::cout << "Init Dominators[" << bb->getName() << "]: {" << bb->getName() << "}" << std::endl;
} else {
for (const auto &all_bb_ptr : F->getBasicBlocks()) {
Dominators[bb].insert(all_bb_ptr.get());
Dominators[bb].clear();
for (BasicBlock *all_bb : bbs_rpo) {
Dominators[bb].insert(all_bb);
}
if (DEBUG) {
std::cout << "Init Dominators[" << bb->getName() << "]: ";
@@ -94,23 +131,18 @@ void DominatorTree::computeDominators(Function *F) {
while (changed) {
changed = false;
iteration++;
if (DEBUG)
std::cout << "Iteration " << iteration << std::endl;
if (DEBUG) std::cout << "Iteration " << iteration << std::endl;
// 确保遍历顺序一致性例如可以按照DFS或BFS顺序或者简单的迭代器顺序
// 如果Function::getBasicBlocks()返回的迭代器顺序稳定则无需bbs_in_order
for (const auto &bb_ptr : F->getBasicBlocks()) { // 假设这个迭代器顺序稳定
BasicBlock *bb = bb_ptr.get();
if (bb == entryBlock)
continue;
for (BasicBlock *bb : bbs_rpo) {
if (bb == entryBlock) continue;
// 计算所有前驱的支配者集合的交集
std::set<BasicBlock *> newDom;
bool firstPredProcessed = false;
for (BasicBlock *pred : bb->getPredecessors()) {
// 确保前驱的支配者集合已经计算过
if (Dominators.count(pred)) {
if(DEBUG){
std::cout << " Processing predecessor: " << pred->getName() << std::endl;
}
if (!firstPredProcessed) {
newDom = Dominators[pred];
firstPredProcessed = true;
@@ -121,8 +153,7 @@ void DominatorTree::computeDominators(Function *F) {
newDom = intersection;
}
}
}
newDom.insert(bb); // BB 永远支配自己
newDom.insert(bb);
if (newDom != Dominators[bb]) {
if (DEBUG) {
@@ -140,78 +171,242 @@ void DominatorTree::computeDominators(Function *F) {
std::cout << "--- Dominators Computation Finished ---" << std::endl;
}
void DominatorTree::computeIDoms(Function *F) {
if (DEBUG)
std::cout << "--- Computing Immediate Dominators (IDoms) ---" << std::endl;
// ==============================================================
// Lengauer-Tarjan 算法辅助数据结构和函数 (私有成员)
// ==============================================================
BasicBlock *entryBlock = F->getEntryBlock();
IDoms[entryBlock] = nullptr; // 入口块没有即时支配者
// DFS 遍历,填充 dfnum_map, vertex_vec, parent_map
// 对应用户代码的 dfs 函数
void DominatorTree::dfs_lt_helper(BasicBlock* u) {
dfnum_map[u] = df_counter;
if (df_counter >= vertex_vec.size()) { // 动态调整大小
vertex_vec.resize(df_counter + 1);
}
vertex_vec[df_counter] = u;
if (DEBUG) std::cout << " DFS: Visiting " << u->getName() << ", dfnum = " << df_counter << std::endl;
df_counter++;
// 遍历所有非入口块
for (const auto &bb_ptr : F->getBasicBlocks()) {
BasicBlock *bb = bb_ptr.get();
if (bb == entryBlock)
continue;
BasicBlock *currentIDom = nullptr;
const std::set<BasicBlock *> *domsOfBB = getDominators(bb);
if (!domsOfBB) {
if (DEBUG)
std::cerr << "Warning: Dominators for " << bb->getName() << " not found!" << std::endl;
continue;
}
// 遍历bb的所有严格支配者 D (即 bb 的支配者中除了 bb 自身)
for (BasicBlock *D_candidate : *domsOfBB) {
if (D_candidate == bb)
continue; // 跳过bb自身
bool D_candidate_is_IDom = true;
// 检查是否存在另一个块 X使得 D_candidate 严格支配 X 且 X 严格支配 bb
// 或者更直接的,检查 D_candidate 是否被 bb 的所有其他严格支配者所支配
for (BasicBlock *X_other_dom : *domsOfBB) {
if (X_other_dom == bb || X_other_dom == D_candidate)
continue; // 跳过bb自身和D_candidate
// 如果 X_other_dom 严格支配 bb (它在 domsOfBB 中且不是bb自身)
// 并且 X_other_dom 不被 D_candidate 支配,那么 D_candidate 就不是 IDom
const std::set<BasicBlock *> *domsOfX_other_dom = getDominators(X_other_dom);
if (domsOfX_other_dom && domsOfX_other_dom->count(D_candidate)) { // X_other_dom 支配 D_candidate
// D_candidate 被另一个支配者 X_other_dom 支配
// 这说明 D_candidate 位于 X_other_dom 的“下方”X_other_dom 更接近 bb
// 因此 D_candidate 不是 IDom
D_candidate_is_IDom = false;
break;
for (BasicBlock* v : u->getSuccessors()) {
if (dfnum_map.find(v) == dfnum_map.end()) { // 如果 v 未访问过
parent_map[v] = u;
if (DEBUG) std::cout << " DFS: Setting parent[" << v->getName() << "] = " << u->getName() << std::endl;
dfs_lt_helper(v);
}
}
if (D_candidate_is_IDom) {
currentIDom = D_candidate;
break; // 找到即时支配者,可以退出循环,因为它是唯一的
}
}
IDoms[bb] = currentIDom;
if (DEBUG) {
std::cout << " IDom[" << bb->getName() << "] = " << (currentIDom ? currentIDom->getName() : "nullptr")
<< std::endl;
}
}
if (DEBUG)
std::cout << "--- Immediate Dominators Computation Finished ---" << std::endl;
}
/*
for each node n in a postorder traversal of the dominator tree:
df[n] = empty set
// compute DF_local(n)
for each child y of n in the CFG:
if idom[y] != n:
df[n] = df[n] U {y}
// compute DF_up(n)
for each child c of n in the dominator tree:
for each element w in df[c]:
if idom[w] != n:
df[n] = df[n] U {w}
*/
// 并查集:找到集合的代表,并进行路径压缩
// 同时更新 label确保 label[i] 总是指向其祖先链中 sdom_map 最小的节点
// 对应用户代码的 find 函数,也包含了 eval 的逻辑
BasicBlock* DominatorTree::evalAndCompress_lt_helper(BasicBlock* i) {
if (DEBUG) std::cout << " Eval: Processing " << i->getName() << std::endl;
// 如果 i 是根 (ancestor_map[i] == nullptr)
if (ancestor_map.find(i) == ancestor_map.end() || ancestor_map[i] == nullptr) {
if (DEBUG) std::cout << " Eval: " << i->getName() << " is root, returning itself." << std::endl;
return i; // 根节点自身就是路径上sdom最小的因为它没有祖先
}
// 如果 i 的祖先不是根,则递归查找并进行路径压缩
BasicBlock* root_ancestor = evalAndCompress_lt_helper(ancestor_map[i]);
// 路径压缩时,根据 sdom_map 比较并更新 label_map
// 确保 label_map[i] 存储的是 i 到 root_ancestor 路径上 sdom_map 最小的节点
// 注意:这里的 ancestor_map[i] 已经被递归调用压缩过一次了所以是root_ancestor的旧路径
// 应该比较的是 label_map[ancestor_map[i]] 和 label_map[i]
if (sdom_map.count(label_map[ancestor_map[i]]) && // 确保 label_map[ancestor_map[i]] 存在 sdom
sdom_map.count(label_map[i]) && // 确保 label_map[i] 存在 sdom
dfnum_map[sdom_map[label_map[ancestor_map[i]]]] < dfnum_map[sdom_map[label_map[i]]]) {
if (DEBUG) std::cout << " Eval: Updating label for " << i->getName() << " from "
<< label_map[i]->getName() << " to " << label_map[ancestor_map[i]]->getName() << std::endl;
label_map[i] = label_map[ancestor_map[i]];
}
ancestor_map[i] = root_ancestor; // 执行路径压缩:将 i 直接指向其所属集合的根
if (DEBUG) std::cout << " Eval: Path compression for " << i->getName() << ", new ancestor = "
<< (root_ancestor ? root_ancestor->getName() : "nullptr") << std::endl;
return label_map[i]; // <-- **将这里改为返回 label_map[i]**
}
// Link 函数:将 v 加入 u 的 DFS 树子树中 (实际上是并查集操作)
// 对应用户代码的 fa[u] = fth[u];
void DominatorTree::link_lt_helper(BasicBlock* u_parent, BasicBlock* v_child) {
ancestor_map[v_child] = u_parent; // 设置并查集父节点
label_map[v_child] = v_child; // 初始化 label 为自身
if (DEBUG) std::cout << " Link: " << v_child->getName() << " linked to " << u_parent->getName() << std::endl;
}
// ==============================================================
// Lengauer-Tarjan 算法实现 computeIDoms
// ==============================================================
void DominatorTree::computeIDoms(Function *F) {
if (DEBUG) std::cout << "--- Computing Immediate Dominators (IDoms) using Lengauer-Tarjan ---" << std::endl;
BasicBlock *entryBlock = F->getEntryBlock();
// 1. 初始化所有 LT 相关的数据结构
dfnum_map.clear();
vertex_vec.clear();
parent_map.clear();
sdom_map.clear();
idom_map.clear();
bucket_map.clear();
ancestor_map.clear();
label_map.clear();
df_counter = 0; // DFS 计数器从 0 开始
// 预分配 vertex_vec 的大小避免频繁resize
vertex_vec.resize(F->getBasicBlocks().size() + 1);
// 在 DFS 遍历之前,先为所有基本块初始化 sdom 和 label
// 这是 Lengauer-Tarjan 算法的要求,确保所有节点在 Phase 2 开始前都在 map 中
for (auto &bb_ptr : F->getBasicBlocks()) {
BasicBlock* bb = bb_ptr.get();
sdom_map[bb] = bb; // sdom(bb) 初始化为 bb 自身
label_map[bb] = bb; // label(bb) 初始化为 bb 自身 (用于 Union-Find 的路径压缩)
}
// 确保入口块也被正确初始化(如果它不在 F->getBasicBlocks() 的正常迭代中)
sdom_map[entryBlock] = entryBlock;
label_map[entryBlock] = entryBlock;
// Phase 1: DFS 遍历并预处理
// 对应用户代码的 dfs(st)
dfs_lt_helper(entryBlock);
idom_map[entryBlock] = nullptr; // 入口块没有即时支配者
if (DEBUG) std::cout << " IDom[" << entryBlock->getName() << "] = nullptr" << std::endl;
if (DEBUG) std::cout << " Sdom[" << entryBlock->getName() << "] = " << entryBlock->getName() << std::endl;
// 初始化并查集的祖先和 label
for (auto const& [bb_key, dfn_val] : dfnum_map) {
ancestor_map[bb_key] = nullptr; // 初始为独立集合的根
label_map[bb_key] = bb_key; // 初始 label 为自身
}
if (DEBUG) {
std::cout << " --- DFS Phase Complete ---" << std::endl;
std::cout << " dfnum_map:" << std::endl;
for (auto const& [bb, dfn] : dfnum_map) {
std::cout << " " << bb->getName() << " -> " << dfn << std::endl;
}
std::cout << " vertex_vec (by dfnum):" << std::endl;
for (size_t k = 0; k < df_counter; ++k) {
if (vertex_vec[k]) std::cout << " [" << k << "] -> " << vertex_vec[k]->getName() << std::endl;
}
std::cout << " parent_map:" << std::endl;
for (auto const& [child, parent] : parent_map) {
std::cout << " " << child->getName() << " -> " << (parent ? parent->getName() : "nullptr") << std::endl;
}
std::cout << " ------------------------" << std::endl;
}
// Phase 2: 计算半支配者 (sdom)
// 对应用户代码的 for (int i = dfc; i >= 2; --i) 循环的上半部分
// 按照 DFS 编号递减的顺序遍历所有节点 (除了 entryBlock它的 DFS 编号是 0)
if (DEBUG) std::cout << "--- Phase 2: Computing Semi-Dominators (sdom) ---" << std::endl;
for (int i = df_counter - 1; i >= 1; --i) { // 从 DFS 编号最大的节点开始,到 1
BasicBlock* w = vertex_vec[i]; // 当前处理的节点
if (DEBUG) std::cout << " Processing node w: " << w->getName() << " (dfnum=" << i << ")" << std::endl;
// 对于 w 的每个前驱 v
for (BasicBlock* v : w->getPredecessors()) {
if (DEBUG) std::cout << " Considering predecessor v: " << v->getName() << std::endl;
// 如果前驱 v 未被 DFS 访问过 (即不在 dfnum_map 中),则跳过
if (dfnum_map.find(v) == dfnum_map.end()) {
if (DEBUG) std::cout << " Predecessor " << v->getName() << " not in DFS tree, skipping." << std::endl;
continue;
}
// 调用 evalAndCompress 来找到 v 在其 DFS 树祖先链上具有最小 sdom 的节点
BasicBlock* u_with_min_sdom_on_path = evalAndCompress_lt_helper(v);
if (DEBUG) std::cout << " Eval(" << v->getName() << ") returned "
<< u_with_min_sdom_on_path->getName() << std::endl;
if (DEBUG && sdom_map.count(u_with_min_sdom_on_path) && sdom_map.count(w)) {
std::cout << " Comparing sdom: dfnum[" << sdom_map[u_with_min_sdom_on_path]->getName() << "] (" << dfnum_map[sdom_map[u_with_min_sdom_on_path]]
<< ") vs dfnum[" << sdom_map[w]->getName() << "] (" << dfnum_map[sdom_map[w]] << ")" << std::endl;
}
// 比较 sdom(u) 和 sdom(w)
if (sdom_map.count(u_with_min_sdom_on_path) && sdom_map.count(w) &&
dfnum_map[sdom_map[u_with_min_sdom_on_path]] < dfnum_map[sdom_map[w]]) {
if (DEBUG) std::cout << " Updating sdom[" << w->getName() << "] from "
<< sdom_map[w]->getName() << " to "
<< sdom_map[u_with_min_sdom_on_path]->getName() << std::endl;
sdom_map[w] = sdom_map[u_with_min_sdom_on_path]; // 更新 sdom(w)
if (DEBUG) std::cout << " Sdom update applied. New sdom[" << w->getName() << "] = " << sdom_map[w]->getName() << std::endl;
}
}
// 将 w 加入 sdom(w) 对应的桶中
bucket_map[sdom_map[w]].push_back(w);
if (DEBUG) std::cout << " Adding " << w->getName() << " to bucket of sdom(" << w->getName() << "): "
<< sdom_map[w]->getName() << std::endl;
// 将 w 的父节点加入并查集 (link 操作)
if (parent_map.count(w) && parent_map[w] != nullptr) {
link_lt_helper(parent_map[w], w);
}
// Phase 3-part 1: 处理 parent[w] 的桶中所有节点,确定部分 idom
if (parent_map.count(w) && parent_map[w] != nullptr) {
BasicBlock* p = parent_map[w]; // p 是 w 的父节点
if (DEBUG) std::cout << " Processing bucket for parent " << p->getName() << std::endl;
// 注意这里需要复制桶的内容因为原始桶在循环中会被clear
std::vector<BasicBlock*> nodes_in_p_bucket_copy = bucket_map[p];
for (BasicBlock* y : nodes_in_p_bucket_copy) {
if (DEBUG) std::cout << " Processing node y from bucket: " << y->getName() << std::endl;
// 找到 y 在其 DFS 树祖先链上具有最小 sdom 的节点
BasicBlock* u = evalAndCompress_lt_helper(y);
if (DEBUG) std::cout << " Eval(" << y->getName() << ") returned " << u->getName() << std::endl;
// 确定 idom(y)
// if sdom(eval(y)) == sdom(parent(w)), then idom(y) = parent(w)
// else idom(y) = eval(y)
if (sdom_map.count(u) && sdom_map.count(p) &&
dfnum_map[sdom_map[u]] < dfnum_map[sdom_map[p]]) {
idom_map[y] = u; // 确定的 idom
if (DEBUG) std::cout << " IDom[" << y->getName() << "] set to " << u->getName() << std::endl;
} else {
idom_map[y] = p; // p 是 y 的 idom
if (DEBUG) std::cout << " IDom[" << y->getName() << "] set to " << p->getName() << std::endl;
}
}
bucket_map[p].clear(); // 清空桶,防止重复处理
if (DEBUG) std::cout << " Cleared bucket for parent " << p->getName() << std::endl;
}
}
// Phase 3-part 2: 最终确定 idom (处理那些 idom != sdom 的节点)
if (DEBUG) std::cout << "--- Phase 3: Finalizing Immediate Dominators (idom) ---" << std::endl;
for (int i = 1; i < df_counter; ++i) { // 从 DFS 编号最小的节点 (除了 entryBlock) 开始
BasicBlock* w = vertex_vec[i];
if (DEBUG) std::cout << " Finalizing node w: " << w->getName() << std::endl;
if (idom_map.count(w) && sdom_map.count(w) && idom_map[w] != sdom_map[w]) {
// idom[w] 的 idom 是其真正的 idom
if (DEBUG) std::cout << " idom[" << w->getName() << "] (" << idom_map[w]->getName()
<< ") != sdom[" << w->getName() << "] (" << sdom_map[w]->getName() << ")" << std::endl;
if (idom_map.count(idom_map[w])) {
idom_map[w] = idom_map[idom_map[w]];
if (DEBUG) std::cout << " Updating idom[" << w->getName() << "] to idom(idom(w)): "
<< idom_map[w]->getName() << std::endl;
} else {
if (DEBUG) std::cout << " Warning: idom(idom(" << w->getName() << ")) not found, leaving idom[" << w->getName() << "] as is." << std::endl;
}
}
if (DEBUG) {
std::cout << " Final IDom[" << w->getName() << "] = " << (idom_map[w] ? idom_map[w]->getName() : "nullptr") << std::endl;
}
}
// 将计算结果从 idom_map 存储到 DominatorTree 的成员变量 IDoms 中
IDoms = idom_map;
if (DEBUG) std::cout << "--- Immediate Dominators Computation Finished ---" << std::endl;
}
// ==============================================================
// computeDominanceFrontiers 和 computeDominatorTreeChildren (保持不变)
// ==============================================================
void DominatorTree::computeDominanceFrontiers(Function *F) {
if (DEBUG)
@@ -221,21 +416,17 @@ void DominatorTree::computeDominanceFrontiers(Function *F) {
BasicBlock *X = bb_ptr_X.get();
DominanceFrontiers[X].clear();
// 遍历所有可能的 Z (X支配Z或者Z就是X)
for (const auto &bb_ptr_Z : F->getBasicBlocks()) {
BasicBlock *Z = bb_ptr_Z.get();
const std::set<BasicBlock *> *domsOfZ = getDominators(Z);
// 如果 X 不支配 Z则 Z 与 DF(X) 无关
if (!domsOfZ || domsOfZ->find(X) == domsOfZ->end()) {
if (!domsOfZ || domsOfZ->find(X) == domsOfZ->end()) { // Z 不被 X 支配
continue;
}
// 遍历 Z 的所有后继 Y
for (BasicBlock *Y : Z->getSuccessors()) {
// 如果 Y 不被 X 严格支配,则 Y 在 DF(X) 中
// Y 不被 X 严格支配意味着 (Y不被X支配) 或 (Y就是X)
const std::set<BasicBlock *> *domsOfY = getDominators(Y);
// 如果 Y == X或者 Y 不被 X 严格支配 (即 Y 不被 X 支配)
if (Y == X || (domsOfY && domsOfY->find(X) == domsOfY->end())) {
DominanceFrontiers[X].insert(Y);
}
@@ -274,23 +465,21 @@ void DominatorTree::computeDominatorTreeChildren(Function *F) {
}
// ==============================================================
// DominatorTreeAnalysisPass 的实现
// DominatorTreeAnalysisPass 的实现 (保持不变)
// ==============================================================
bool DominatorTreeAnalysisPass::runOnFunction(Function *F, AnalysisManager &AM) {
// 每次运行时清空旧数据,确保重新计算
CurrentDominatorTree = std::make_unique<DominatorTree>(F);
// 不需要手动清空mapunique_ptr会创建新的DominatorTree对象其map是空的
CurrentDominatorTree->computeDominators(F);
CurrentDominatorTree->computeIDoms(F); // 修正后的IDoms算法
CurrentDominatorTree->computeIDoms(F); // 修正后的LT算法
CurrentDominatorTree->computeDominanceFrontiers(F);
CurrentDominatorTree->computeDominatorTreeChildren(F);
return false; // 分析遍通常返回 false表示不修改 IR
return false;
}
std::unique_ptr<AnalysisResultBase> DominatorTreeAnalysisPass::getResult() {
// 返回计算好的 DominatorTree 实例,所有权转移给 AnalysisManager
return std::move(CurrentDominatorTree);
}

View File

@@ -0,0 +1,79 @@
#include "BuildCFG.h"
#include "Dom.h"
#include "Liveness.h"
#include <iostream>
#include <queue>
#include <set>
namespace sysy {
void *BuildCFG::ID = (void *)&BuildCFG::ID; // 定义唯一的 Pass ID
// 声明Pass的分析使用
void BuildCFG::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
// BuildCFG不依赖其他分析
// analysisDependencies.insert(&DominatorTreeAnalysisPass::ID); // 错误的例子
// BuildCFG会使所有依赖于CFG的分析结果失效所以它必须声明这些失效
analysisInvalidations.insert(&DominatorTreeAnalysisPass::ID);
analysisInvalidations.insert(&LivenessAnalysisPass::ID);
}
bool BuildCFG::runOnFunction(Function *F, AnalysisManager &AM) {
if (DEBUG) {
std::cout << "Running BuildCFG pass on function: " << F->getName() << std::endl;
}
bool changed = false;
// 1. 清空所有基本块的前驱和后继列表
for (auto &bb : F->getBasicBlocks()) {
bb->clearPredecessors();
bb->clearSuccessors();
}
// 2. 遍历每个基本块重建CFG
for (auto &bb : F->getBasicBlocks()) {
// 获取基本块的最后一条指令
auto &inst = *bb->terminator();
Instruction *termInst = inst.get();
// 确保基本块有终结指令
if (!termInst) {
continue;
}
// 根据终结指令类型,建立前驱后继关系
if (termInst->isBranch()) {
// 无条件跳转
if (termInst->isUnconditional()) {
auto brInst = dynamic_cast<UncondBrInst *>(termInst);
BasicBlock *succ = dynamic_cast<BasicBlock *>(brInst->getBlock());
assert(succ && "Branch instruction's target must be a BasicBlock");
bb->addSuccessor(succ);
succ->addPredecessor(bb.get());
changed = true;
// 条件跳转
} else if (termInst->isConditional()) {
auto brInst = dynamic_cast<CondBrInst *>(termInst);
BasicBlock *trueSucc = dynamic_cast<BasicBlock *>(brInst->getThenBlock());
BasicBlock *falseSucc = dynamic_cast<BasicBlock *>(brInst->getElseBlock());
assert(trueSucc && falseSucc && "Branch instruction's targets must be BasicBlocks");
bb->addSuccessor(trueSucc);
trueSucc->addPredecessor(bb.get());
bb->addSuccessor(falseSucc);
falseSucc->addPredecessor(bb.get());
changed = true;
}
} else if (auto retInst = dynamic_cast<ReturnInst *>(termInst)) {
// RetInst没有后继无需处理
// ...
}
}
return changed;
}
} // namespace sysy

View File

@@ -0,0 +1,145 @@
#include "../../include/midend/Pass/Optimize/LargeArrayToGlobal.h"
#include "../../IR.h"
#include <unordered_map>
#include <sstream>
#include <string>
namespace sysy {
// Helper function to convert type to string
static std::string typeToString(Type *type) {
if (!type) return "null";
switch (type->getKind()) {
case Type::kInt:
return "int";
case Type::kFloat:
return "float";
case Type::kPointer:
return "ptr";
case Type::kArray: {
auto *arrayType = type->as<ArrayType>();
return "[" + std::to_string(arrayType->getNumElements()) + " x " +
typeToString(arrayType->getElementType()) + "]";
}
default:
return "unknown";
}
}
void *LargeArrayToGlobalPass::ID = &LargeArrayToGlobalPass::ID;
bool LargeArrayToGlobalPass::runOnModule(Module *M, AnalysisManager &AM) {
bool changed = false;
if (!M) {
return false;
}
// Collect all alloca instructions from all functions
std::vector<std::pair<AllocaInst*, Function*>> allocasToConvert;
for (auto &funcPair : M->getFunctions()) {
Function *F = funcPair.second.get();
if (!F || F->getBasicBlocks().begin() == F->getBasicBlocks().end()) {
continue;
}
for (auto &BB : F->getBasicBlocks()) {
for (auto &inst : BB->getInstructions()) {
if (auto *alloca = dynamic_cast<AllocaInst*>(inst.get())) {
Type *allocatedType = alloca->getAllocatedType();
// Calculate the size of the allocated type
unsigned size = calculateTypeSize(allocatedType);
if(DEBUG){
// Debug: print size information
std::cout << "LargeArrayToGlobalPass: Found alloca with size " << size
<< " for type " << typeToString(allocatedType) << std::endl;
}
// Convert arrays of 1KB (1024 bytes) or larger to global variables
if (size >= 1024) {
if(DEBUG)
std::cout << "LargeArrayToGlobalPass: Converting array of size " << size << " to global" << std::endl;
allocasToConvert.emplace_back(alloca, F);
}
}
}
}
}
// Convert the collected alloca instructions to global variables
for (auto [alloca, F] : allocasToConvert) {
convertAllocaToGlobal(alloca, F, M);
changed = true;
}
return changed;
}
unsigned LargeArrayToGlobalPass::calculateTypeSize(Type *type) {
if (!type) return 0;
switch (type->getKind()) {
case Type::kInt:
case Type::kFloat:
return 4;
case Type::kPointer:
return 8;
case Type::kArray: {
auto *arrayType = type->as<ArrayType>();
return arrayType->getNumElements() * calculateTypeSize(arrayType->getElementType());
}
default:
return 0;
}
}
void LargeArrayToGlobalPass::convertAllocaToGlobal(AllocaInst *alloca, Function *F, Module *M) {
Type *allocatedType = alloca->getAllocatedType();
// Create a unique name for the global variable
std::string globalName = generateUniqueGlobalName(alloca, F);
// Create the global variable - GlobalValue expects pointer type
Type *pointerType = Type::getPointerType(allocatedType);
GlobalValue *globalVar = M->createGlobalValue(globalName, pointerType);
if (!globalVar) {
return;
}
// Replace all uses of the alloca with the global variable
alloca->replaceAllUsesWith(globalVar);
// Remove the alloca instruction from its basic block
for (auto &BB : F->getBasicBlocks()) {
auto &instructions = BB->getInstructions();
for (auto it = instructions.begin(); it != instructions.end(); ++it) {
if (it->get() == alloca) {
instructions.erase(it);
break;
}
}
}
}
std::string LargeArrayToGlobalPass::generateUniqueGlobalName(AllocaInst *alloca, Function *F) {
std::string baseName = alloca->getName();
if (baseName.empty()) {
baseName = "array";
}
// Ensure uniqueness by appending function name and counter
static std::unordered_map<std::string, int> nameCounter;
std::string key = F->getName() + "." + baseName;
int counter = nameCounter[key]++;
std::ostringstream oss;
oss << key << "." << counter;
return oss.str();
}
} // namespace sysy

View File

@@ -148,8 +148,8 @@ void Reg2MemContext::rewritePhis(Function *func) {
// 1. 为 Phi 指令的每个入边,在前驱块的末尾插入 Store 指令
// PhiInst 假设有 getIncomingValues() 和 getIncomingBlocks()
for (unsigned i = 0; i < phiInst->getNumIncomingValues(); ++i) { // 假设 PhiInst 是通过操作数来管理入边的
Value *incomingValue = phiInst->getValue(i); // 获取入值
BasicBlock *incomingBlock = phiInst->getBlock(i); // 获取对应的入块
Value *incomingValue = phiInst->getIncomingValue(i); // 获取入值
BasicBlock *incomingBlock = phiInst->getIncomingBlock(i); // 获取对应的入块
// 在入块的跳转指令之前插入 StoreInst
// 需要找到 incomingBlock 的终结指令 (Terminator Instruction)

View File

@@ -468,6 +468,22 @@ void SCCPContext::ProcessInstruction(Instruction *inst) {
return; // 不处理不可达块中的指令的实际值
}
if(DEBUG) {
std::cout << "Processing instruction: " << inst->getName() << " in block " << inst->getParent()->getName() << std::endl;
std::cout << "Old state: ";
if (oldState.state == LatticeVal::Top) {
std::cout << "Top";
} else if (oldState.state == LatticeVal::Constant) {
if (oldState.constant_type == ValueType::Integer) {
std::cout << "Const<int>(" << std::get<int>(oldState.constantVal) << ")";
} else {
std::cout << "Const<float>(" << std::get<float>(oldState.constantVal) << ")";
}
} else {
std::cout << "Bottom";
}
}
switch (inst->getKind()) {
case Instruction::kAdd:
case Instruction::kSub:
@@ -815,19 +831,71 @@ void SCCPContext::ProcessInstruction(Instruction *inst) {
}
case Instruction::kPhi: {
PhiInst *phi = static_cast<PhiInst *>(inst);
if(DEBUG) {
std::cout << "Processing Phi node: " << phi->getName() << std::endl;
}
// 标准SCCP的phi节点处理
// 只考虑可执行前驱,但要保证单调性
SSAPValue currentPhiState = GetValueState(phi);
SSAPValue phiResult = SSAPValue(); // 初始为 Top
bool hasAnyExecutablePred = false;
for (unsigned i = 0; i < phi->getNumIncomingValues(); ++i) {
Value *incomingVal = phi->getIncomingValue(i);
BasicBlock *incomingBlock = phi->getIncomingBlock(i);
if (executableBlocks.count(incomingBlock)) { // 仅考虑可执行前驱
phiResult = Meet(phiResult, GetValueState(incomingVal));
if (phiResult.state == LatticeVal::Bottom)
break; // 如果已经 Bottom则提前退出
if (executableBlocks.count(incomingBlock)) {
hasAnyExecutablePred = true;
Value *incomingVal = phi->getIncomingValue(i);
SSAPValue incomingState = GetValueState(incomingVal);
if(DEBUG) {
std::cout << " Incoming from block " << incomingBlock->getName()
<< " with value " << incomingVal->getName() << " state: ";
if (incomingState.state == LatticeVal::Top)
std::cout << "Top";
else if (incomingState.state == LatticeVal::Constant) {
if (incomingState.constant_type == ValueType::Integer)
std::cout << "Const<int>(" << std::get<int>(incomingState.constantVal) << ")";
else
std::cout << "Const<float>(" << std::get<float>(incomingState.constantVal) << ")";
} else
std::cout << "Bottom";
std::cout << std::endl;
}
phiResult = Meet(phiResult, incomingState);
if (phiResult.state == LatticeVal::Bottom) {
break; // 提前退出优化
}
}
// 不可执行前驱暂时被忽略
// 这是标准SCCP的做法依赖于单调性保证正确性
}
if (!hasAnyExecutablePred) {
// 没有可执行前驱保持Top状态
newState = SSAPValue();
} else {
// 关键修复:使用严格的单调性
// 确保phi的值只能从Top -> Constant -> Bottom单向变化
if (currentPhiState.state == LatticeVal::Top) {
// 从Top状态可以变为任何计算结果
newState = phiResult;
} else if (currentPhiState.state == LatticeVal::Constant) {
// 从Constant状态只能保持相同常量或变为Bottom
if (phiResult.state == LatticeVal::Constant &&
currentPhiState.constantVal == phiResult.constantVal &&
currentPhiState.constant_type == phiResult.constant_type) {
// 保持相同的常量
newState = currentPhiState;
} else {
// 不同的值必须变为Bottom
newState = SSAPValue(LatticeVal::Bottom);
}
} else {
// 已经是Bottom保持Bottom
newState = currentPhiState;
}
}
break;
}
case Instruction::kAlloca: // 对应 kAlloca
@@ -884,6 +952,22 @@ void SCCPContext::ProcessInstruction(Instruction *inst) {
}
}
}
if (DEBUG) {
std::cout << "New state: ";
if (newState.state == LatticeVal::Top) {
std::cout << "Top";
} else if (newState.state == LatticeVal::Constant) {
if (newState.constant_type == ValueType::Integer) {
std::cout << "Const<int>(" << std::get<int>(newState.constantVal) << ")";
} else {
std::cout << "Const<float>(" << std::get<float>(newState.constantVal) << ")";
}
} else {
std::cout << "Bottom";
}
std::cout << std::endl;
}
}
// 辅助函数:处理单条控制流边
@@ -891,14 +975,22 @@ void SCCPContext::ProcessEdge(const std::pair<BasicBlock *, BasicBlock *> &edge)
BasicBlock *fromBB = edge.first;
BasicBlock *toBB = edge.second;
// 检查目标块是否已经可执行
bool wasAlreadyExecutable = executableBlocks.count(toBB) > 0;
// 标记目标块为可执行(如果还不是的话)
MarkBlockExecutable(toBB);
// 对于目标块中的所有 Phi 指令,重新评估其值,因为可能有新的前驱被激活
// 如果目标块之前就已经可执行那么需要重新处理其中的phi节点
// 因为现在有新的前驱变为可执行phi节点的值可能需要更新
if (wasAlreadyExecutable) {
for (auto &inst_ptr : toBB->getInstructions()) {
if (dynamic_cast<PhiInst *>(inst_ptr.get())) {
instWorkList.push(inst_ptr.get());
}
}
}
// 如果目标块是新变为可执行的MarkBlockExecutable已经添加了所有指令
}
// 阶段1: 常量传播与折叠
@@ -913,18 +1005,29 @@ bool SCCPContext::PropagateConstants(Function *func) {
}
}
// 初始化函数参数为Bottom因为它们在编译时是未知的
for (auto arg : func->getArguments()) {
valueState[arg] = SSAPValue(LatticeVal::Bottom);
if (DEBUG) {
std::cout << "Initializing function argument " << arg->getName() << " to Bottom" << std::endl;
}
}
// 标记入口块为可执行
if (!func->getBasicBlocks().empty()) {
MarkBlockExecutable(func->getEntryBlock());
}
// 主循环:处理工作列表直到不动点
// 主循环:标准的SCCP工作列表算法
// 交替处理边工作列表和指令工作列表直到不动点
while (!instWorkList.empty() || !edgeWorkList.empty()) {
// 处理所有待处理的CFG边
while (!edgeWorkList.empty()) {
ProcessEdge(edgeWorkList.front());
edgeWorkList.pop();
}
// 处理所有待处理的指令
while (!instWorkList.empty()) {
Instruction *inst = instWorkList.front();
instWorkList.pop();
@@ -1243,7 +1346,7 @@ void SCCPContext::RemovePhiIncoming(BasicBlock *phiParentBB, BasicBlock *removed
for (Instruction *inst : insts_to_check) {
if (auto phi = dynamic_cast<PhiInst *>(inst)) {
phi->delBlk(removedPred);
phi->removeIncomingBlock(removedPred);
}
}
}

View File

@@ -42,7 +42,7 @@ bool SysYCFGOptUtils::SysYDelInstAfterBr(Function *func) {
++Branchiter;
while (Branchiter != instructions.end()) {
changed = true;
Branchiter = instructions.erase(Branchiter);
Branchiter = SysYIROptUtils::usedelete(Branchiter); // 删除指令
}
if (Branch) { // 更新前驱后继关系
@@ -77,6 +77,11 @@ bool SysYCFGOptUtils::SysYBlockMerge(Function *func) {
bool changed = false;
for (auto blockiter = func->getBasicBlocks().begin(); blockiter != func->getBasicBlocks().end();) {
// 检查当前块是是不是entry块
if( blockiter->get() == func->getEntryBlock() ) {
blockiter++;
continue; // 跳过入口块
}
if (blockiter->get()->getNumSuccessors() == 1) {
// 如果当前块只有一个后继块
// 且后继块只有一个前驱块
@@ -86,7 +91,7 @@ bool SysYCFGOptUtils::SysYBlockMerge(Function *func) {
BasicBlock *block = blockiter->get();
BasicBlock *nextBlock = blockiter->get()->getSuccessors()[0];
// auto nextarguments = nextBlock->getArguments();
// 删除br指令
// 删除block的br指令
if (block->getNumInstructions() != 0) {
auto thelastinstinst = block->terminator();
if (thelastinstinst->get()->isUnconditional()) {
@@ -98,14 +103,21 @@ bool SysYCFGOptUtils::SysYBlockMerge(Function *func) {
if (brinst->getThenBlock() == brinst->getElseBlock()) {
thelastinstinst = SysYIROptUtils::usedelete(thelastinstinst);
}
else{
assert(false && "SysYBlockMerge: unexpected conditional branch with different then and else blocks");
}
}
}
// 将后继块的指令移动到当前块
// 并将后继块的父指针改为当前块
for (auto institer = nextBlock->begin(); institer != nextBlock->end();) {
institer->get()->setParent(block);
block->getInstructions().emplace_back(institer->release());
institer = nextBlock->getInstructions().erase(institer);
// institer->get()->setParent(block);
// block->getInstructions().emplace_back(institer->release());
// 用usedelete删除会导致use关系被删除我只希望移动指令到当前块
// institer = SysYIROptUtils::usedelete(institer);
// institer = nextBlock->getInstructions().erase(institer);
institer = nextBlock->moveInst(institer, block->getInstructions().end(), block);
}
// 更新前驱后继关系,类似树节点操作
block->removeSuccessor(nextBlock);
@@ -189,7 +201,7 @@ bool SysYCFGOptUtils::SysYDelNoPreBLock(Function *func) {
break;
}
// 将这个 Phi 节点中来自不可达前驱unreachableBlock的输入参数删除
dynamic_cast<PhiInst *>(phiInstPtr.get())->delBlk(unreachableBlock);
dynamic_cast<PhiInst *>(phiInstPtr.get())->removeIncomingBlock(unreachableBlock);
}
}
}
@@ -288,13 +300,12 @@ bool SysYCFGOptUtils::SysYDelEmptyBlock(Function *func, IRBuilder *pBuilder) {
continue;
}
std::function<Value *(Value *, BasicBlock *)> getUltimateSourceValue = [&](Value *val,
BasicBlock *currentDefBlock) -> Value * {
// 如果值不是指令,例如常量或函数参数,则它本身就是最终来源
if (auto instr = dynamic_cast<Instruction *>(val)) { // Assuming Value* has a method to check if it's an instruction
std::function<Value *(Value *, BasicBlock *)> getUltimateSourceValue = [&](Value *val, BasicBlock *currentDefBlock) -> Value * {
if(!dynamic_cast<Instruction *>(val)) {
// 如果 val 不是指令,直接返回它
return val;
}
Instruction *inst = dynamic_cast<Instruction *>(val);
// 如果定义指令不在任何空块中,它就是最终来源
if (!emptyBlockRedirectMap.count(currentDefBlock)) {
@@ -311,7 +322,7 @@ bool SysYCFGOptUtils::SysYDelEmptyBlock(Function *func, IRBuilder *pBuilder) {
// 找到在空块链中导致 currentDefBlock 的那个前驱块
if (emptyBlockRedirectMap.count(incomingBlock) || incomingBlock == currentBlock) {
// 递归追溯该传入值
return getUltimateSourceValue(phi->getIncomingValue(incomingBlock), incomingBlock);
return getUltimateSourceValue(phi->getValfromBlk(incomingBlock), incomingBlock);
}
}
}
@@ -354,7 +365,7 @@ bool SysYCFGOptUtils::SysYDelEmptyBlock(Function *func, IRBuilder *pBuilder) {
if (actualEmptyPredecessorOfS) {
// 获取 Phi 节点原本从 actualEmptyPredecessorOfS 接收的值
Value *valueFromEmptyPredecessor = phiInst->getIncomingValue(actualEmptyPredecessorOfS);
Value *valueFromEmptyPredecessor = phiInst->getValfromBlk(actualEmptyPredecessorOfS);
// 追溯这个值,找到它在非空块中的最终来源
// currentBlock 是 P
@@ -364,12 +375,13 @@ bool SysYCFGOptUtils::SysYDelEmptyBlock(Function *func, IRBuilder *pBuilder) {
// 替换 Phi 节点的传入块和传入值
if (ultimateSourceValue) { // 确保成功追溯到有效来源
phiInst->replaceIncoming(actualEmptyPredecessorOfS, currentBlock, ultimateSourceValue);
// phiInst->replaceIncoming(actualEmptyPredecessorOfS, currentBlock, ultimateSourceValue);
phiInst->replaceIncomingBlock(actualEmptyPredecessorOfS, currentBlock, ultimateSourceValue);
} else {
assert(false && "[DelEmptyBlock] Unable to trace a valid source for Phi instruction");
// 无法追溯到有效来源,这可能是个错误或特殊情况
// 此时可能需要移除该 Phi 项,或者插入一个 undef 值
phiInst->removeIncoming(actualEmptyPredecessorOfS);
phiInst->getValfromBlk(actualEmptyPredecessorOfS);
}
}
} else {
@@ -421,7 +433,7 @@ bool SysYCFGOptUtils::SysYDelEmptyBlock(Function *func, IRBuilder *pBuilder) {
if (actualEmptyPredecessorOfS) {
// 获取 Phi 节点原本从 actualEmptyPredecessorOfS 接收的值
Value *valueFromEmptyPredecessor = phiInst->getIncomingValue(actualEmptyPredecessorOfS);
Value *valueFromEmptyPredecessor = phiInst->getValfromBlk(actualEmptyPredecessorOfS);
// 追溯这个值,找到它在非空块中的最终来源
// currentBlock 是 P
@@ -431,12 +443,13 @@ bool SysYCFGOptUtils::SysYDelEmptyBlock(Function *func, IRBuilder *pBuilder) {
// 替换 Phi 节点的传入块和传入值
if (ultimateSourceValue) { // 确保成功追溯到有效来源
phiInst->replaceIncoming(actualEmptyPredecessorOfS, currentBlock, ultimateSourceValue);
// phiInst->replaceIncoming(actualEmptyPredecessorOfS, currentBlock, ultimateSourceValue);
phiInst->replaceIncomingBlock(actualEmptyPredecessorOfS, currentBlock, ultimateSourceValue);
} else {
assert(false && "[DelEmptyBlock] Unable to trace a valid source for Phi instruction");
// 无法追溯到有效来源,这可能是个错误或特殊情况
// 此时可能需要移除该 Phi 项,或者插入一个 undef 值
phiInst->removeIncoming(actualEmptyPredecessorOfS);
phiInst->removeIncomingBlock(actualEmptyPredecessorOfS);
}
}
} else {
@@ -481,7 +494,7 @@ bool SysYCFGOptUtils::SysYDelEmptyBlock(Function *func, IRBuilder *pBuilder) {
if (actualEmptyPredecessorOfS) {
// 获取 Phi 节点原本从 actualEmptyPredecessorOfS 接收的值
Value *valueFromEmptyPredecessor = phiInst->getIncomingValue(actualEmptyPredecessorOfS);
Value *valueFromEmptyPredecessor = phiInst->getValfromBlk(actualEmptyPredecessorOfS);
// 追溯这个值,找到它在非空块中的最终来源
// currentBlock 是 P
@@ -491,12 +504,13 @@ bool SysYCFGOptUtils::SysYDelEmptyBlock(Function *func, IRBuilder *pBuilder) {
// 替换 Phi 节点的传入块和传入值
if (ultimateSourceValue) { // 确保成功追溯到有效来源
phiInst->replaceIncoming(actualEmptyPredecessorOfS, currentBlock, ultimateSourceValue);
// phiInst->replaceIncoming(actualEmptyPredecessorOfS, currentBlock, ultimateSourceValue);
phiInst->replaceIncomingBlock(actualEmptyPredecessorOfS, currentBlock, ultimateSourceValue);
} else {
assert(false && "[DelEmptyBlock] Unable to trace a valid source for Phi instruction");
// 无法追溯到有效来源,这可能是个错误或特殊情况
// 此时可能需要移除该 Phi 项,或者插入一个 undef 值
phiInst->removeIncoming(actualEmptyPredecessorOfS);
phiInst->removeIncomingBlock(actualEmptyPredecessorOfS);
}
}
} else {
@@ -647,7 +661,7 @@ bool SysYCFGOptUtils::SysYCondBr2Br(Function *func, IRBuilder *pBuilder) {
break;
}
// 使用 delBlk 方法删除 basicblock.get() 对应的传入值
dynamic_cast<PhiInst *>(phiinst.get())->removeIncoming(basicblock.get());
dynamic_cast<PhiInst *>(phiinst.get())->removeIncomingBlock(basicblock.get());
}
} else { // cond为false或0
@@ -665,7 +679,7 @@ bool SysYCFGOptUtils::SysYCondBr2Br(Function *func, IRBuilder *pBuilder) {
break;
}
// 使用 delBlk 方法删除 basicblock.get() 对应的传入值
dynamic_cast<PhiInst *>(phiinst.get())->removeIncoming(basicblock.get());
dynamic_cast<PhiInst *>(phiinst.get())->removeIncomingBlock(basicblock.get());
}
}
}

View File

@@ -11,6 +11,8 @@
#include "Mem2Reg.h"
#include "Reg2Mem.h"
#include "SCCP.h"
#include "BuildCFG.h"
#include "LargeArrayToGlobal.h"
#include "Pass.h"
#include <iostream>
#include <queue>
@@ -40,6 +42,8 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
3. 添加优化passid
*/
// 注册分析遍
registerAnalysisPass<DominatorTreeAnalysisPass>();
registerAnalysisPass<LivenessAnalysisPass>();
registerAnalysisPass<sysy::DominatorTreeAnalysisPass>();
registerAnalysisPass<sysy::LivenessAnalysisPass>();
registerAnalysisPass<SysYAliasAnalysisPass>(); // 别名分析 (优先级高)
@@ -49,6 +53,9 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
registerAnalysisPass<LoopCharacteristicsPass>(); // 循环特征分析依赖别名分析
// 注册优化遍
registerOptimizationPass<BuildCFG>();
registerOptimizationPass<LargeArrayToGlobalPass>();
registerOptimizationPass<SysYDelInstAfterBrPass>();
registerOptimizationPass<SysYDelNoPreBLockPass>();
registerOptimizationPass<SysYBlockMergePass>();
@@ -68,6 +75,16 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
if (DEBUG) std::cout << "Applying -O1 optimizations.\n";
if (DEBUG) std::cout << "--- Running custom optimization sequence ---\n";
if(DEBUG) {
std::cout << "=== IR Before CFGOpt Optimizations ===\n";
printPasses();
}
this->clearPasses();
this->addPass(&BuildCFG::ID);
this->addPass(&LargeArrayToGlobalPass::ID);
this->run();
this->clearPasses();
this->addPass(&SysYDelInstAfterBrPass::ID);
this->addPass(&SysYDelNoPreBLockPass::ID);
@@ -77,6 +94,10 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
this->addPass(&SysYAddReturnPass::ID);
this->run();
this->clearPasses();
this->addPass(&BuildCFG::ID);
this->run();
if(DEBUG) {
std::cout << "=== IR After CFGOpt Optimizations ===\n";
printPasses();
@@ -117,7 +138,9 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
std::cout << "=== IR After Reg2Mem Optimizations ===\n";
printPasses();
}
this->clearPasses();
this->addPass(&BuildCFG::ID);
this->run();
if (DEBUG) std::cout << "--- Custom optimization sequence finished ---\n";
}
@@ -132,6 +155,7 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
SysYPrinter printer(moduleIR);
printer.printIR();
}
}
void PassManager::clearPasses() {

View File

@@ -15,6 +15,139 @@
using namespace std;
namespace sysy {
std::pair<long long, int> calculate_signed_magic(int d) {
if (d == 0) throw std::runtime_error("Division by zero");
if (d == 1 || d == -1) return {0, 0}; // Not used by strength reduction
int k = 0;
unsigned int ad = (d > 0) ? d : -d;
unsigned int temp = ad;
while (temp > 0) {
temp >>= 1;
k++;
}
if ((ad & (ad - 1)) == 0) { // if power of 2
k--;
}
unsigned __int128 m_val = 1;
m_val <<= (32 + k - 1);
unsigned __int128 m_prime = m_val / ad;
long long m = m_prime + 1;
return {m, k};
}
// 清除因函数调用而失效的表达式缓存(保守策略)
void SysYIRGenerator::invalidateExpressionsOnCall() {
availableBinaryExpressions.clear();
availableUnaryExpressions.clear();
availableLoads.clear();
availableGEPs.clear();
}
// 在进入新的基本块时清空所有表达式缓存
void SysYIRGenerator::enterNewBasicBlock() {
availableBinaryExpressions.clear();
availableUnaryExpressions.clear();
availableLoads.clear();
availableGEPs.clear();
}
// 清除因变量赋值而失效的表达式缓存
// @param storedAddress: store 指令的目标地址 (例如 AllocaInst* 或 GEPInst*)
void SysYIRGenerator::invalidateExpressionsOnStore(Value *storedAddress) {
// 遍历二元表达式缓存,移除受影响的条目
// 创建一个临时列表来存储要移除的键,避免在迭代时修改容器
std::vector<ExpKey> binaryKeysToRemove;
for (const auto &pair : availableBinaryExpressions) {
// 检查左操作数
// 如果左操作数是 LoadInst并且它从 storedAddress 加载
if (auto loadInst = dynamic_cast<LoadInst *>(pair.first.left)) {
if (loadInst->getPointer() == storedAddress) {
binaryKeysToRemove.push_back(pair.first);
continue; // 这个表达式已标记为移除,跳到下一个
}
}
// 如果左操作数本身就是被存储的地址 (例如,将一个地址值直接作为操作数,虽然不常见)
if (pair.first.left == storedAddress) {
binaryKeysToRemove.push_back(pair.first);
continue;
}
// 检查右操作数,逻辑同左操作数
if (auto loadInst = dynamic_cast<LoadInst *>(pair.first.right)) {
if (loadInst->getPointer() == storedAddress) {
binaryKeysToRemove.push_back(pair.first);
continue;
}
}
if (pair.first.right == storedAddress) {
binaryKeysToRemove.push_back(pair.first);
continue;
}
}
// 实际移除条目
for (const auto &key : binaryKeysToRemove) {
availableBinaryExpressions.erase(key);
}
// 遍历一元表达式缓存,移除受影响的条目
std::vector<UnExpKey> unaryKeysToRemove;
for (const auto &pair : availableUnaryExpressions) {
// 检查操作数
if (auto loadInst = dynamic_cast<LoadInst *>(pair.first.operand)) {
if (loadInst->getPointer() == storedAddress) {
unaryKeysToRemove.push_back(pair.first);
continue;
}
}
if (pair.first.operand == storedAddress) {
unaryKeysToRemove.push_back(pair.first);
continue;
}
}
// 实际移除条目
for (const auto &key : unaryKeysToRemove) {
availableUnaryExpressions.erase(key);
}
availableLoads.erase(storedAddress);
std::vector<GEPKey> gepKeysToRemove;
for (const auto &pair : availableGEPs) {
// 检查 GEP 的基指针是否受存储影响
if (auto loadInst = dynamic_cast<LoadInst *>(pair.first.basePointer)) {
if (loadInst->getPointer() == storedAddress) {
gepKeysToRemove.push_back(pair.first);
continue; // 标记此GEP为移除跳过后续检查
}
}
// 如果基指针本身就是存储的目标地址 (不常见,但可能)
if (pair.first.basePointer == storedAddress) {
gepKeysToRemove.push_back(pair.first);
continue;
}
// 检查 GEP 的每个索引是否受存储影响
for (const auto &indexVal : pair.first.indices) {
if (auto loadInst = dynamic_cast<LoadInst *>(indexVal)) {
if (loadInst->getPointer() == storedAddress) {
gepKeysToRemove.push_back(pair.first);
break; // 标记此GEP为移除并跳出内部循环
}
}
// 如果索引本身就是存储的目标地址
if (indexVal == storedAddress) {
gepKeysToRemove.push_back(pair.first);
break;
}
}
}
// 实际移除条目
for (const auto &key : gepKeysToRemove) {
availableGEPs.erase(key);
}
}
// std::vector<Value*> BinaryValueStack; ///< 用于存储value的栈
// std::vector<int> BinaryOpStack; ///< 用于存储二元表达式的操作符栈
@@ -244,6 +377,13 @@ void SysYIRGenerator::compute() {
}
} else {
// 否则创建相应的IR指令
ExpKey currentExpKey(static_cast<BinaryOp>(op), lhs, rhs);
auto it = availableBinaryExpressions.find(currentExpKey);
if (it != availableBinaryExpressions.end()) {
// 在缓存中找到,重用结果
resultValue = it->second;
} else {
if (commonType == Type::getIntType()) {
switch (op) {
case BinaryOp::ADD: resultValue = builder.createAddInst(lhs, rhs); break;
@@ -266,6 +406,9 @@ void SysYIRGenerator::compute() {
std::cerr << "Error: Unsupported type for binary instruction." << std::endl;
return;
}
// 将新创建的指令结果添加到缓存
availableBinaryExpressions[currentExpKey] = resultValue;
}
}
break;
}
@@ -316,7 +459,13 @@ void SysYIRGenerator::compute() {
return;
}
} else {
// 否则创建相应的IR指令
// 否则创建相应的IR指令 (在这里应用CSE)
UnExpKey currentUnExpKey(static_cast<BinaryOp>(op), operand);
auto it = availableUnaryExpressions.find(currentUnExpKey);
if (it != availableUnaryExpressions.end()) {
// 在缓存中找到,重用结果
resultValue = it->second;
} else {
switch (op) {
case BinaryOp::PLUS:
resultValue = operand; // 一元加指令通常直接返回操作数
@@ -347,6 +496,9 @@ void SysYIRGenerator::compute() {
std::cerr << "Error: Unknown unary operator for instructions: " << op << std::endl;
return;
}
// 将新创建的指令结果添加到缓存
availableUnaryExpressions[currentUnExpKey] = resultValue;
}
}
break;
}
@@ -487,7 +639,19 @@ Value* SysYIRGenerator::getGEPAddressInst(Value* basePointer, const std::vector<
// `indices` 向量现在由调用方(如 visitLValue, visitVarDecl, visitAssignStmt负责完整准备
// 包括是否需要添加初始的 `0` 索引。
// 所以这里直接将其传递给 `builder.createGetElementPtrInst`。
return builder.createGetElementPtrInst(basePointer, indices);
GEPKey key = {basePointer, indices};
// 尝试从缓存中查找
auto it = availableGEPs.find(key);
if (it != availableGEPs.end()) {
return it->second; // 缓存命中,返回已有的 GEPInst*
}
// 缓存未命中,创建新的 GEPInst
Value* gepInst = builder.createGetElementPtrInst(basePointer, indices); // 假设 builder 提供了 createGEPInst 方法
availableGEPs[key] = gepInst; // 将新的 GEPInst* 加入缓存
return gepInst;
}
/*
@@ -586,7 +750,13 @@ std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx) {
// 显式地为局部常量在栈上分配空间
// alloca 的类型将是指针指向常量类型,例如 `int*` 或 `int[2][3]*`
// 将alloca全部集中到entry中
auto entry = builder.getBasicBlock()->getParent()->getEntryBlock();
auto it = builder.getPosition();
auto nowblk = builder.getBasicBlock();
builder.setPosition(entry, entry->terminator());
AllocaInst *alloca = builder.createAllocaInst(Type::getPointerType(variableType), name);
builder.setPosition(nowblk, it);
ArrayValueTree *root = std::any_cast<ArrayValueTree *>(constDef->constInitVal()->accept(this));
ValueCounter values;
@@ -743,8 +913,12 @@ std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) {
// 对于数组alloca 的类型将是指针指向数组类型,例如 `int[2][3]*`
// 对于标量alloca 的类型将是指针指向标量类型,例如 `int*`
AllocaInst* alloca =
builder.createAllocaInst(Type::getPointerType(variableType), name);
auto entry = builder.getBasicBlock()->getParent()->getEntryBlock();
auto it = builder.getPosition();
auto nowblk = builder.getBasicBlock();
builder.setPosition(entry, entry->terminator());
AllocaInst *alloca = builder.createAllocaInst(Type::getPointerType(variableType), name);
builder.setPosition(nowblk, it);
if (varDef->initVal() != nullptr) {
ValueCounter values;
@@ -946,6 +1120,8 @@ std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) {
std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){
// 更新作用域
module->enterNewScope();
// 清除CSE缓存
enterNewBasicBlock();
auto name = ctx->Ident()->getText();
std::vector<Type *> paramActualTypes;
@@ -1015,15 +1191,25 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){
for(int i = 0; i < paramActualTypes.size(); ++i) {
Argument* arg = new Argument(paramActualTypes[i], function, i, paramNames[i]);
function->insertArgument(arg);
}
// 先将所有参数名字注册到符号表中确保alloca不会使用相同的名字
for (int i = 0; i < paramNames.size(); ++i) {
// 预先注册参数名字这样addVariable就会使用不同的后缀
module->registerParameterName(paramNames[i]);
}
auto funcArgs = function->getArguments();
std::vector<AllocaInst *> allocas;
for (int i = 0; i < paramActualTypes.size(); ++i) {
AllocaInst *alloca = builder.createAllocaInst(Type::getPointerType(paramActualTypes[i]), paramNames[i]);
// 使用函数特定的前缀来确保参数alloca名字唯一
std::string allocaName = name + "_param_" + paramNames[i];
AllocaInst *alloca = builder.createAllocaInst(Type::getPointerType(paramActualTypes[i]), allocaName);
// 直接设置唯一名字不依赖addVariable的命名逻辑
alloca->setName(allocaName);
allocas.push_back(alloca);
module->addVariable(paramNames[i], alloca);
// 直接添加到符号表,使用原参数名作为查找键
module->addVariableDirectly(paramNames[i], alloca);
}
for(int i = 0; i < paramActualTypes.size(); ++i) {
@@ -1037,6 +1223,7 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){
// 从 entryBB 无条件跳转到 funcBodyEntry
builder.createUncondBrInst(funcBodyEntry);
BasicBlock::conectBlocks(entry, funcBodyEntry); // 连接 entryBB 和 funcBodyEntry
builder.setPosition(funcBodyEntry,funcBodyEntry->end()); // 将插入点设置到 funcBodyEntry
for (auto item : ctx->blockStmt()->blockItem()) {
@@ -1091,42 +1278,8 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
if (dynamic_cast<AllocaInst*>(variable) || dynamic_cast<GlobalValue*>(variable)) {
LValue = variable;
}
}
else {
// 对于数组或多维数组的左值处理
// 需要获取 GEP 地址
Value* gepBasePointer = nullptr;
std::vector<Value*> gepIndices;
if (AllocaInst *alloc = dynamic_cast<AllocaInst *>(variable)) {
Type* allocatedType = alloc->getType()->as<PointerType>()->getBaseType();
if (allocatedType->isPointer()) {
gepBasePointer = builder.createLoadInst(alloc);
gepIndices = indices;
} else {
gepBasePointer = alloc;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
}
} else if (GlobalValue *glob = dynamic_cast<GlobalValue *>(variable)) {
// 情况 B: 全局变量 (GlobalValue)
gepBasePointer = glob;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
} else if (ConstantVariable *constV = dynamic_cast<ConstantVariable *>(variable)) {
gepBasePointer = constV;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
}
// 左值为地址
LValue = getGEPAddressInst(gepBasePointer, gepIndices);
}
// Value* RValue = std::any_cast<Value *>(visitExp(ctx->exp())); // 右值
// 先推断 LValue 的类型
// 如果 LValue 是指向数组的指针,则需要根据 indices 获取正确的类型
// 如果 LValue 是标量,则直接使用其类型
// 注意LValue 的类型可能是指向数组的指针 (e.g., int(*)[3]) 或者指向标量的指针 (e.g., int*) 也能推断
// 标量变量的类型推断
Type* LType = builder.getIndexedType(variable->getType(), indices);
Value* RValue = computeExp(ctx->exp(), LType); // 右值计算
@@ -1151,20 +1304,98 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
} else if (dynamic_cast<ConstantInteger *>(constValue)) {
// 如果是整型常量,直接使用
RValue = ConstantInteger::get(static_cast<int>(constValue->getInt()));
}
}
} else {
if (LType == Type::getFloatType()) {
if (LType == Type::getFloatType() && RType != Type::getFloatType()) {
RValue = builder.createItoFInst(RValue);
} else { // 假设如果不是浮点型,就是整型
} else if (LType != Type::getFloatType() && RType == Type::getFloatType()) {
RValue = builder.createFtoIInst(RValue);
}
// 如果两者都是同一类型,就不需要转换
}
}
builder.createStoreInst(RValue, LValue);
}
else {
// 对于数组或多维数组的左值处理
// 需要获取 GEP 地址
Value* gepBasePointer = nullptr;
std::vector<Value*> gepIndices;
if (AllocaInst *alloc = dynamic_cast<AllocaInst *>(variable)) {
Type* allocatedType = alloc->getType()->as<PointerType>()->getBaseType();
if (allocatedType->isPointer()) {
// 尝试从缓存中获取 builder.createLoadInst(alloc) 的结果
auto it = availableLoads.find(alloc);
if (it != availableLoads.end()) {
gepBasePointer = it->second; // 缓存命中,重用
} else {
gepBasePointer = builder.createLoadInst(alloc); // 缓存未命中,创建新的 LoadInst
availableLoads[alloc] = gepBasePointer; // 将结果加入缓存
}
// --- CSE 结束 ---
// gepBasePointer = builder.createLoadInst(alloc);
gepIndices = indices;
} else {
gepBasePointer = alloc;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
}
} else if (GlobalValue *glob = dynamic_cast<GlobalValue *>(variable)) {
// 情况 B: 全局变量 (GlobalValue)
gepBasePointer = glob;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
} else if (ConstantVariable *constV = dynamic_cast<ConstantVariable *>(variable)) {
gepBasePointer = constV;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
}
// 左值为地址
LValue = getGEPAddressInst(gepBasePointer, gepIndices);
// 数组变量的类型推断使用gepIndices和gepBasePointer的类型
Type* LType = builder.getIndexedType(gepBasePointer->getType(), gepIndices);
Value* RValue = computeExp(ctx->exp(), LType); // 右值计算
Type* RType = RValue->getType();
// TODO:computeExp处理了类型转换可以考虑删除判断逻辑
if (LType != RType) {
ConstantValue *constValue = dynamic_cast<ConstantValue *>(RValue);
if (constValue != nullptr) {
if (LType == Type::getFloatType()) {
if(dynamic_cast<ConstantInteger *>(constValue)) {
// 如果是整型常量,转换为浮点型
RValue = ConstantFloating::get(static_cast<float>(constValue->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constValue)) {
// 如果是浮点型常量,直接使用
RValue = ConstantFloating::get(static_cast<float>(constValue->getFloat()));
}
} else { // 假设如果不是浮点型,就是整型
if(dynamic_cast<ConstantFloating *>(constValue)) {
// 如果是浮点型常量,转换为整型
RValue = ConstantInteger::get(static_cast<int>(constValue->getFloat()));
} else if (dynamic_cast<ConstantInteger *>(constValue)) {
// 如果是整型常量,直接使用
RValue = ConstantInteger::get(static_cast<int>(constValue->getInt()));
}
}
} else {
if (LType == Type::getFloatType() && RType != Type::getFloatType()) {
RValue = builder.createItoFInst(RValue);
} else if (LType != Type::getFloatType() && RType == Type::getFloatType()) {
RValue = builder.createFtoIInst(RValue);
}
// 如果两者都是同一类型,就不需要转换
}
}
builder.createStoreInst(RValue, LValue);
}
invalidateExpressionsOnStore(LValue);
return std::any();
}
@@ -1201,6 +1432,8 @@ std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) {
labelstring.str("");
function->addBasicBlock(thenBlock);
builder.setPosition(thenBlock, thenBlock->end());
// CSE清除缓存
enterNewBasicBlock();
auto block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt(0));
// 如果是块语句,直接访问
@@ -1220,6 +1453,8 @@ std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) {
labelstring.str("");
function->addBasicBlock(elseBlock);
builder.setPosition(elseBlock, elseBlock->end());
// CSE清除缓存
enterNewBasicBlock();
block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt(1));
if (block != nullptr) {
@@ -1237,6 +1472,8 @@ std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) {
labelstring.str("");
function->addBasicBlock(exitBlock);
builder.setPosition(exitBlock, exitBlock->end());
// CSE清除缓存
enterNewBasicBlock();
} else {
builder.pushTrueBlock(thenBlock);
@@ -1250,6 +1487,8 @@ std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) {
labelstring.str("");
function->addBasicBlock(thenBlock);
builder.setPosition(thenBlock, thenBlock->end());
// CSE清除缓存
enterNewBasicBlock();
auto block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt(0));
if (block != nullptr) {
@@ -1267,6 +1506,9 @@ std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) {
labelstring.str("");
function->addBasicBlock(exitBlock);
builder.setPosition(exitBlock, exitBlock->end());
// CSE清除缓存
enterNewBasicBlock();
}
return std::any();
}
@@ -1284,6 +1526,8 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) {
builder.createUncondBrInst(headBlock);
BasicBlock::conectBlocks(curBlock, headBlock);
builder.setPosition(headBlock, headBlock->end());
// CSE清除缓存
enterNewBasicBlock();
BasicBlock* bodyBlock = new BasicBlock(function);
BasicBlock* exitBlock = new BasicBlock(function);
@@ -1300,6 +1544,8 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) {
labelstring.str("");
function->addBasicBlock(bodyBlock);
builder.setPosition(bodyBlock, bodyBlock->end());
// CSE清除缓存
enterNewBasicBlock();
builder.pushBreakBlock(exitBlock);
builder.pushContinueBlock(headBlock);
@@ -1315,7 +1561,7 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) {
}
builder.createUncondBrInst(headBlock);
BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock);
BasicBlock::conectBlocks(builder.getBasicBlock(), headBlock);
builder.popBreakBlock();
builder.popContinueBlock();
@@ -1324,6 +1570,8 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) {
labelstring.str("");
function->addBasicBlock(exitBlock);
builder.setPosition(exitBlock, exitBlock->end());
// CSE清除缓存
enterNewBasicBlock();
return std::any();
}
@@ -1430,90 +1678,101 @@ std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) {
break;
}
}
if (allIndicesConstant) {
// 如果是常量变量且所有索引都是常量,并且不是数组名单独出现的情况
if (allIndicesConstant && !dims.empty()) {
// 如果是常量变量且所有索引都是常量,直接通过 getByIndices 获取编译时值
// 这个方法会根据索引深度返回最终的标量值或指向子数组的指针 (作为 ConstantValue/Variable)
return constVar->getByIndices(dims);
}
// 如果dims为空检查是否是常量标量
if (dims.empty() && declaredNumDims == 0) {
// 常量标量,直接返回其值
// 默认传入空索引列表,表示访问标量本身
return constVar->getByIndices(dims);
}
// 如果dims为空但不是标量数组名单独出现需要走GEP路径来实现数组到指针的退化
}
// 3. 处理可变变量 (AllocaInst/GlobalValue) 或带非常量索引的常量变量
// 这里区分标量访问和数组元素/子数组访问
Value *targetAddress = nullptr;
// 检查是否是访问标量变量本身没有索引且声明维度为0
if (dims.empty() && declaredNumDims == 0) {
// 对于标量变量,直接加载其值。
// variable 本身就是指向标量的指针 (e.g., int* %a)
if (dynamic_cast<AllocaInst*>(variable) || dynamic_cast<GlobalValue*>(variable)) {
value = builder.createLoadInst(variable);
} else {
// 如果走到这里且不是AllocaInst/GlobalValue但dims为空且declaredNumDims为0
// 且又不是ConstantVariable (前面已处理),则可能是错误情况。
targetAddress = variable;
}
else {
assert(false && "Unhandled scalar variable type in LValue access.");
return static_cast<Value*>(nullptr);
}
} else {
// 访问数组元素或子数组(有索引,或变量本身是数组/多维指针)
Value* gepBasePointer = nullptr;
std::vector<Value*> gepIndices; // 准备传递给 getGEPAddressInst 的索引列表
// GEP 的基指针就是变量本身(它是一个指向内存的指针)
std::vector<Value*> gepIndices;
if (AllocaInst *alloc = dynamic_cast<AllocaInst *>(variable)) {
// 情况 A: 局部变量 (AllocaInst)
// 获取 AllocaInst 分配的内存的实际类型。
// 例如:对于 `int b[10][20];``allocatedType` 是 `[10 x [20 x i32]]`。
// 对于 `int b[][20]` 的函数参数,其 AllocaInst 存储的是一个指针,
// 此时 `allocatedType` 是 `[20 x i32]*`。
Type* allocatedType = alloc->getType()->as<PointerType>()->getBaseType();
if (allocatedType->isPointer()) {
// 如果 AllocaInst 分配的是一个指针类型 (例如,用于存储函数参数的指针,如 int b[][20] 中的 b)
// 即 `allocatedType` 是一个指向数组指针的指针 (e.g., [20 x i32]**)
// 那么 GEP 的基指针是加载这个指针变量的值。
gepBasePointer = builder.createLoadInst(alloc); // 加载出实际的指针值 (e.g., [20 x i32]*)
// 对于这种参数指针,用户提供的索引直接作用于它。不需要额外的 0。
gepBasePointer = builder.createLoadInst(alloc);
gepIndices = dims;
} else {
// 如果 AllocaInst 分配的是实际的数组数据 (例如int b[10][20] 中的 b)
// 那么 AllocaInst 本身就是 GEP 的基指针。
// 这里的 `alloc` 是指向数组的指针 (e.g., [10 x [20 x i32]]*)
gepBasePointer = alloc; // 类型是 [10 x [20 x i32]]*
// 对于这种完整的数组分配GEP 的第一个索引必须是 0用于“步过”整个数组。
gepBasePointer = alloc;
gepIndices.push_back(ConstantInteger::get(0));
if (dims.empty() && declaredNumDims > 0) {
// 数组名单独出现没有索引在SysY中多维数组名应该退化为指向第一行的指针
// 对于二维数组 T[M][N],退化为 T(*)[N]需要GEP: getelementptr T[M][N], T[M][N]* ptr, i32 0, i32 0
// 第一个i32 0: 选择数组本身第二个i32 0: 选择第0行
// 结果类型: T[N]*
gepIndices.push_back(ConstantInteger::get(0));
} else {
// 正常的数组元素访问
gepIndices.insert(gepIndices.end(), dims.begin(), dims.end());
}
}
} else if (GlobalValue *glob = dynamic_cast<GlobalValue *>(variable)) {
// 情况 B: 全局变量 (GlobalValue)
// GlobalValue 总是指向全局数据的指针。
gepBasePointer = glob; // 类型是 [61 x [67 x i32]]*
// 对于全局数组GEP 的第一个索引必须是 0用于“步过”整个数组。
gepBasePointer = glob;
gepIndices.push_back(ConstantInteger::get(0));
if (dims.empty() && declaredNumDims > 0) {
// 全局数组名单独出现(没有索引):应该退化为指向第一行的指针
// 需要添加一个额外的i32 0索引
gepIndices.push_back(ConstantInteger::get(0));
} else {
// 正常的数组元素访问
gepIndices.insert(gepIndices.end(), dims.begin(), dims.end());
}
} else if (ConstantVariable *constV = dynamic_cast<ConstantVariable *>(variable)) {
// 情况 C: 常量变量 (ConstantVariable),如果它代表全局数组常量
// 假设 ConstantVariable 可以直接作为 GEP 的基指针。
gepBasePointer = constV;
// 对于常量数组,也需要 0 索引来“步过”整个数组。
// 这里可以进一步检查 constV->getType()->as<PointerType>()->getBaseType()->isArray()
// 但为了简洁,假设所有 ConstantVariable 作为 GEP 基指针时都需要此 0。
gepIndices.push_back(ConstantInteger::get(0));
if (dims.empty() && declaredNumDims > 0) {
// 常量数组名单独出现(没有索引):应该退化为指向第一行的指针
// 需要添加一个额外的i32 0索引
gepIndices.push_back(ConstantInteger::get(0));
} else {
// 正常的数组元素访问
gepIndices.insert(gepIndices.end(), dims.begin(), dims.end());
}
} else {
assert(false && "LValue variable type not supported for GEP base pointer.");
return static_cast<Value *>(nullptr);
}
// 现在调用 getGEPAddressInst传入正确准备的基指针和索引列表
Value *targetAddress = getGEPAddressInst(gepBasePointer, gepIndices);
targetAddress = getGEPAddressInst(gepBasePointer, gepIndices);
// 如果提供的索引数量少于声明的维度数量,则表示访问的是子数组,返回其地址
}
// 如果提供的索引数量少于声明的维度数量,则表示访问的是子数组,返回其地址 (无需加载)
if (dims.size() < declaredNumDims) {
value = targetAddress;
} else {
// 否则,表示访问的是最终的标量元素,加载其值
// 假设 createLoadInst 接受 Value* pointer
// value = builder.createLoadInst(targetAddress);
auto it = availableLoads.find(targetAddress);
if (it != availableLoads.end()) {
value = it->second; // 缓存命中,重用已有的 LoadInst 结果
} else {
// 缓存未命中,创建新的 LoadInst
value = builder.createLoadInst(targetAddress);
availableLoads[targetAddress] = value; // 将新的 LoadInst 结果加入缓存
}
}
return value;
}
@@ -1571,10 +1830,10 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
// 获取形参列表。`getArguments()` 返回的是 `Argument*` 的集合,
// 每个 `Argument` 代表一个函数形参,其 `getType()` 就是指向形参的类型的指针类型。
auto formalParams = function->getArguments();
const auto& formalParams = function->getArguments();
// 检查实参和形参数量是否匹配。
if (args.size() != formalParams.size()) {
if (args.size() != function->getNumArguments()) {
std::cerr << "Error: Function call argument count mismatch for function '" << funcName << "'." << std::endl;
assert(false && "Function call argument count mismatch!");
}
@@ -1606,15 +1865,27 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
} else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) {
args[i] = builder.createItoFInst(args[i]);
}
// 2. 指针类型转换 (例如数组退化:`[N x T]*` 到 `T*`,或兼容指针类型之间) TODO不清楚有没有这种样例
// 2. 指针类型转换 (例如数组退化:`[N x T]*` 到 `T*`,或兼容指针类型之间)
// 这种情况常见于数组参数,实参可能是一个更具体的数组指针类型,
// 而形参是其退化后的基础指针类型。LLVM 的 `bitcast` 指令可以用于
// 在相同大小的指针类型之间进行转换,这对于数组退化至关重要。
// else if (formalParamType->isPointer() && actualArgType->isPointer()) {
// 检查指针基类型是否兼容,或者是否是数组退化导致的类型不同。
// 使用 bitcast
// args[i] = builder.createBitCastInst(args[i], formalParamType);
// }
// 而形参是其退化后的基础指针类型。
else if (formalParamExpectedValueType->isPointer() && actualArgType->isPointer()) {
// 检查是否是数组指针到元素指针的decay
// 例如:[N x T]* -> T*
auto formalPtrType = formalParamExpectedValueType->as<PointerType>();
auto actualPtrType = actualArgType->as<PointerType>();
if (formalPtrType && actualPtrType && actualPtrType->getBaseType()->isArray()) {
auto actualArrayType = actualPtrType->getBaseType()->as<ArrayType>();
if (actualArrayType &&
formalPtrType->getBaseType() == actualArrayType->getElementType()) {
// 这是数组decay的情况添加GEP来获取数组的第一个元素
std::vector<Value*> indices;
indices.push_back(ConstantInteger::get(0)); // 第一个索引:解引用指针
indices.push_back(ConstantInteger::get(0)); // 第二个索引:获取数组第一个元素
args[i] = getGEPAddressInst(args[i], indices);
}
}
}
// 3. 其他未预期的类型不匹配
// 如果代码执行到这里,说明存在编译器前端未处理的类型不兼容或错误。
else {
@@ -1633,6 +1904,7 @@ std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) {
visitPrimaryExp(ctx->primaryExp());
} else if (ctx->call() != nullptr) {
BinaryExpStack.push_back(std::any_cast<Value *>(visitCall(ctx->call())));BinaryExpLenStack.back()++;
invalidateExpressionsOnCall();
} else if (ctx->unaryOp() != nullptr) {
// 遇到一元操作符,将其压入 BinaryExpStack
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->unaryOp()->children[0]);
@@ -1997,15 +2269,23 @@ void Utils::createExternalFunction(
const std::vector<std::string> &paramNames,
const std::vector<std::vector<Value *>> &paramDims, Type *returnType,
const std::string &funcName, Module *pModule, IRBuilder *pBuilder) {
auto funcType = Type::getFunctionType(returnType, paramTypes);
// 根据paramDims调整参数类型数组参数需要转换为指针类型
std::vector<Type *> adjustedParamTypes = paramTypes;
for (int i = 0; i < paramTypes.size() && i < paramDims.size(); ++i) {
if (!paramDims[i].empty()) {
// 如果参数有维度信息,说明是数组参数,转换为指针类型
adjustedParamTypes[i] = Type::getPointerType(paramTypes[i]);
}
}
auto funcType = Type::getFunctionType(returnType, adjustedParamTypes);
auto function = pModule->createExternalFunction(funcName, funcType);
auto entry = function->getEntryBlock();
pBuilder->setPosition(entry, entry->end());
for (int i = 0; i < paramTypes.size(); ++i) {
auto arg = new Argument(paramTypes[i], function, i, paramNames[i]);
auto arg = new Argument(adjustedParamTypes[i], function, i, paramNames[i]);
auto alloca = pBuilder->createAllocaInst(
Type::getPointerType(paramTypes[i]), paramNames[i]);
Type::getPointerType(adjustedParamTypes[i]), paramNames[i]);
function->insertArgument(arg);
auto store = pBuilder->createStoreInst(arg, alloca);
pModule->addVariable(paramNames[i], alloca);

View File

@@ -240,6 +240,8 @@ void SysYPrinter::printInst(Instruction *pInst) {
case Kind::kMul:
case Kind::kDiv:
case Kind::kRem:
case Kind::kSRA:
case Kind::kMulh:
case Kind::kFAdd:
case Kind::kFSub:
case Kind::kFMul:
@@ -272,6 +274,8 @@ void SysYPrinter::printInst(Instruction *pInst) {
case Kind::kMul: std::cout << "mul"; break;
case Kind::kDiv: std::cout << "sdiv"; break;
case Kind::kRem: std::cout << "srem"; break;
case Kind::kSRA: std::cout << "ashr"; break;
case Kind::kMulh: std::cout << "mulh"; break;
case Kind::kFAdd: std::cout << "fadd"; break;
case Kind::kFSub: std::cout << "fsub"; break;
case Kind::kFMul: std::cout << "fmul"; break;
@@ -295,7 +299,12 @@ void SysYPrinter::printInst(Instruction *pInst) {
// Types and operands
std::cout << " ";
// For comparison operations, print operand types instead of result type
if (pInst->getKind() >= Kind::kICmpEQ && pInst->getKind() <= Kind::kFCmpGE) {
printType(binInst->getLhs()->getType());
} else {
printType(binInst->getType());
}
std::cout << " ";
printValue(binInst->getLhs());
std::cout << ", ";
@@ -508,9 +517,9 @@ void SysYPrinter::printInst(Instruction *pInst) {
if (!firstPair) std::cout << ", ";
firstPair = false;
std::cout << "[ ";
printValue(phiInst->getValue(i));
printValue(phiInst->getIncomingValue(i));
std::cout << ", %";
printBlock(phiInst->getBlock(i));
printBlock(phiInst->getIncomingBlock(i));
std::cout << " ]";
}
std::cout << std::endl;

View File

@@ -21,6 +21,8 @@ using namespace sysy;
int DEBUG = 0;
int DEEPDEBUG = 0;
int DEEPERDEBUG = 0;
int DEBUGLENGTH = 50;
static string argStopAfter;
static string argInputFile;
@@ -108,6 +110,7 @@ int main(int argc, char **argv) {
// 如果指定停止在 AST 阶段,则打印并退出
if (argStopAfter == "ast") {
cout << moduleAST->toStringTree(true) << '\n';
sysy::cleanupIRPools(); // 清理内存池
return EXIT_SUCCESS;
}
@@ -130,7 +133,7 @@ int main(int argc, char **argv) {
if (DEBUG) {
cout << "=== Init IR ===\n";
SysYPrinter(moduleIR).printIR(); // 临时打印器用于调试
moduleIR->print(cout); // 使用新实现的print方法直接打印IR
}
// 创建 Pass 管理器并运行优化管道
@@ -142,10 +145,26 @@ int main(int argc, char **argv) {
// a) 如果指定停止在 IR 阶段,则打印最终 IR 并退出
if (argStopAfter == "ir" || argStopAfter == "ird") {
// 打印最终 IR
cout << "=== Final IR ===\n";
SysYPrinter printer(moduleIR); // 在这里创建打印器,因为可能之前调试时用过临时打印器
printer.printIR();
if (DEBUG) cerr << "=== Final IR ===\n";
if (!argOutputFilename.empty()) {
// 输出到指定文件
ofstream fout(argOutputFilename);
if (not fout.is_open()) {
cerr << "Failed to open output file: " << argOutputFilename << endl;
moduleIR->cleanup(); // 清理模块
sysy::cleanupIRPools(); // 清理内存池
return EXIT_FAILURE;
}
moduleIR->print(fout);
fout.close();
} else {
// 输出到标准输出
moduleIR->print(cout);
}
moduleIR->cleanup(); // 清理模块
sysy::cleanupIRPools(); // 清理内存池
return EXIT_SUCCESS;
}
// b) 如果未停止在 IR 阶段,则继续生成汇编 (后端)
@@ -164,6 +183,8 @@ int main(int argc, char **argv) {
ofstream fout(argOutputFilename);
if (not fout.is_open()) {
cerr << "Failed to open output file: " << argOutputFilename << endl;
moduleIR->cleanup(); // 清理模块
sysy::cleanupIRPools(); // 清理内存池
return EXIT_FAILURE;
}
fout << asmCode << endl;
@@ -171,6 +192,8 @@ int main(int argc, char **argv) {
} else {
cout << asmCode << endl;
}
moduleIR->cleanup(); // 清理模块
sysy::cleanupIRPools(); // 清理内存池
return EXIT_SUCCESS;
}
@@ -179,5 +202,7 @@ int main(int argc, char **argv) {
cout << "Compilation completed. No output specified (neither -s nor -S). Exiting.\n";
// return EXIT_SUCCESS; // 或者这里调用一个链接器生成可执行文件
moduleIR->cleanup(); // 清理模块
sysy::cleanupIRPools(); // 清理内存池
return EXIT_SUCCESS;
}