Files
mysysy/src/midend/SysYIRGenerator.cpp

1750 lines
70 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// SysYIRGenerator.cpp
// TODO类型转换及其检查
// TODOsysy库函数处理
// TODO数组处理
// TODO对while、continue、break的测试
#include "IR.h"
#include <any>
#include <memory>
#include <iterator>
#include <sstream>
#include <string>
#include <vector>
#include "SysYIRGenerator.h"
using namespace std;
namespace sysy {
Type* SysYIRGenerator::buildArrayType(Type* baseType, const std::vector<Value*>& dims){
Type* currentType = baseType;
// 从最内层维度开始构建 ArrayType
// 例如对于 int arr[2][3],先处理 [3],再处理 [2]
// 注意SysY 的 dims 是从最外层到最内层,所以我们需要反向迭代
// 或者调整逻辑,使得从内到外构建 ArrayType
// 假设 dims 列表是 [dim1, dim2, dim3...] (例如 [2, 3] for int[2][3])
// 我们需要从最内层维度开始向外构建 ArrayType
for (int i = dims.size() - 1; i >= 0; --i) {
// 维度大小必须是常量,否则无法构建 ArrayType
ConstantInteger* constDim = dynamic_cast<ConstantInteger*>(dims[i]);
if (constDim == nullptr) {
// 如果维度不是常量,可能需要特殊处理,例如将其视为指针
// 对于函数参数 int arr[] 这种,第一个维度可以为未知
// 在这里,我们假设所有声明的数组维度都是常量
assert(false && "Array dimension must be a constant integer!");
return nullptr;
}
unsigned dimSize = constDim->getInt();
currentType = Type::getArrayType(currentType, dimSize);
}
return currentType;
}
// @brief: 获取 GEP 指令的地址
// @param basePointer: GEP 的基指针,已经过适当的加载/处理,类型为 LLVM IR 中的指针类型。
// 例如,对于局部数组,它是 AllocaInst对于参数数组它是 LoadInst 的结果。
// @param indices: 已经包含了所有必要的偏移索引 (包括可能的初始 0 索引,由 visitLValue 准备)。
// @return: 计算得到的地址值 (也是一个指针类型)
Value* SysYIRGenerator::getGEPAddressInst(Value* basePointer, const std::vector<Value*>& indices) {
// 检查 basePointer 是否为指针类型
assert(basePointer->getType()->isPointer() && "Base pointer must be a pointer type!");
// `indices` 向量现在由调用方(如 visitLValue, visitVarDecl, visitAssignStmt负责完整准备
// 包括是否需要添加初始的 `0` 索引。
// 所以这里直接将其传递给 `builder.createGetElementPtrInst`。
return builder.createGetElementPtrInst(basePointer, indices);
}
/*
* @brief: visit compUnit
* @details:
* compUnit: (globalDecl | funcDef)+;
*/
std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) {
// create the IR module
auto pModule = new Module();
assert(pModule);
module.reset(pModule);
// SymbolTable::ModuleScope scope(symbols_table);
Utils::initExternalFunction(pModule, &builder);
pModule->enterNewScope();
visitChildren(ctx);
pModule->leaveScope();
return pModule;
}
std::any SysYIRGenerator::visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx){
auto constDecl = ctx->constDecl();
Type* type = std::any_cast<Type *>(visitBType(constDecl->bType()));
for (const auto &constDef : constDecl->constDef()) {
std::vector<Value *> dims = {};
std::string name = constDef->Ident()->getText();
auto constExps = constDef->constExp();
if (!constExps.empty()) {
for (const auto &constExp : constExps) {
dims.push_back(std::any_cast<Value *>(visitConstExp(constExp)));
}
}
ArrayValueTree* root = std::any_cast<ArrayValueTree *>(constDef->constInitVal()->accept(this));
ValueCounter values;
Utils::tree2Array(type, root, dims, dims.size(), values, &builder);
delete root;
// 创建全局常量变量,并更新符号表
Type* variableType = type;
if (!dims.empty()) { // 如果有维度,说明是数组
variableType = buildArrayType(type, dims); // 构建完整的 ArrayType
}
module->createConstVar(name, Type::getPointerType(variableType), values);
}
return std::any();
}
std::any SysYIRGenerator::visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *ctx) {
auto varDecl = ctx->varDecl();
Type* type = std::any_cast<Type *>(visitBType(varDecl->bType()));
for (const auto &varDef : varDecl->varDef()) {
std::vector<Value *> dims = {};
std::string name = varDef->Ident()->getText();
auto constExps = varDef->constExp();
if (!constExps.empty()) {
for (const auto &constExp : constExps) {
dims.push_back(std::any_cast<Value *>(visitConstExp(constExp)));
}
}
ValueCounter values = {};
if (varDef->initVal() != nullptr) {
ArrayValueTree* root = std::any_cast<ArrayValueTree *>(varDef->initVal()->accept(this));
Utils::tree2Array(type, root, dims, dims.size(), values, &builder);
delete root;
}
// 创建全局变量,并更新符号表
Type* variableType = type;
if (!dims.empty()) { // 如果有维度,说明是数组
variableType = buildArrayType(type, dims); // 构建完整的 ArrayType
}
module->createGlobalValue(name, Type::getPointerType(variableType), values);
}
return std::any();
}
std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx) {
Type *type = std::any_cast<Type *>(visitBType(ctx->bType()));
for (const auto constDef : ctx->constDef()) {
std::vector<Value *> dims = {};
std::string name = constDef->Ident()->getText();
auto constExps = constDef->constExp();
if (!constExps.empty()) {
for (const auto constExp : constExps) {
dims.push_back(std::any_cast<Value *>(visitConstExp(constExp)));
}
}
Type *variableType = type;
if (!dims.empty()) {
variableType = buildArrayType(type, dims); // 构建完整的 ArrayType
}
// 显式地为局部常量在栈上分配空间
// alloca 的类型将是指针指向常量类型,例如 `int*` 或 `int[2][3]*`
AllocaInst *alloca = builder.createAllocaInst(Type::getPointerType(variableType), name);
ArrayValueTree *root = std::any_cast<ArrayValueTree *>(constDef->constInitVal()->accept(this));
ValueCounter values;
Utils::tree2Array(type, root, dims, dims.size(), values, &builder);
delete root;
// 根据维度信息进行 store 初始化
if (dims.empty()) { // 标量常量初始化
// 局部常量必须有初始值,且通常是单个值
if (!values.getValues().empty()) {
builder.createStoreInst(values.getValue(0), alloca);
} else {
// 错误处理:局部标量常量缺少初始化值
// 或者可以考虑默认初始化为0但这通常不符合常量的语义
assert(false && "Local scalar constant must have an initialization value!");
return std::any(); // 直接返回,避免继续执行
}
} else { // 数组常量初始化
const std::vector<sysy::Value *> &counterValues = values.getValues();
const std::vector<unsigned> &counterNumbers = values.getNumbers();
int numElements = 1;
std::vector<int> dimSizes;
for (Value *dimVal : dims) {
if (ConstantInteger *constInt = dynamic_cast<ConstantInteger *>(dimVal)) {
int dimSize = constInt->getInt();
numElements *= dimSize;
dimSizes.push_back(dimSize);
}
// TODO else 错误处理:数组维度必须是常量(对于静态分配)
else {
assert(false && "Array dimension must be a constant integer!");
return std::any(); // 直接返回,避免继续执行
}
}
unsigned int elementSizeInBytes = type->getSize();
unsigned int totalSizeInBytes = numElements * elementSizeInBytes;
// 检查是否所有初始化值都是零
bool allValuesAreZero = false;
if (counterValues.empty()) { // 如果没有提供初始化值,通常视为全零初始化
allValuesAreZero = true;
} else {
allValuesAreZero = true;
for (Value *val : counterValues) {
if (ConstantInteger *constInt = dynamic_cast<ConstantInteger *>(val)) {
if (constInt->getInt() != 0) {
allValuesAreZero = false;
break;
}
} else { // 如果不是常量整数,则不能确定是零
allValuesAreZero = false;
break;
}
}
}
if (allValuesAreZero) {
builder.createMemsetInst(alloca, ConstantInteger::get(0), ConstantInteger::get(totalSizeInBytes),
ConstantInteger::get(0));
} else {
int linearIndexOffset = 0; // 用于追踪当前处理的线性索引的偏移量
for (int k = 0; k < counterValues.size(); ++k) {
// 当前 Value 的值和重复次数
Value *currentValue = counterValues[k];
unsigned currentRepeatNum = counterNumbers[k];
for (unsigned i = 0; i < currentRepeatNum; ++i) {
std::vector<Value *> currentIndices;
int tempLinearIndex = linearIndexOffset + i; // 使用偏移量和当前重复次数内的索引
// 将线性索引转换为多维索引
for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) {
currentIndices.insert(currentIndices.begin(),
ConstantInteger::get(static_cast<int>(tempLinearIndex % dimSizes[dimIdx])));
tempLinearIndex /= dimSizes[dimIdx];
}
// 对于局部数组alloca 本身就是 GEP 的基指针。
// GEP 的第一个索引必须是 0用于“步过”整个数组。
std::vector<Value *> gepIndicesForInit;
gepIndicesForInit.push_back(ConstantInteger::get(0));
gepIndicesForInit.insert(gepIndicesForInit.end(), currentIndices.begin(), currentIndices.end());
// 计算元素的地址
Value *elementAddress = getGEPAddressInst(alloca, gepIndicesForInit);
// 生成 store 指令
builder.createStoreInst(currentValue, elementAddress);
}
// 更新线性索引偏移量,以便下一次迭代从正确的位置开始
linearIndexOffset += currentRepeatNum;
}
}
}
// 更新符号表,将常量名称与 AllocaInst 关联起来
module->addVariable(name, alloca);
}
return std::any();
}
std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) {
Type* type = std::any_cast<Type *>(visitBType(ctx->bType()));
for (const auto varDef : ctx->varDef()) {
std::vector<Value *> dims = {};
std::string name = varDef->Ident()->getText();
auto constExps = varDef->constExp();
if (!constExps.empty()) {
for (const auto &constExp : constExps) {
dims.push_back(std::any_cast<Value *>(visitConstExp(constExp)));
}
}
Type* variableType = type;
if (!dims.empty()) { // 如果有维度,说明是数组
variableType = buildArrayType(type, dims); // 构建完整的 ArrayType
}
// 对于数组alloca 的类型将是指针指向数组类型,例如 `int[2][3]*`
// 对于标量alloca 的类型将是指针指向标量类型,例如 `int*`
AllocaInst* alloca =
builder.createAllocaInst(Type::getPointerType(variableType), name);
if (varDef->initVal() != nullptr) {
ValueCounter values;
ArrayValueTree* root = std::any_cast<ArrayValueTree *>(varDef->initVal()->accept(this));
Utils::tree2Array(type, root, dims, dims.size(), values, &builder);
delete root;
if (dims.empty()) { // 标量变量初始化
builder.createStoreInst(values.getValue(0), alloca);
} else { // 数组变量初始化
const std::vector<sysy::Value *> &counterValues = values.getValues();
const std::vector<unsigned> &counterNumbers = values.getNumbers();
int numElements = 1;
std::vector<int> dimSizes;
for (Value *dimVal : dims) {
if (ConstantInteger *constInt = dynamic_cast<ConstantInteger *>(dimVal)) {
int dimSize = constInt->getInt();
numElements *= dimSize;
dimSizes.push_back(dimSize);
}
// TODO else 错误处理:数组维度必须是常量(对于静态分配)
}
unsigned int elementSizeInBytes = type->getSize();
unsigned int totalSizeInBytes = numElements * elementSizeInBytes;
bool allValuesAreZero = false;
if (counterValues.empty()) {
allValuesAreZero = true;
}
else {
allValuesAreZero = true;
for (Value *val : counterValues){
if (ConstantInteger *constInt = dynamic_cast<ConstantInteger *>(val)){
if (constInt->getInt() != 0){
allValuesAreZero = false;
break;
}
}
else{
allValuesAreZero = false;
break;
}
}
}
if (allValuesAreZero) {
builder.createMemsetInst(
alloca,
ConstantInteger::get(0),
ConstantInteger::get(totalSizeInBytes),
ConstantInteger::get(0));
}
else {
int linearIndexOffset = 0; // 用于追踪当前处理的线性索引的偏移量
for (int k = 0; k < counterValues.size(); ++k) {
// 当前 Value 的值和重复次数
Value* currentValue = counterValues[k];
unsigned currentRepeatNum = counterNumbers[k];
for (unsigned i = 0; i < currentRepeatNum; ++i) {
std::vector<Value *> currentIndices;
int tempLinearIndex = linearIndexOffset + i; // 使用偏移量和当前重复次数内的索引
// 将线性索引转换为多维索引
for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) {
currentIndices.insert(currentIndices.begin(),
ConstantInteger::get(static_cast<int>(tempLinearIndex % dimSizes[dimIdx])));
tempLinearIndex /= dimSizes[dimIdx];
}
// 对于局部数组alloca 本身就是 GEP 的基指针。
// GEP 的第一个索引必须是 0用于“步过”整个数组。
std::vector<Value*> gepIndicesForInit;
gepIndicesForInit.push_back(ConstantInteger::get(0));
gepIndicesForInit.insert(gepIndicesForInit.end(), currentIndices.begin(), currentIndices.end());
// 计算元素的地址
Value* elementAddress = getGEPAddressInst(alloca, gepIndicesForInit);
// 生成 store 指令
builder.createStoreInst(currentValue, elementAddress);
}
// 更新线性索引偏移量,以便下一次迭代从正确的位置开始
linearIndexOffset += currentRepeatNum;
}
}
}
}
else { // 如果没有显式初始化值,默认对数组进行零初始化
if (!dims.empty()) { // 只有数组才需要默认的零初始化
int numElements = 1;
for (Value *dimVal : dims) {
if (ConstantInteger *constInt = dynamic_cast<ConstantInteger *>(dimVal)) {
numElements *= constInt->getInt();
}
}
unsigned int elementSizeInBytes = type->getSize();
unsigned int totalSizeInBytes = numElements * elementSizeInBytes;
builder.createMemsetInst(
alloca,
ConstantInteger::get(0),
ConstantInteger::get(totalSizeInBytes),
ConstantInteger::get(0)
);
}
}
module->addVariable(name, alloca);
}
return std::any();
}
std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) {
return ctx->INT() != nullptr ? Type::getIntType() : Type::getFloatType();
}
std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) {
Value* value = std::any_cast<Value *>(visitExp(ctx->exp()));
ArrayValueTree* result = new ArrayValueTree();
result->setValue(value);
return result;
}
std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) {
std::vector<ArrayValueTree *> children;
for (const auto &initVal : ctx->initVal())
children.push_back(std::any_cast<ArrayValueTree *>(initVal->accept(this)));
ArrayValueTree* result = new ArrayValueTree();
result->addChildren(children);
return result;
}
std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) {
Value* value = std::any_cast<Value *>(visitConstExp(ctx->constExp()));
ArrayValueTree* result = new ArrayValueTree();
result->setValue(value);
return result;
}
std::any SysYIRGenerator::visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) {
std::vector<ArrayValueTree *> children;
for (const auto &constInitVal : ctx->constInitVal())
children.push_back(std::any_cast<ArrayValueTree *>(constInitVal->accept(this)));
ArrayValueTree* result = new ArrayValueTree();
result->addChildren(children);
return result;
}
std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) {
if (ctx->INT() != nullptr)
return Type::getIntType();
if (ctx->FLOAT() != nullptr)
return Type::getFloatType();
return Type::getVoidType();
}
std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){
// 更新作用域
module->enterNewScope();
auto name = ctx->Ident()->getText();
std::vector<Type *> paramActualTypes;
std::vector<std::string> paramNames;
std::vector<std::vector<Value *>> paramDims;
if (ctx->funcFParams() != nullptr) {
auto params = ctx->funcFParams()->funcFParam();
for (const auto &param : params) {
Type* baseBType = std::any_cast<Type *>(visitBType(param->bType()));
std::string paramName = param->Ident()->getText();
// 用于收集当前参数的维度信息(如果它是数组)
std::vector<Value *> currentParamDims;
if (!param->LBRACK().empty()) { // 如果参数声明中有方括号,说明是数组
// SysY 数组参数的第一个维度可以是未知的(例如 int arr[] 或 int arr[][10]
// 这里的 ConstantInteger::get(-1) 表示未知维度,但对于 LLVM 类型构建,我们主要关注已知维度
currentParamDims.push_back(ConstantInteger::get(-1)); // 标记第一个维度为未知
for (const auto &exp : param->exp()) {
// 访问表达式以获取维度大小,这些维度必须是常量
Value* dimVal = std::any_cast<Value *>(visitExp(exp));
// 确保维度是常量整数,否则 buildArrayType 会断言失败
assert(dynamic_cast<ConstantInteger*>(dimVal) && "Array dimension in parameter must be a constant integer!");
currentParamDims.push_back(dimVal);
}
}
// 根据解析出的信息,确定参数在 LLVM IR 中的实际类型
Type* actualParamType;
if (currentParamDims.empty()) { // 情况1标量参数 (e.g., int x)
actualParamType = baseBType; // 实际类型就是基本类型
} else { // 情况2&3数组参数 (e.g., int arr[] 或 int arr[][10])
// 数组参数在函数传递时会退化为指针。
// 这个指针指向的类型是除第一维外,由后续维度构成的数组类型。
// 从 currentParamDims 中移除第一个标记未知维度的 -1
std::vector<Value*> fixedDimsForTypeBuilding;
if (currentParamDims.size() > 1) { // 如果有固定维度 (e.g., int arr[][10])
// 复制除第一个 -1 之外的所有维度
fixedDimsForTypeBuilding.assign(currentParamDims.begin() + 1, currentParamDims.end());
}
Type* pointedToArrayType = baseBType; // 从基本类型开始构建
// 从最内层维度向外层构建数组类型
// buildArrayType 期望 dims 是从最外层到最内层,但它内部反向迭代,所以这里直接传入
// 例如,对于 int arr[][10]fixedDimsForTypeBuilding 包含 [10],构建出 [10 x i32]
if (!fixedDimsForTypeBuilding.empty()) {
pointedToArrayType = buildArrayType(baseBType, fixedDimsForTypeBuilding);
}
// 实际参数类型是指向这个构建好的数组类型的指针
actualParamType = Type::getPointerType(pointedToArrayType); // e.g., i32* 或 [10 x i32]*
}
paramActualTypes.push_back(actualParamType); // 存储参数的实际 LLVM IR 类型
paramNames.push_back(paramName); // 存储参数名称
}
}
Type* returnType = std::any_cast<Type *>(visitFuncType(ctx->funcType()));
Type* funcType = Type::getFunctionType(returnType, paramActualTypes);
Function* function = module->createFunction(name, funcType);
BasicBlock* entry = function->getEntryBlock();
builder.setPosition(entry, entry->end());
for(int i = 0; i < paramActualTypes.size(); ++i) {
Argument* arg = new Argument(paramActualTypes[i], function, i, paramNames[i]);
function->insertArgument(arg);
}
auto funcArgs = function->getArguments();
std::vector<AllocaInst *> allocas;
for (int i = 0; i < paramActualTypes.size(); ++i) {
AllocaInst *alloca = builder.createAllocaInst(Type::getPointerType(paramActualTypes[i]), paramNames[i]);
allocas.push_back(alloca);
module->addVariable(paramNames[i], alloca);
}
for(int i = 0; i < paramActualTypes.size(); ++i) {
Value *argValue = funcArgs[i];
builder.createStoreInst(argValue, allocas[i]);
}
// 在处理函数体之前,创建一个新的基本块作为函数体的实际入口
// 这样 entryBB 就可以在完成初始化后跳转到这里
BasicBlock* funcBodyEntry = function->addBasicBlock("funcBodyEntry_" + name);
// 从 entryBB 无条件跳转到 funcBodyEntry
builder.createUncondBrInst(funcBodyEntry);
builder.setPosition(funcBodyEntry,funcBodyEntry->end()); // 将插入点设置到 funcBodyEntry
for (auto item : ctx->blockStmt()->blockItem()) {
visitBlockItem(item);
}
// 如果函数没有显式的返回语句,且返回类型不是 void则需要添加一个默认的返回值
ReturnInst* retinst = nullptr;
retinst = dynamic_cast<ReturnInst*>(builder.getBasicBlock()->terminator()->get());
if (!retinst) {
if (returnType->isVoid()) {
builder.createReturnInst();
} else if (returnType->isInt()) {
builder.createReturnInst(ConstantInteger::get(0)); // 默认返回 0
} else if (returnType->isFloat()) {
builder.createReturnInst(ConstantFloating::get(0.0f)); // 默认返回 0.0f
} else {
assert(false && "Function with no explicit return and non-void type should return a value.");
}
}
module->leaveScope();
return std::any();
}
std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) {
module->enterNewScope();
for (auto item : ctx->blockItem())
visitBlockItem(item);
module->leaveScope();
return 0;
}
std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
auto lVal = ctx->lValue();
std::string name = lVal->Ident()->getText();
Value* LValue = nullptr;
Value* variable = module->getVariable(name); // 左值
vector<Value *> indices;
if (lVal->exp().size() > 0) {
// 如果有下标,访问表达式获取下标值
for (const auto &exp : lVal->exp()) {
Value* indexValue = std::any_cast<Value *>(visitExp(exp));
indices.push_back(indexValue);
}
}
if (indices.empty()) {
// variable 本身就是指向标量的指针 (e.g., int* %a)
if (dynamic_cast<AllocaInst*>(variable) || dynamic_cast<GlobalValue*>(variable)) {
LValue = variable;
}
}
else {
// 对于数组或多维数组的左值处理
// 需要获取 GEP 地址
Value* gepBasePointer = nullptr;
std::vector<Value*> gepIndices;
if (AllocaInst *alloc = dynamic_cast<AllocaInst *>(variable)) {
Type* allocatedType = alloc->getType()->as<PointerType>()->getBaseType();
if (allocatedType->isPointer()) {
gepBasePointer = builder.createLoadInst(alloc);
gepIndices = indices;
} else {
gepBasePointer = alloc;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
}
} else if (GlobalValue *glob = dynamic_cast<GlobalValue *>(variable)) {
// 情况 B: 全局变量 (GlobalValue)
gepBasePointer = glob;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
} else if (ConstantVariable *constV = dynamic_cast<ConstantVariable *>(variable)) {
gepBasePointer = constV;
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), indices.begin(), indices.end());
}
// 左值为地址
LValue = getGEPAddressInst(gepBasePointer, gepIndices);
}
Value* RValue = std::any_cast<Value *>(visitExp(ctx->exp())); // 右值
// 先推断 LValue 的类型
// 如果 LValue 是指向数组的指针,则需要根据 indices 获取正确的类型
// 如果 LValue 是标量,则直接使用其类型
// 注意LValue 的类型可能是指向数组的指针 (e.g., int(*)[3]) 或者指向标量的指针 (e.g., int*) 也能推断
Type* LType = builder.getIndexedType(variable->getType(), indices);
Type* RType = RValue->getType();
if (LType != RType) {
ConstantValue *constValue = dynamic_cast<ConstantValue *>(RValue);
if (constValue != nullptr) {
if (LType == Type::getFloatType()) {
if(dynamic_cast<ConstantInteger *>(constValue)) {
// 如果是整型常量,转换为浮点型
RValue = ConstantFloating::get(static_cast<float>(constValue->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constValue)) {
// 如果是浮点型常量,直接使用
RValue = ConstantFloating::get(static_cast<float>(constValue->getFloat()));
}
} else { // 假设如果不是浮点型,就是整型
if(dynamic_cast<ConstantFloating *>(constValue)) {
// 如果是浮点型常量,转换为整型
RValue = ConstantInteger::get(static_cast<int>(constValue->getFloat()));
} else if (dynamic_cast<ConstantInteger *>(constValue)) {
// 如果是整型常量,直接使用
RValue = ConstantInteger::get(static_cast<int>(constValue->getInt()));
}
}
} else {
if (LType == Type::getFloatType()) {
RValue = builder.createIToFInst(RValue);
} else { // 假设如果不是浮点型,就是整型
RValue = builder.createFtoIInst(RValue);
}
}
}
builder.createStoreInst(RValue, LValue);
return std::any();
}
std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) {
// labels string stream
std::stringstream labelstring;
Function * function = builder.getBasicBlock()->getParent();
BasicBlock* thenBlock = new BasicBlock(function);
BasicBlock* exitBlock = new BasicBlock(function);
if (ctx->stmt().size() > 1) {
BasicBlock* elseBlock = new BasicBlock(function);
builder.pushTrueBlock(thenBlock);
builder.pushFalseBlock(elseBlock);
// 访问条件表达式
visitCond(ctx->cond());
builder.popTrueBlock();
builder.popFalseBlock();
labelstring << "if_then.L" << builder.getLabelIndex();
thenBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(thenBlock);
builder.setPosition(thenBlock, thenBlock->end());
auto block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt(0));
// 如果是块语句,直接访问
// 否则访问语句
if (block != nullptr) {
visitBlockStmt(block);
} else {
module->enterNewScope();
ctx->stmt(0)->accept(this);
module->leaveScope();
}
builder.createUncondBrInst(exitBlock);
BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock);
labelstring << "if_else.L" << builder.getLabelIndex();
elseBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(elseBlock);
builder.setPosition(elseBlock, elseBlock->end());
block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt(1));
if (block != nullptr) {
visitBlockStmt(block);
} else {
module->enterNewScope();
ctx->stmt(1)->accept(this);
module->leaveScope();
}
builder.createUncondBrInst(exitBlock);
BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock);
labelstring << "if_exit.L" << builder.getLabelIndex();
exitBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(exitBlock);
builder.setPosition(exitBlock, exitBlock->end());
} else {
builder.pushTrueBlock(thenBlock);
builder.pushFalseBlock(exitBlock);
visitCond(ctx->cond());
builder.popTrueBlock();
builder.popFalseBlock();
labelstring << "if_then.L" << builder.getLabelIndex();
thenBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(thenBlock);
builder.setPosition(thenBlock, thenBlock->end());
auto block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt(0));
if (block != nullptr) {
visitBlockStmt(block);
} else {
module->enterNewScope();
ctx->stmt(0)->accept(this);
module->leaveScope();
}
builder.createUncondBrInst(exitBlock);
BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock);
labelstring << "if_exit.L" << builder.getLabelIndex();
exitBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(exitBlock);
builder.setPosition(exitBlock, exitBlock->end());
}
return std::any();
}
std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) {
// while structure:
// curblock -> headBlock -> bodyBlock -> exitBlock
BasicBlock* curBlock = builder.getBasicBlock();
Function* function = builder.getBasicBlock()->getParent();
std::stringstream labelstring;
labelstring << "while_head.L" << builder.getLabelIndex();
BasicBlock *headBlock = function->addBasicBlock(labelstring.str());
labelstring.str("");
builder.createUncondBrInst(headBlock);
BasicBlock::conectBlocks(curBlock, headBlock);
builder.setPosition(headBlock, headBlock->end());
BasicBlock* bodyBlock = new BasicBlock(function);
BasicBlock* exitBlock = new BasicBlock(function);
builder.pushTrueBlock(bodyBlock);
builder.pushFalseBlock(exitBlock);
// 访问条件表达式
visitCond(ctx->cond());
builder.popTrueBlock();
builder.popFalseBlock();
labelstring << "while_body.L" << builder.getLabelIndex();
bodyBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(bodyBlock);
builder.setPosition(bodyBlock, bodyBlock->end());
builder.pushBreakBlock(exitBlock);
builder.pushContinueBlock(headBlock);
auto block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt());
if( block != nullptr) {
visitBlockStmt(block);
} else {
module->enterNewScope();
ctx->stmt()->accept(this);
module->leaveScope();
}
builder.createUncondBrInst(headBlock);
BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock);
builder.popBreakBlock();
builder.popContinueBlock();
labelstring << "while_exit.L" << builder.getLabelIndex();
exitBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(exitBlock);
builder.setPosition(exitBlock, exitBlock->end());
return std::any();
}
std::any SysYIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx) {
BasicBlock* breakBlock = builder.getBreakBlock();
builder.createUncondBrInst(breakBlock);
BasicBlock::conectBlocks(builder.getBasicBlock(), breakBlock);
return std::any();
}
std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx) {
BasicBlock* continueBlock = builder.getContinueBlock();
builder.createUncondBrInst(continueBlock);
BasicBlock::conectBlocks(builder.getBasicBlock(), continueBlock);
return std::any();
}
std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) {
Value* returnValue = nullptr;
if (ctx->exp() != nullptr) {
returnValue = std::any_cast<Value *>(visitExp(ctx->exp()));
}
Type* funcType = builder.getBasicBlock()->getParent()->getReturnType();
if (returnValue != nullptr && funcType!= returnValue->getType()) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(returnValue);
if (constValue != nullptr) {
if (funcType == Type::getFloatType()) {
if(dynamic_cast<ConstantInteger *>(constValue)) {
// 如果是整型常量,转换为浮点型
returnValue = ConstantFloating::get(static_cast<float>(constValue->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constValue)) {
// 如果是浮点型常量,直接使用
returnValue = ConstantFloating::get(static_cast<float>(constValue->getInt()));
}
} else {
if(dynamic_cast<ConstantFloating *>(constValue)) {
// 如果是浮点型常量,转换为整型
returnValue = ConstantInteger::get(static_cast<int>(constValue->getFloat()));
} else if (dynamic_cast<ConstantInteger *>(constValue)) {
// 如果是整型常量,直接使用
returnValue = ConstantInteger::get(static_cast<int>(constValue->getFloat()));
}
}
} else {
if (funcType == Type::getFloatType()) {
returnValue = builder.createIToFInst(returnValue);
} else {
returnValue = builder.createFtoIInst(returnValue);
}
}
}
builder.createReturnInst(returnValue);
return std::any();
}
// 辅助函数:计算给定类型中嵌套的数组维度数量
// 例如:
// - 对于 i32* 类型,它指向 i32维度为 0。
// - 对于 [10 x i32]* 类型,它指向 [10 x i32],维度为 1。
// - 对于 [20 x [10 x i32]]* 类型,它指向 [20 x [10 x i32]],维度为 2。
unsigned SysYIRGenerator::countArrayDimensions(Type* type) {
unsigned dims = 0;
Type* currentType = type;
// 如果是指针类型,先获取它指向的基础类型
if (currentType->isPointer()) {
currentType = currentType->as<PointerType>()->getBaseType();
}
// 递归地计算数组的维度层数
while (currentType && currentType->isArray()) {
dims++;
currentType = currentType->as<ArrayType>()->getElementType();
}
return dims;
}
std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) {
std::string name = ctx->Ident()->getText();
Value* variable = module->getVariable(name);
Value* value = nullptr;
std::vector<Value *> dims;
for (const auto &exp : ctx->exp()) {
dims.push_back(std::any_cast<Value *>(visitExp(exp)));
}
// 1. 获取变量的声明维度数量
unsigned declaredNumDims = countArrayDimensions(variable->getType());
// 2. 处理常量变量 (ConstantVariable) 且所有索引都是常量的情况
ConstantVariable* constVar = dynamic_cast<ConstantVariable *>(variable);
if (constVar != nullptr) {
bool allIndicesConstant = true;
for (const auto &dim : dims) {
if (dynamic_cast<ConstantValue *>(dim) == nullptr) {
allIndicesConstant = false;
break;
}
}
if (allIndicesConstant) {
// 如果是常量变量且所有索引都是常量,直接通过 getByIndices 获取编译时值
// 这个方法会根据索引深度返回最终的标量值或指向子数组的指针 (作为 ConstantValue/Variable)
return constVar->getByIndices(dims);
}
}
// 3. 处理可变变量 (AllocaInst/GlobalValue) 或带非常量索引的常量变量
// 这里区分标量访问和数组元素/子数组访问
// 检查是否是访问标量变量本身没有索引且声明维度为0
if (dims.empty() && declaredNumDims == 0) {
// 对于标量变量,直接加载其值。
// variable 本身就是指向标量的指针 (e.g., int* %a)
if (dynamic_cast<AllocaInst*>(variable) || dynamic_cast<GlobalValue*>(variable)) {
value = builder.createLoadInst(variable);
} else {
// 如果走到这里且不是AllocaInst/GlobalValue但dims为空且declaredNumDims为0
// 且又不是ConstantVariable (前面已处理),则可能是错误情况。
assert(false && "Unhandled scalar variable type in LValue access.");
return static_cast<Value*>(nullptr);
}
} else {
// 访问数组元素或子数组(有索引,或变量本身是数组/多维指针)
Value* gepBasePointer = nullptr;
std::vector<Value*> gepIndices; // 准备传递给 getGEPAddressInst 的索引列表
// GEP 的基指针就是变量本身(它是一个指向内存的指针)
if (AllocaInst *alloc = dynamic_cast<AllocaInst *>(variable)) {
// 情况 A: 局部变量 (AllocaInst)
// 获取 AllocaInst 分配的内存的实际类型。
// 例如:对于 `int b[10][20];``allocatedType` 是 `[10 x [20 x i32]]`。
// 对于 `int b[][20]` 的函数参数,其 AllocaInst 存储的是一个指针,
// 此时 `allocatedType` 是 `[20 x i32]*`。
Type* allocatedType = alloc->getType()->as<PointerType>()->getBaseType();
if (allocatedType->isPointer()) {
// 如果 AllocaInst 分配的是一个指针类型 (例如,用于存储函数参数的指针,如 int b[][20] 中的 b)
// 即 `allocatedType` 是一个指向数组指针的指针 (e.g., [20 x i32]**)
// 那么 GEP 的基指针是加载这个指针变量的值。
gepBasePointer = builder.createLoadInst(alloc); // 加载出实际的指针值 (e.g., [20 x i32]*)
// 对于这种参数指针,用户提供的索引直接作用于它。不需要额外的 0。
gepIndices = dims;
} else {
// 如果 AllocaInst 分配的是实际的数组数据 (例如int b[10][20] 中的 b)
// 那么 AllocaInst 本身就是 GEP 的基指针。
// 这里的 `alloc` 是指向数组的指针 (e.g., [10 x [20 x i32]]*)
gepBasePointer = alloc; // 类型是 [10 x [20 x i32]]*
// 对于这种完整的数组分配GEP 的第一个索引必须是 0用于“步过”整个数组。
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), dims.begin(), dims.end());
}
} else if (GlobalValue *glob = dynamic_cast<GlobalValue *>(variable)) {
// 情况 B: 全局变量 (GlobalValue)
// GlobalValue 总是指向全局数据的指针。
gepBasePointer = glob; // 类型是 [61 x [67 x i32]]*
// 对于全局数组GEP 的第一个索引必须是 0用于“步过”整个数组。
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), dims.begin(), dims.end());
} else if (ConstantVariable *constV = dynamic_cast<ConstantVariable *>(variable)) {
// 情况 C: 常量变量 (ConstantVariable),如果它代表全局数组常量
// 假设 ConstantVariable 可以直接作为 GEP 的基指针。
gepBasePointer = constV;
// 对于常量数组,也需要 0 索引来“步过”整个数组。
// 这里可以进一步检查 constV->getType()->as<PointerType>()->getBaseType()->isArray()
// 但为了简洁,假设所有 ConstantVariable 作为 GEP 基指针时都需要此 0。
gepIndices.push_back(ConstantInteger::get(0));
gepIndices.insert(gepIndices.end(), dims.begin(), dims.end());
} else {
assert(false && "LValue variable type not supported for GEP base pointer.");
return static_cast<Value *>(nullptr);
}
// 现在调用 getGEPAddressInst传入正确准备的基指针和索引列表
Value *targetAddress = getGEPAddressInst(gepBasePointer, gepIndices);
// 如果提供的索引数量少于声明的维度数量,则表示访问的是子数组,返回其地址
if (dims.size() < declaredNumDims) {
value = targetAddress;
} else {
// 否则,表示访问的是最终的标量元素,加载其值
// 假设 createLoadInst 接受 Value* pointer
value = builder.createLoadInst(targetAddress);
}
}
return value;
}
std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) {
if (ctx->exp() != nullptr)
return visitExp(ctx->exp());
if (ctx->lValue() != nullptr)
return visitLValue(ctx->lValue());
if (ctx->number() != nullptr)
return visitNumber(ctx->number());
if (ctx->string() != nullptr) {
cout << "String literal not supported in SysYIRGenerator." << endl;
}
return visitNumber(ctx->number());
}
std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) {
if (ctx->ILITERAL() != nullptr) {
int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0);
return static_cast<Value *>(ConstantInteger::get(value));
} else if (ctx->FLITERAL() != nullptr) {
float value = std::stof(ctx->FLITERAL()->getText());
return static_cast<Value *>(ConstantFloating::get(value));
}
throw std::runtime_error("Unknown number type.");
return std::any(); // 不会到达这里
}
std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
std::string funcName = ctx->Ident()->getText();
Function *function = module->getFunction(funcName);
if (function == nullptr) {
function = module->getExternalFunction(funcName);
if (function == nullptr) {
std::cout << "The function " << funcName << " no defined." << std::endl;
assert(function);
}
}
std::vector<Value *> args = {};
if (funcName == "starttime" || funcName == "stoptime") {
args.emplace_back(
ConstantInteger::get(static_cast<int>(ctx->getStart()->getLine())));
} else {
if (ctx->funcRParams() != nullptr) {
args = std::any_cast<std::vector<Value *>>(visitFuncRParams(ctx->funcRParams()));
}
// 获取形参列表。`getArguments()` 返回的是 `Argument*` 的集合,
// 每个 `Argument` 代表一个函数形参,其 `getType()` 就是指向形参的类型的指针类型。
auto formalParams = function->getArguments();
// 检查实参和形参数量是否匹配。
if (args.size() != formalParams.size()) {
std::cerr << "Error: Function call argument count mismatch for function '" << funcName << "'." << std::endl;
assert(false && "Function call argument count mismatch!");
}
for (int i = 0; i < args.size(); i++) {
// 形参的类型 (e.g., i32, float, i32*, [10 x i32]*)
Type* formalParamExpectedValueType = formalParams[i]->getType();
// 实参的实际类型 (e.g., i32, float, i32*, [67 x i32]*)
Type* actualArgType = args[i]->getType();
// 如果实参类型与形参类型不匹配,则尝试进行类型转换
if (formalParamExpectedValueType != actualArgType) {
ConstantValue *constValue = dynamic_cast<ConstantValue *>(args[i]);
if (constValue != nullptr) {
if (formalParamExpectedValueType->isInt() && actualArgType->isFloat()) {
args[i] = ConstantInteger::get(static_cast<int>(constValue->getFloat()));
} else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) {
args[i] = ConstantFloating::get(static_cast<float>(constValue->getInt()));
} else {
// 如果是常量但不是简单的 int/float 标量转换,
// 或者是指针常量需要 bitcast则让它进入非常量转换逻辑。
// 例如,一个常量数组的地址,需要 bitcast 成另一种指针类型。
// 目前不知道样例有没有这种情况,所以这里不做处理。
}
}
else {
// 1. 标量值类型转换 (例如int_reg 到 float_regfloat_reg 到 int_reg)
if (formalParamExpectedValueType->isInt() && actualArgType->isFloat()) {
args[i] = builder.createFtoIInst(args[i]);
} else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) {
args[i] = builder.createIToFInst(args[i]);
}
// 2. 指针类型转换 (例如数组退化:`[N x T]*` 到 `T*`,或兼容指针类型之间) TODO不清楚有没有这种样例
// 这种情况常见于数组参数,实参可能是一个更具体的数组指针类型,
// 而形参是其退化后的基础指针类型。LLVM 的 `bitcast` 指令可以用于
// 在相同大小的指针类型之间进行转换,这对于数组退化至关重要。
// else if (formalParamType->isPointer() && actualArgType->isPointer()) {
// 检查指针基类型是否兼容,或者是否是数组退化导致的类型不同。
// 使用 bitcast
// args[i] = builder.createBitCastInst(args[i], formalParamType);
// }
// 3. 其他未预期的类型不匹配
// 如果代码执行到这里,说明存在编译器前端未处理的类型不兼容或错误。
else {
// assert(false && "Unhandled type mismatch for function call argument.");
}
}
}
}
}
return static_cast<Value *>(builder.createCallInst(function, args));
}
std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) {
if (ctx->primaryExp() != nullptr)
return visitPrimaryExp(ctx->primaryExp());
if (ctx->call() != nullptr)
return visitCall(ctx->call());
Value* value = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp()));
Value* result = value;
if (ctx->unaryOp()->SUB() != nullptr) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (constValue->isFloat()) {
result = ConstantFloating::get(-constValue->getFloat());
} else {
result = ConstantInteger::get(-constValue->getInt());
}
} else if (value != nullptr) {
if (value->getType() == Type::getIntType()) {
result = builder.createNegInst(value);
} else {
result = builder.createFNegInst(value);
}
} else {
std::cout << "UnExp: value is nullptr." << std::endl;
assert(false);
}
} else if (ctx->unaryOp()->NOT() != nullptr) {
auto constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (constValue->isFloat()) {
result =
ConstantFloating::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0));
} else {
result = ConstantInteger::get(1 - (constValue->getInt() != 0 ? 1 : 0));
}
} else if (value != nullptr) {
if (value->getType() == Type::getIntType()) {
result = builder.createNotInst(value);
} else {
result = builder.createFNotInst(value);
}
} else {
std::cout << "UnExp: value is nullptr." << std::endl;
assert(false);
}
}
return result;
}
std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) {
std::vector<Value *> params;
for (const auto &exp : ctx->exp())
params.push_back(std::any_cast<Value *>(visitExp(exp)));
return params;
}
std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) {
Value * result = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp(0)));
for (int i = 1; i < ctx->unaryExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
if (resultType == floatType || operandType == floatType) {
// 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数
if (operandType != floatType) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(operand);
if (constValue != nullptr) {
if(dynamic_cast<ConstantInteger *>(constValue)) {
// 如果是整型常量,转换为浮点型
operand = ConstantFloating::get(static_cast<float>(constValue->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constValue)) {
// 如果是浮点型常量,直接使用
operand = ConstantFloating::get(static_cast<float>(constValue->getFloat()));
}
}
else
operand = builder.createIToFInst(operand);
} else if (resultType != floatType) {
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr) {
if(dynamic_cast<ConstantInteger *>(constResult)) {
// 如果是整型常量,转换为浮点型
result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constResult)) {
// 如果是浮点型常量,直接使用
result = ConstantFloating::get(static_cast<float>(constResult->getFloat()));
}
}
else
result = builder.createIToFInst(result);
}
ConstantFloating* constResult = dynamic_cast<ConstantFloating *>(result);
ConstantFloating* constOperand = dynamic_cast<ConstantFloating *>(operand);
if (opType == SysYParser::MUL) {
if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantFloating::get(constResult->getFloat() *
constOperand->getFloat());
} else {
result = builder.createFMulInst(result, operand);
}
} else if (opType == SysYParser::DIV) {
if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantFloating::get(constResult->getFloat() /
constOperand->getFloat());
} else {
result = builder.createFDivInst(result, operand);
}
} else {
// float类型的取模操作不允许
std::cout << "MulExp: float type mod operation is not allowed." << std::endl;
assert(false);
}
} else {
ConstantInteger *constResult = dynamic_cast<ConstantInteger *>(result);
ConstantInteger *constOperand = dynamic_cast<ConstantInteger *>(operand);
if (opType == SysYParser::MUL) {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantInteger::get(constResult->getInt() * constOperand->getInt());
else
result = builder.createMulInst(result, operand);
} else if (opType == SysYParser::DIV) {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantInteger::get(constResult->getInt() / constOperand->getInt());
else
result = builder.createDivInst(result, operand);
} else {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantInteger::get(constResult->getInt() % constOperand->getInt());
else
result = builder.createRemInst(result, operand);
}
}
}
return result;
}
std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) {
Value* result = std::any_cast<Value *>(visitMulExp(ctx->mulExp(0)));
for (int i = 1; i < ctx->mulExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitMulExp(ctx->mulExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
if (resultType == floatType || operandType == floatType) {
// 类型转换
if (operandType != floatType) {
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (constOperand != nullptr) {
if(dynamic_cast<ConstantInteger *>(constOperand)) {
// 如果是整型常量,转换为浮点型
operand = ConstantFloating::get(static_cast<float>(constOperand->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constOperand)) {
// 如果是浮点型常量,直接使用
operand = ConstantFloating::get(static_cast<float>(constOperand->getFloat()));
}
}
else
operand = builder.createIToFInst(operand);
} else if (resultType != floatType) {
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr) {
if(dynamic_cast<ConstantInteger *>(constResult)) {
// 如果是整型常量,转换为浮点型
result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constResult)) {
// 如果是浮点型常量,直接使用
result = ConstantFloating::get(static_cast<float>(constResult->getFloat()));
}
}
else
result = builder.createIToFInst(result);
}
ConstantFloating *constResult = dynamic_cast<ConstantFloating *>(result);
ConstantFloating *constOperand = dynamic_cast<ConstantFloating *>(operand);
if (opType == SysYParser::ADD) {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantFloating::get(constResult->getFloat() + constOperand->getFloat());
else
result = builder.createFAddInst(result, operand);
} else {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantFloating::get(constResult->getFloat() - constOperand->getFloat());
else
result = builder.createFSubInst(result, operand);
}
} else {
ConstantInteger *constResult = dynamic_cast<ConstantInteger *>(result);
ConstantInteger *constOperand = dynamic_cast<ConstantInteger *>(operand);
if (opType == SysYParser::ADD) {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantInteger::get(constResult->getInt() + constOperand->getInt());
else
result = builder.createAddInst(result, operand);
} else {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantInteger::get(constResult->getInt() - constOperand->getInt());
else
result = builder.createSubInst(result, operand);
}
}
}
return result;
}
std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
Value* result = std::any_cast<Value *>(visitAddExp(ctx->addExp(0)));
for (int i = 1; i < ctx->addExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitAddExp(ctx->addExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue* constOperand = dynamic_cast<ConstantValue *>(operand);
// 常量比较
if ((constResult != nullptr) && (constOperand != nullptr)) {
auto operand1 = constResult->isFloat() ? constResult->getFloat()
: constResult->getInt();
auto operand2 = constOperand->isFloat() ? constOperand->getFloat()
: constOperand->getInt();
if (opType == SysYParser::LT) result = ConstantInteger::get(operand1 < operand2 ? 1 : 0);
else if (opType == SysYParser::GT) result = ConstantInteger::get(operand1 > operand2 ? 1 : 0);
else if (opType == SysYParser::LE) result = ConstantInteger::get(operand1 <= operand2 ? 1 : 0);
else if (opType == SysYParser::GE) result = ConstantInteger::get(operand1 >= operand2 ? 1 : 0);
else assert(false);
} else {
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
// 浮点数处理
if (resultType == floatType || operandType == floatType) {
if (resultType != floatType) {
if (constResult != nullptr){
if(dynamic_cast<ConstantInteger *>(constResult)) {
// 如果是整型常量,转换为浮点型
result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constResult)) {
// 如果是浮点型常量,直接使用
result = ConstantFloating::get(static_cast<float>(constResult->getFloat()));
}
}
else
result = builder.createIToFInst(result);
}
if (operandType != floatType) {
if (constOperand != nullptr) {
if(dynamic_cast<ConstantInteger *>(constOperand)) {
// 如果是整型常量,转换为浮点型
operand = ConstantFloating::get(static_cast<float>(constOperand->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constOperand)) {
// 如果是浮点型常量,直接使用
operand = ConstantFloating::get(static_cast<float>(constOperand->getFloat()));
}
}
else
operand = builder.createIToFInst(operand);
}
if (opType == SysYParser::LT) result = builder.createFCmpLTInst(result, operand);
else if (opType == SysYParser::GT) result = builder.createFCmpGTInst(result, operand);
else if (opType == SysYParser::LE) result = builder.createFCmpLEInst(result, operand);
else if (opType == SysYParser::GE) result = builder.createFCmpGEInst(result, operand);
else assert(false);
} else {
// 整数处理
if (opType == SysYParser::LT) result = builder.createICmpLTInst(result, operand);
else if (opType == SysYParser::GT) result = builder.createICmpGTInst(result, operand);
else if (opType == SysYParser::LE) result = builder.createICmpLEInst(result, operand);
else if (opType == SysYParser::GE) result = builder.createICmpGEInst(result, operand);
else assert(false);
}
}
}
return result;
}
std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
Value * result = std::any_cast<Value *>(visitRelExp(ctx->relExp(0)));
for (int i = 1; i < ctx->relExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value * operand = std::any_cast<Value *>(visitRelExp(ctx->relExp(i)));
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue* constOperand = dynamic_cast<ConstantValue *>(operand);
if ((constResult != nullptr) && (constOperand != nullptr)) {
auto operand1 = constResult->isFloat() ? constResult->getFloat()
: constResult->getInt();
auto operand2 = constOperand->isFloat() ? constOperand->getFloat()
: constOperand->getInt();
if (opType == SysYParser::EQ) result = ConstantInteger::get(operand1 == operand2 ? 1 : 0);
else if (opType == SysYParser::NE) result = ConstantInteger::get(operand1 != operand2 ? 1 : 0);
else assert(false);
} else {
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
if (resultType == floatType || operandType == floatType) {
if (resultType != floatType) {
if (constResult != nullptr){
if(dynamic_cast<ConstantInteger *>(constResult)) {
// 如果是整型常量,转换为浮点型
result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constResult)) {
// 如果是浮点型常量,直接使用
result = ConstantFloating::get(static_cast<float>(constResult->getFloat()));
}
}
else
result = builder.createIToFInst(result);
}
if (operandType != floatType) {
if (constOperand != nullptr) {
if(dynamic_cast<ConstantInteger *>(constOperand)) {
// 如果是整型常量,转换为浮点型
operand = ConstantFloating::get(static_cast<float>(constOperand->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constOperand)) {
// 如果是浮点型常量,直接使用
operand = ConstantFloating::get(static_cast<float>(constOperand->getFloat()));
}
}
else
operand = builder.createIToFInst(operand);
}
if (opType == SysYParser::EQ) result = builder.createFCmpEQInst(result, operand);
else if (opType == SysYParser::NE) result = builder.createFCmpNEInst(result, operand);
else assert(false);
} else {
if (opType == SysYParser::EQ) result = builder.createICmpEQInst(result, operand);
else if (opType == SysYParser::NE) result = builder.createICmpNEInst(result, operand);
else assert(false);
}
}
}
if (ctx->relExp().size() == 1) {
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
// 如果只有一个关系表达式则将结果转换为0或1
if (constResult != nullptr) {
if (constResult->isFloat())
result = ConstantInteger::get(constResult->getFloat() != 0.0F ? 1 : 0);
else
result = ConstantInteger::get(constResult->getInt() != 0 ? 1 : 0);
}
}
return result;
}
std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext *ctx){
std::stringstream labelstring;
BasicBlock *curBlock = builder.getBasicBlock();
Function *function = builder.getBasicBlock()->getParent();
BasicBlock *trueBlock = builder.getTrueBlock();
BasicBlock *falseBlock = builder.getFalseBlock();
auto conds = ctx->eqExp();
for (int i = 0; i < conds.size() - 1; i++) {
labelstring << "AND.L" << builder.getLabelIndex();
BasicBlock *newtrueBlock = function->addBasicBlock(labelstring.str());
labelstring.str("");
auto cond = std::any_cast<Value *>(visitEqExp(ctx->eqExp(i)));
builder.createCondBrInst(cond, newtrueBlock, falseBlock);
BasicBlock::conectBlocks(curBlock, newtrueBlock);
BasicBlock::conectBlocks(curBlock, falseBlock);
curBlock = newtrueBlock;
builder.setPosition(curBlock, curBlock->end());
}
auto cond = std::any_cast<Value *>(visitEqExp(conds.back()));
builder.createCondBrInst(cond, trueBlock, falseBlock);
BasicBlock::conectBlocks(curBlock, trueBlock);
BasicBlock::conectBlocks(curBlock, falseBlock);
return std::any();
}
auto SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext *ctx) -> std::any {
std::stringstream labelstring;
BasicBlock *curBlock = builder.getBasicBlock();
Function *function = curBlock->getParent();
auto conds = ctx->lAndExp();
for (int i = 0; i < conds.size() - 1; i++) {
labelstring << "OR.L" << builder.getLabelIndex();
BasicBlock *newFalseBlock = function->addBasicBlock(labelstring.str());
labelstring.str("");
builder.pushFalseBlock(newFalseBlock);
visitLAndExp(ctx->lAndExp(i));
builder.popFalseBlock();
builder.setPosition(newFalseBlock, newFalseBlock->end());
}
visitLAndExp(conds.back());
return std::any();
}
// attention : 这里的type是数组元素的type
void Utils::tree2Array(Type *type, ArrayValueTree *root,
const std::vector<Value *> &dims, unsigned numDims,
ValueCounter &result, IRBuilder *builder) {
Value* value = root->getValue();
auto &children = root->getChildren();
// 类型转换
if (value != nullptr) {
if (type == value->getType()) {
result.push_back(value);
} else {
if (type == Type::getFloatType()) {
ConstantValue* constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if(dynamic_cast<ConstantInteger *>(constValue))
result.push_back(ConstantFloating::get(static_cast<float>(constValue->getInt())));
else if (dynamic_cast<ConstantFloating *>(constValue))
result.push_back(ConstantFloating::get(static_cast<float>(constValue->getFloat())));
else
assert(false && "Unknown constant type for float conversion.");
}
else
result.push_back(builder->createIToFInst(value));
} else {
ConstantValue* constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr){
if(dynamic_cast<ConstantInteger *>(constValue))
result.push_back(ConstantInteger::get(constValue->getInt()));
else if (dynamic_cast<ConstantFloating *>(constValue))
result.push_back(ConstantInteger::get(static_cast<int>(constValue->getFloat())));
else
assert(false && "Unknown constant type for int conversion.");
}
else
result.push_back(builder->createFtoIInst(value));
}
}
return;
}
auto beforeSize = result.size();
for (const auto &child : children) {
int begin = result.size();
int newNumDims = 0;
for (unsigned i = 0; i < numDims - 1; i++) {
auto dim = dynamic_cast<ConstantValue *>(*(dims.rbegin() + i))->getInt();
if (begin % dim == 0) {
newNumDims += 1;
begin /= dim;
} else {
break;
}
}
tree2Array(type, child.get(), dims, newNumDims, result, builder);
}
auto afterSize = result.size();
int blockSize = 1;
for (unsigned i = 0; i < numDims; i++) {
blockSize *= dynamic_cast<ConstantValue *>(*(dims.rbegin() + i))->getInt();
}
int num = blockSize - afterSize + beforeSize;
if (num > 0) {
if (type == Type::getFloatType())
result.push_back(ConstantFloating::get(0.0F), num);
else
result.push_back(ConstantInteger::get(0), num);
}
}
void Utils::createExternalFunction(
const std::vector<Type *> &paramTypes,
const std::vector<std::string> &paramNames,
const std::vector<std::vector<Value *>> &paramDims, Type *returnType,
const std::string &funcName, Module *pModule, IRBuilder *pBuilder) {
auto funcType = Type::getFunctionType(returnType, paramTypes);
auto function = pModule->createExternalFunction(funcName, funcType);
auto entry = function->getEntryBlock();
pBuilder->setPosition(entry, entry->end());
for (int i = 0; i < paramTypes.size(); ++i) {
auto arg = new Argument(paramTypes[i], function, i, paramNames[i]);
auto alloca = pBuilder->createAllocaInst(
Type::getPointerType(paramTypes[i]), paramNames[i]);
function->insertArgument(arg);
auto store = pBuilder->createStoreInst(arg, alloca);
pModule->addVariable(paramNames[i], alloca);
}
}
void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) {
std::vector<Type *> paramTypes;
std::vector<std::string> paramNames;
std::vector<std::vector<Value *>> paramDims;
Type *returnType;
std::string funcName;
returnType = Type::getIntType();
funcName = "getint";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
funcName = "getch";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.push_back(Type::getIntType());
paramNames.emplace_back("x");
paramDims.push_back(std::vector<Value *>{ConstantInteger::get(-1)});
funcName = "getarray";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
returnType = Type::getFloatType();
paramTypes.clear();
paramNames.clear();
paramDims.clear();
funcName = "getfloat";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
returnType = Type::getIntType();
paramTypes.push_back(Type::getFloatType());
paramNames.emplace_back("x");
paramDims.push_back(std::vector<Value *>{ConstantInteger::get(-1)});
funcName = "getfarray";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
returnType = Type::getVoidType();
paramTypes.clear();
paramTypes.push_back(Type::getIntType());
paramDims.clear();
paramDims.emplace_back();
funcName = "putint";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
funcName = "putch";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.clear();
paramTypes.push_back(Type::getIntType());
paramTypes.push_back(Type::getIntType());
paramDims.clear();
paramDims.emplace_back();
paramDims.push_back(std::vector<Value *>{ConstantInteger::get(-1)});
paramNames.clear();
paramNames.emplace_back("n");
paramNames.emplace_back("a");
funcName = "putarray";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.clear();
paramTypes.push_back(Type::getFloatType());
paramDims.clear();
paramDims.emplace_back();
paramNames.clear();
paramNames.emplace_back("a");
funcName = "putfloat";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.clear();
paramTypes.push_back(Type::getIntType());
paramTypes.push_back(Type::getFloatType());
paramDims.clear();
paramDims.emplace_back();
paramDims.push_back(std::vector<Value *>{ConstantInteger::get(-1)});
paramNames.clear();
paramNames.emplace_back("n");
paramNames.emplace_back("a");
funcName = "putfarray";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.clear();
paramTypes.push_back(Type::getIntType());
paramDims.clear();
paramDims.emplace_back();
paramNames.clear();
paramNames.emplace_back("__LINE__");
funcName = "starttime";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.clear();
paramTypes.push_back(Type::getIntType());
paramDims.clear();
paramDims.emplace_back();
paramNames.clear();
paramNames.emplace_back("__LINE__");
funcName = "stoptime";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
}
} // namespace sysy