From f879a0f521e3033d3f928a5976ac095af6905942 Mon Sep 17 00:00:00 2001 From: Lixuanwang Date: Sat, 2 Aug 2025 22:06:37 +0800 Subject: [PATCH] =?UTF-8?q?[midend]=E4=BF=AE=E5=A4=8D=E4=BA=86=E5=90=8E?= =?UTF-8?q?=E7=AB=AF=E4=B8=8D=E9=80=82=E9=85=8D=E4=B8=AD=E7=AB=AF=E5=85=A8?= =?UTF-8?q?=E5=B1=80=E5=8F=98=E9=87=8F=E5=AE=9A=E4=B9=89=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/RISCv64/RISCv64Backend.cpp | 93 ++++++++++++++++---- src/include/backend/RISCv64/RISCv64Backend.h | 4 + 2 files changed, 82 insertions(+), 15 deletions(-) diff --git a/src/backend/RISCv64/RISCv64Backend.cpp b/src/backend/RISCv64/RISCv64Backend.cpp index 2797eb7..2edb21d 100644 --- a/src/backend/RISCv64/RISCv64Backend.cpp +++ b/src/backend/RISCv64/RISCv64Backend.cpp @@ -12,6 +12,39 @@ std::string RISCv64CodeGen::code_gen() { return module_gen(); } +unsigned RISCv64CodeGen::getTypeSizeInBytes(Type* type) { + if (!type) { + assert(false && "Cannot get size of a null type."); + return 0; + } + + switch (type->getKind()) { + // 对于SysY语言,基本类型int和float都占用4字节 + case Type::kInt: + case Type::kFloat: + return 4; + + // 指针类型在RISC-V 64位架构下占用8字节 + // 虽然SysY没有'int*'语法,但数组变量在IR层面本身就是指针类型 + case Type::kPointer: + return 8; + + // 数组类型的总大小 = 元素数量 * 单个元素的大小 + case Type::kArray: { + auto arrayType = type->as(); + // 递归调用以计算元素大小 + return arrayType->getNumElements() * getTypeSizeInBytes(arrayType->getElementType()); + } + + // 其他类型,如Void, Label等不占用栈空间,或者不应该出现在这里 + default: + // 如果遇到未处理的类型,触发断言,方便调试 + // assert(false && "Unsupported type for size calculation."); + return 0; // 对于像Label或Void这样的类型,返回0是合理的 + } +} + + void printInitializer(std::stringstream& ss, const ValueCounter& init_values) { for (size_t i = 0; i < init_values.getValues().size(); ++i) { auto val = init_values.getValues()[i]; @@ -39,18 +72,36 @@ std::string RISCv64CodeGen::module_gen() { for (const auto& global_ptr : module->getGlobals()) { GlobalValue* global = global_ptr.get(); + + // [核心修改] 使用更健壮的逻辑来判断是否为大型零初始化数组 + bool is_all_zeros = true; const auto& init_values = global->getInitValues(); - // 判断是否为大型零初始化数组,以便放入.bss段 - bool is_large_zero_array = false; - if (init_values.getValues().size() == 1) { - if (auto const_val = dynamic_cast(init_values.getValues()[0])) { - if (const_val->isInt() && const_val->getInt() == 0 && init_values.getNumbers()[0] > 16) { - is_large_zero_array = true; + // 检查初始化值是否全部为0 + if (init_values.getValues().empty()) { + // 如果 ValueCounter 为空,GlobalValue 的构造函数会确保它是零初始化的 + is_all_zeros = true; + } else { + for (auto val : init_values.getValues()) { + if (auto const_val = dynamic_cast(val)) { + if (!const_val->isZero()) { + is_all_zeros = false; + break; + } + } else { + // 如果初始值包含非常量(例如,另一个全局变量的地址),则不认为是纯零初始化 + is_all_zeros = false; + break; } } } + // 使用 getTypeSizeInBytes 检查总大小是否超过阈值 (16个整数 = 64字节) + Type* allocated_type = global->getType()->as()->getBaseType(); + unsigned total_size = getTypeSizeInBytes(allocated_type); + + bool is_large_zero_array = is_all_zeros && (total_size > 64); + if (is_large_zero_array) { bss_globals.push_back(global); } else { @@ -58,12 +109,12 @@ std::string RISCv64CodeGen::module_gen() { } } - // --- 步骤2:生成 .bss 段的代码 (这部分不变) --- + // --- 步骤2:生成 .bss 段的代码 --- if (!bss_globals.empty()) { ss << ".bss\n"; for (GlobalValue* global : bss_globals) { - unsigned count = global->getInitValues().getNumbers()[0]; - unsigned total_size = count * 4; // 假设元素都是4字节 + Type* allocated_type = global->getType()->as()->getBaseType(); + unsigned total_size = getTypeSizeInBytes(allocated_type); ss << " .align 3\n"; ss << ".globl " << global->getName() << "\n"; @@ -74,33 +125,45 @@ std::string RISCv64CodeGen::module_gen() { } } - // --- [修改] 步骤3:生成 .data 段的代码 --- - // 我们需要检查 data_globals 和 常量列表是否都为空 + // --- 步骤3:生成 .data 段的代码 --- if (!data_globals.empty() || !module->getConsts().empty()) { ss << ".data\n"; - // a. 先处理普通的全局变量 (GlobalValue) + // a. 处理普通的全局变量 (GlobalValue) for (GlobalValue* global : data_globals) { + Type* allocated_type = global->getType()->as()->getBaseType(); + unsigned total_size = getTypeSizeInBytes(allocated_type); + + ss << " .align 3\n"; ss << ".globl " << global->getName() << "\n"; + ss << ".type " << global->getName() << ", @object\n"; + ss << ".size " << global->getName() << ", " << total_size << "\n"; ss << global->getName() << ":\n"; printInitializer(ss, global->getInitValues()); } - // b. [新增] 再处理全局常量 (ConstantVariable) + // b. 处理全局常量 (ConstantVariable) for (const auto& const_ptr : module->getConsts()) { ConstantVariable* cnst = const_ptr.get(); + Type* allocated_type = cnst->getType()->as()->getBaseType(); + unsigned total_size = getTypeSizeInBytes(allocated_type); + + ss << " .align 3\n"; ss << ".globl " << cnst->getName() << "\n"; + ss << ".type " << cnst->getName() << ", @object\n"; + ss << ".size " << cnst->getName() << ", " << total_size << "\n"; ss << cnst->getName() << ":\n"; printInitializer(ss, cnst->getInitValues()); } } - // --- 处理函数 (.text段) 的逻辑保持不变 --- + // --- 步骤4:处理函数 (.text段) 的逻辑 --- if (!module->getFunctions().empty()) { ss << ".text\n"; for (const auto& func_pair : module->getFunctions()) { - if (func_pair.second.get()) { + if (func_pair.second.get() && !func_pair.second->getBasicBlocks().empty()) { ss << function_gen(func_pair.second.get()); + if (DEBUG) std::cerr << "Function: " << func_pair.first << " generated.\n"; } } } diff --git a/src/include/backend/RISCv64/RISCv64Backend.h b/src/include/backend/RISCv64/RISCv64Backend.h index 403d586..9e179d9 100644 --- a/src/include/backend/RISCv64/RISCv64Backend.h +++ b/src/include/backend/RISCv64/RISCv64Backend.h @@ -22,6 +22,10 @@ private: // 函数级代码生成 (实现新的流水线) std::string function_gen(Function* func); + + // 私有辅助函数,用于根据类型计算其占用的字节数。 + unsigned getTypeSizeInBytes(Type* type); + Module* module; };