Files
mysysy/src/SysYIRGenerator.cpp
2025-07-15 12:53:03 +08:00

1197 lines
43 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// SysYIRGenerator.cpp
// TODO类型转换及其检查
// TODOsysy库函数处理
// TODO数组处理
// TODO对while、continue、break的测试
#include "IR.h"
#include <any>
#include <memory>
#include <iterator>
#include <sstream>
#include <string>
#include <vector>
using namespace std;
#include "SysYIRGenerator.h"
namespace sysy {
/*
* @brief: visit compUnit
* @details:
* compUnit: (globalDecl | funcDef)+;
*/
std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) {
// create the IR module
auto pModule = new Module();
assert(pModule);
module.reset(pModule);
// SymbolTable::ModuleScope scope(symbols_table);
Utils::initExternalFunction(pModule, &builder);
pModule->enterNewScope();
visitChildren(ctx);
pModule->leaveScope();
return pModule;
}
std::any SysYIRGenerator::visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx){
auto constDecl = ctx->constDecl();
Type* type = std::any_cast<Type *>(visitBType(constDecl->bType()));
for (const auto &constDef : constDecl->constDef()) {
std::vector<Value *> dims = {};
std::string name = constDef->Ident()->getText();
auto constExps = constDef->constExp();
if (!constExps.empty()) {
for (const auto &constExp : constExps) {
dims.push_back(std::any_cast<Value *>(visitConstExp(constExp)));
}
}
ArrayValueTree* root = std::any_cast<ArrayValueTree *>(constDef->constInitVal()->accept(this));
ValueCounter values;
Utils::tree2Array(type, root, dims, dims.size(), values, &builder);
delete root;
// 创建全局常量变量,并更新符号表
module->createConstVar(name, Type::getPointerType(type), values, dims);
}
return std::any();
}
std::any SysYIRGenerator::visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *ctx) {
auto varDecl = ctx->varDecl();
Type* type = std::any_cast<Type *>(visitBType(varDecl->bType()));
for (const auto &varDef : varDecl->varDef()) {
std::vector<Value *> dims = {};
std::string name = varDef->Ident()->getText();
auto constExps = varDef->constExp();
if (!constExps.empty()) {
for (const auto &constExp : constExps) {
dims.push_back(std::any_cast<Value *>(visitConstExp(constExp)));
}
}
ValueCounter values = {};
if (varDef->initVal() != nullptr) {
ArrayValueTree* root = std::any_cast<ArrayValueTree *>(varDef->initVal()->accept(this));
Utils::tree2Array(type, root, dims, dims.size(), values, &builder);
delete root;
}
// 创建全局变量,并更新符号表
module->createGlobalValue(name, Type::getPointerType(type), dims, values);
}
return std::any();
}
std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx){
Type* type = std::any_cast<Type *>(visitBType(ctx->bType()));
for (const auto constDef : ctx->constDef()) {
std::vector<Value *> dims = {};
std::string name = constDef->Ident()->getText();
auto constExps = constDef->constExp();
if (!constExps.empty()) {
for (const auto constExp : constExps) {
dims.push_back(std::any_cast<Value *>(visitConstExp(constExp)));
}
}
ArrayValueTree* root = std::any_cast<ArrayValueTree *>(constDef->constInitVal()->accept(this));
ValueCounter values;
Utils::tree2Array(type, root, dims, dims.size(), values, &builder);
delete root;
module->createConstVar(name, Type::getPointerType(type), values, dims);
}
return 0;
}
std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) {
Type* type = std::any_cast<Type *>(visitBType(ctx->bType()));
for (const auto varDef : ctx->varDef()) {
std::vector<Value *> dims = {};
std::string name = varDef->Ident()->getText();
auto constExps = varDef->constExp();
if (!constExps.empty()) {
for (const auto &constExp : constExps) {
dims.push_back(std::any_cast<Value *>(visitConstExp(constExp)));
}
}
AllocaInst* alloca =
builder.createAllocaInst(Type::getPointerType(type), dims, name);
if (varDef->initVal() != nullptr) {
ValueCounter values;
// 这里的varDef->initVal()可能是ScalarInitValue或ArrayInitValue
ArrayValueTree* root = std::any_cast<ArrayValueTree *>(varDef->initVal()->accept(this));
Utils::tree2Array(type, root, dims, dims.size(), values, &builder);
delete root;
if (dims.empty()) {
builder.createStoreInst(values.getValue(0), alloca);
} else {
// 对于多维数组使用memset初始化
// 计算每个维度的大小
// 这里的values.getNumbers()返回的是每个维度的大小
// 这里的values.getValues()返回的是每个维度对应的值
// 例如对于一个二维数组values.getNumbers()可能是[3, 4]表示3行4列
// values.getValues()可能是[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
// 对于每个维度使用memset将对应的值填充到数组中
// 这里的alloca是一个指向数组的指针
const std::vector<unsigned int> & counterNumbers = values.getNumbers();
const std::vector<sysy::Value *> & counterValues = values.getValues();
unsigned begin = 0;
for (size_t i = 0; i < counterNumbers.size(); i++) {
builder.createMemsetInst(
alloca, ConstantValue::get(static_cast<int>(begin)),
ConstantValue::get(static_cast<int>(counterNumbers[i])),
counterValues[i]);
begin += counterNumbers[i];
}
}
}
module->addVariable(name, alloca);
}
return std::any();
}
std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) {
return ctx->INT() != nullptr ? Type::getIntType() : Type::getFloatType();
}
std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) {
Value* value = std::any_cast<Value *>(visitExp(ctx->exp()));
ArrayValueTree* result = new ArrayValueTree();
result->setValue(value);
return result;
}
std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) {
std::vector<ArrayValueTree *> children;
for (const auto &initVal : ctx->initVal())
children.push_back(std::any_cast<ArrayValueTree *>(initVal->accept(this)));
ArrayValueTree* result = new ArrayValueTree();
result->addChildren(children);
return result;
}
std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) {
Value* value = std::any_cast<Value *>(visitConstExp(ctx->constExp()));
ArrayValueTree* result = new ArrayValueTree();
result->setValue(value);
return result;
}
std::any SysYIRGenerator::visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) {
std::vector<ArrayValueTree *> children;
for (const auto &constInitVal : ctx->constInitVal())
children.push_back(std::any_cast<ArrayValueTree *>(constInitVal->accept(this)));
ArrayValueTree* result = new ArrayValueTree();
result->addChildren(children);
return result;
}
std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) {
if (ctx->INT() != nullptr)
return Type::getIntType();
if (ctx->FLOAT() != nullptr)
return Type::getFloatType();
return Type::getVoidType();
}
std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){
// 更新作用域
module->enterNewScope();
HasReturnInst = false;
auto name = ctx->Ident()->getText();
std::vector<Type *> paramTypes;
std::vector<std::string> paramNames;
std::vector<std::vector<Value *>> paramDims;
if (ctx->funcFParams() != nullptr) {
auto params = ctx->funcFParams()->funcFParam();
for (const auto &param : params) {
paramTypes.push_back(std::any_cast<Type *>(visitBType(param->bType())));
paramNames.push_back(param->Ident()->getText());
std::vector<Value *> dims = {};
if (!param->LBRACK().empty()) {
dims.push_back(ConstantValue::get(-1)); // 第一个维度不确定
for (const auto &exp : param->exp()) {
dims.push_back(std::any_cast<Value *>(visitExp(exp)));
}
}
paramDims.emplace_back(dims);
}
}
Type* returnType = std::any_cast<Type *>(visitFuncType(ctx->funcType()));
Type* funcType = Type::getFunctionType(returnType, paramTypes);
Function* function = module->createFunction(name, funcType);
BasicBlock* entry = function->getEntryBlock();
builder.setPosition(entry, entry->end());
for (size_t i = 0; i < paramTypes.size(); ++i) {
AllocaInst* alloca = builder.createAllocaInst(Type::getPointerType(paramTypes[i]),
paramDims[i], paramNames[i]);
entry->insertArgument(alloca);
module->addVariable(paramNames[i], alloca);
}
for (auto item : ctx->blockStmt()->blockItem()) {
visitBlockItem(item);
}
if(HasReturnInst == false) {
// 如果没有return语句则默认返回0
if (returnType != Type::getVoidType()) {
Value* returnValue = ConstantValue::get(0);
if (returnType == Type::getFloatType()) {
returnValue = ConstantValue::get(0.0f);
}
builder.createReturnInst(returnValue);
} else {
builder.createReturnInst();
}
}
module->leaveScope();
return std::any();
}
std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) {
module->enterNewScope();
for (auto item : ctx->blockItem())
visitBlockItem(item);
module->leaveScope();
return 0;
}
std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
auto lVal = ctx->lValue();
std::string name = lVal->Ident()->getText();
std::vector<Value *> dims;
for (const auto &exp : lVal->exp()) {
dims.push_back(std::any_cast<Value *>(visitExp(exp)));
}
auto variable = module->getVariable(name);
Value* value = std::any_cast<Value *>(visitExp(ctx->exp()));
Type* variableType = dynamic_cast<PointerType *>(variable->getType())->getBaseType();
// 左值右值类型不同处理
if (variableType != value->getType()) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (variableType == Type::getFloatType()) {
value = ConstantValue::get(static_cast<float>(constValue->getInt()));
} else {
value = ConstantValue::get(static_cast<int>(constValue->getFloat()));
}
} else {
if (variableType == Type::getFloatType()) {
value = builder.createIToFInst(value);
} else {
value = builder.createFtoIInst(value);
}
}
}
builder.createStoreInst(value, variable, dims, variable->getName());
return std::any();
}
std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) {
// labels string stream
std::stringstream labelstring;
Function * function = builder.getBasicBlock()->getParent();
BasicBlock* thenBlock = new BasicBlock(function);
BasicBlock* exitBlock = new BasicBlock(function);
if (ctx->stmt().size() > 1) {
BasicBlock* elseBlock = new BasicBlock(function);
builder.pushTrueBlock(thenBlock);
builder.pushFalseBlock(elseBlock);
// 访问条件表达式
visitCond(ctx->cond());
builder.popTrueBlock();
builder.popFalseBlock();
labelstring << "if_then.L" << builder.getLabelIndex();
thenBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(thenBlock);
builder.setPosition(thenBlock, thenBlock->end());
auto block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt(0));
// 如果是块语句,直接访问
// 否则访问语句
if (block != nullptr) {
visitBlockStmt(block);
} else {
module->enterNewScope();
ctx->stmt(0)->accept(this);
module->leaveScope();
}
builder.createUncondBrInst(exitBlock, {});
BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock);
labelstring << "if_else.L" << builder.getLabelIndex();
elseBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(elseBlock);
builder.setPosition(elseBlock, elseBlock->end());
block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt(1));
if (block != nullptr) {
visitBlockStmt(block);
} else {
module->enterNewScope();
ctx->stmt(1)->accept(this);
module->leaveScope();
}
builder.createUncondBrInst(exitBlock, {});
BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock);
labelstring << "if_exit.L" << builder.getLabelIndex();
exitBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(exitBlock);
builder.setPosition(exitBlock, exitBlock->end());
} else {
builder.pushTrueBlock(thenBlock);
builder.pushFalseBlock(exitBlock);
visitCond(ctx->cond());
builder.popTrueBlock();
builder.popFalseBlock();
labelstring << "if_then.L" << builder.getLabelIndex();
thenBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(thenBlock);
builder.setPosition(thenBlock, thenBlock->end());
auto block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt(0));
if (block != nullptr) {
visitBlockStmt(block);
} else {
module->enterNewScope();
ctx->stmt(0)->accept(this);
module->leaveScope();
}
BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock);
labelstring << "if_exit.L" << builder.getLabelIndex();
exitBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(exitBlock);
builder.setPosition(exitBlock, exitBlock->end());
}
return std::any();
}
std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) {
// while structure:
// curblock -> headBlock -> bodyBlock -> exitBlock
BasicBlock* curBlock = builder.getBasicBlock();
Function* function = builder.getBasicBlock()->getParent();
std::stringstream labelstring;
labelstring << "while_head.L" << builder.getLabelIndex();
BasicBlock *headBlock = function->addBasicBlock(labelstring.str());
labelstring.str("");
BasicBlock::conectBlocks(curBlock, headBlock);
builder.setPosition(headBlock, headBlock->end());
BasicBlock* bodyBlock = new BasicBlock(function);
BasicBlock* exitBlock = new BasicBlock(function);
builder.pushTrueBlock(bodyBlock);
builder.pushFalseBlock(exitBlock);
// 访问条件表达式
visitCond(ctx->cond());
builder.popTrueBlock();
builder.popFalseBlock();
labelstring << "while_body.L" << builder.getLabelIndex();
bodyBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(bodyBlock);
builder.setPosition(bodyBlock, bodyBlock->end());
builder.pushBreakBlock(exitBlock);
builder.pushContinueBlock(headBlock);
auto block = dynamic_cast<SysYParser::BlockStmtContext *>(ctx->stmt());
if( block != nullptr) {
visitBlockStmt(block);
} else {
module->enterNewScope();
ctx->stmt()->accept(this);
module->leaveScope();
}
builder.createUncondBrInst(headBlock, {});
BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock);
builder.popBreakBlock();
builder.popContinueBlock();
labelstring << "while_exit.L" << builder.getLabelIndex();
exitBlock->setName(labelstring.str());
labelstring.str("");
function->addBasicBlock(exitBlock);
builder.setPosition(exitBlock, exitBlock->end());
return std::any();
}
std::any SysYIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx) {
BasicBlock* breakBlock = builder.getBreakBlock();
builder.createUncondBrInst(breakBlock, {});
BasicBlock::conectBlocks(builder.getBasicBlock(), breakBlock);
return std::any();
}
std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx) {
BasicBlock* continueBlock = builder.getContinueBlock();
builder.createUncondBrInst(continueBlock, {});
BasicBlock::conectBlocks(builder.getBasicBlock(), continueBlock);
return std::any();
}
std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) {
Value* returnValue = nullptr;
if (ctx->exp() != nullptr) {
returnValue = std::any_cast<Value *>(visitExp(ctx->exp()));
}
Type* funcType = builder.getBasicBlock()->getParent()->getReturnType();
if (funcType!= returnValue->getType() && returnValue != nullptr) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(returnValue);
if (constValue != nullptr) {
if (funcType == Type::getFloatType()) {
returnValue = ConstantValue::get(static_cast<float>(constValue->getInt()));
} else {
returnValue = ConstantValue::get(static_cast<int>(constValue->getFloat()));
}
} else {
if (funcType == Type::getFloatType()) {
returnValue = builder.createIToFInst(returnValue);
} else {
returnValue = builder.createFtoIInst(returnValue);
}
}
}
builder.createReturnInst(returnValue);
HasReturnInst = true;
return std::any();
}
std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) {
std::string name = ctx->Ident()->getText();
User* variable = module->getVariable(name);
Value* value = nullptr;
std::vector<Value *> dims;
for (const auto &exp : ctx->exp()) {
dims.push_back(std::any_cast<Value *>(visitExp(exp)));
}
if (variable == nullptr) {
throw std::runtime_error("Variable " + name + " not found.");
}
bool indicesConstant = true;
for (const auto &dim : dims) {
if (dynamic_cast<ConstantValue *>(dim) == nullptr) {
indicesConstant = false;
break;
}
}
ConstantVariable* constVar = dynamic_cast<ConstantVariable *>(variable);
GlobalValue* globalVar = dynamic_cast<GlobalValue *>(variable);
AllocaInst* localVar = dynamic_cast<AllocaInst *>(variable);
if (constVar != nullptr && indicesConstant) {
// 如果是常量变量,且索引是常量,则直接获取子数组
value = constVar->getByIndices(dims);
} else if (module->isInGlobalArea() && (globalVar != nullptr)) {
assert(indicesConstant);
value = globalVar->getByIndices(dims);
} else {
if ((globalVar != nullptr && globalVar->getNumDims() > dims.size()) ||
(localVar != nullptr && localVar->getNumDims() > dims.size()) ||
(constVar != nullptr && constVar->getNumDims() > dims.size())) {
// value = builder.createLaInst(variable, indices);
// 如果变量是全局变量或局部变量且索引数量小于维度数量则创建createGetSubArray获取子数组
auto getArrayInst =
builder.createGetSubArray(dynamic_cast<LVal *>(variable), dims);
value = getArrayInst->getChildArray();
} else {
value = builder.createLoadInst(variable, dims);
}
}
return value;
}
std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) {
if (ctx->exp() != nullptr)
return visitExp(ctx->exp());
if (ctx->lValue() != nullptr)
return visitLValue(ctx->lValue());
if (ctx->number() != nullptr)
return visitNumber(ctx->number());
if (ctx->string() != nullptr) {
cout << "String literal not supported in SysYIRGenerator." << endl;
}
return visitNumber(ctx->number());
}
std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) {
if (ctx->ILITERAL() != nullptr) {
int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0);
return static_cast<Value *>(ConstantValue::get(value));
} else if (ctx->FLITERAL() != nullptr) {
float value = std::stof(ctx->FLITERAL()->getText());
return static_cast<Value *>(ConstantValue::get(value));
}
throw std::runtime_error("Unknown number type.");
return std::any(); // 不会到达这里
}
std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
std::string funcName = ctx->Ident()->getText();
Function *function = module->getFunction(funcName);
if (function == nullptr) {
function = module->getExternalFunction(funcName);
if (function == nullptr) {
std::cout << "The function " << funcName << " no defined." << std::endl;
assert(function);
}
}
std::vector<Value *> args = {};
if (funcName == "starttime" || funcName == "stoptime") {
// 如果是starttime或stoptime函数
// TODO: 这里需要处理starttime和stoptime函数的参数
// args.emplace_back()
} else {
if (ctx->funcRParams() != nullptr) {
args = std::any_cast<std::vector<Value *>>(visitFuncRParams(ctx->funcRParams()));
}
auto params = function->getEntryBlock()->getArguments();
for (size_t i = 0; i < args.size(); i++) {
// 参数类型转换
if (params[i]->getType() != args[i]->getType() &&
(params[i]->getNumDims() != 0 ||
params[i]->getType()->as<PointerType>()->getBaseType() != args[i]->getType())) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(args[i]);
if (constValue != nullptr) {
if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) {
args[i] = ConstantValue::get(static_cast<float>(constValue->getInt()));
} else {
args[i] = ConstantValue::get(static_cast<int>(constValue->getFloat()));
}
} else {
if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) {
args[i] = builder.createIToFInst(args[i]);
} else {
args[i] = builder.createFtoIInst(args[i]);
}
}
}
}
}
return static_cast<Value *>(builder.createCallInst(function, args));
}
std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) {
if (ctx->primaryExp() != nullptr)
return visitPrimaryExp(ctx->primaryExp());
if (ctx->call() != nullptr)
return visitCall(ctx->call());
Value* value = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp()));
Value* result = value;
if (ctx->unaryOp()->SUB() != nullptr) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (constValue->isFloat()) {
result = ConstantValue::get(-constValue->getFloat());
} else {
result = ConstantValue::get(-constValue->getInt());
}
} else if (value != nullptr) {
if (value->getType() == Type::getIntType()) {
result = builder.createNegInst(value);
} else {
result = builder.createFNegInst(value);
}
} else {
std::cout << "UnExp: value is nullptr." << std::endl;
assert(false);
}
} else if (ctx->unaryOp()->NOT() != nullptr) {
auto constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (constValue->isFloat()) {
result =
ConstantValue::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0));
} else {
result = ConstantValue::get(1 - (constValue->getInt() != 0 ? 1 : 0));
}
} else if (value != nullptr) {
if (value->getType() == Type::getIntType()) {
result = builder.createNotInst(value);
} else {
result = builder.createFNotInst(value);
}
} else {
std::cout << "UnExp: value is nullptr." << std::endl;
assert(false);
}
}
return result;
}
std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) {
std::vector<Value *> params;
for (const auto &exp : ctx->exp())
params.push_back(std::any_cast<Value *>(visitExp(exp)));
return params;
}
std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) {
Value * result = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp(0)));
for (size_t i = 1; i < ctx->unaryExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
if (resultType == floatType || operandType == floatType) {
// 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数
if (operandType != floatType) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(operand);
if (constValue != nullptr)
operand = ConstantValue::get(static_cast<float>(constValue->getInt()));
else
operand = builder.createIToFInst(operand);
} else if (resultType != floatType) {
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt()));
else
result = builder.createIToFInst(result);
}
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue* constOperand = dynamic_cast<ConstantValue *>(operand);
if (opType == SysYParser::MUL) {
if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantValue::get(constResult->getFloat() *
constOperand->getFloat());
} else {
result = builder.createFMulInst(result, operand);
}
} else if (opType == SysYParser::DIV) {
if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantValue::get(constResult->getFloat() /
constOperand->getFloat());
} else {
result = builder.createFDivInst(result, operand);
}
} else {
// float类型的取模操作不允许
std::cout << "MulExp: float type mod operation is not allowed." << std::endl;
assert(false);
}
} else {
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (opType == SysYParser::MUL) {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantValue::get(constResult->getInt() * constOperand->getInt());
else
result = builder.createMulInst(result, operand);
} else if (opType == SysYParser::DIV) {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantValue::get(constResult->getInt() / constOperand->getInt());
else
result = builder.createDivInst(result, operand);
} else {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantValue::get(constResult->getInt() % constOperand->getInt());
else
result = builder.createRemInst(result, operand);
}
}
}
return result;
}
std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) {
Value* result = std::any_cast<Value *>(visitMulExp(ctx->mulExp(0)));
for (size_t i = 1; i < ctx->mulExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitMulExp(ctx->mulExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
if (resultType == floatType || operandType == floatType) {
// 类型转换
if (operandType != floatType) {
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (constOperand != nullptr)
operand = ConstantValue::get(static_cast<float>(constOperand->getInt()));
else
operand = builder.createIToFInst(operand);
} else if (resultType != floatType) {
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt()));
else
result = builder.createIToFInst(result);
}
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (opType == SysYParser::ADD) {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getFloat() + constOperand->getFloat());
else
result = builder.createFAddInst(result, operand);
} else {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getFloat() - constOperand->getFloat());
else
result = builder.createFSubInst(result, operand);
}
} else {
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (opType == SysYParser::ADD) {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getInt() + constOperand->getInt());
else
result = builder.createAddInst(result, operand);
} else {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getInt() - constOperand->getInt());
else
result = builder.createSubInst(result, operand);
}
}
}
return result;
}
std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
Value* result = std::any_cast<Value *>(visitAddExp(ctx->addExp(0)));
for (size_t i = 1; i < ctx->addExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitAddExp(ctx->addExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue* constOperand = dynamic_cast<ConstantValue *>(operand);
// 常量比较
if ((constResult != nullptr) && (constOperand != nullptr)) {
auto operand1 = constResult->isFloat() ? constResult->getFloat()
: constResult->getInt();
auto operand2 = constOperand->isFloat() ? constOperand->getFloat()
: constOperand->getInt();
if (opType == SysYParser::LT) result = ConstantValue::get(operand1 < operand2 ? 1 : 0);
else if (opType == SysYParser::GT) result = ConstantValue::get(operand1 > operand2 ? 1 : 0);
else if (opType == SysYParser::LE) result = ConstantValue::get(operand1 <= operand2 ? 1 : 0);
else if (opType == SysYParser::GE) result = ConstantValue::get(operand1 >= operand2 ? 1 : 0);
else assert(false);
} else {
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
// 浮点数处理
if (resultType == floatType || operandType == floatType) {
if (resultType != floatType) {
if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt()));
else
result = builder.createIToFInst(result);
}
if (operandType != floatType) {
if (constOperand != nullptr)
operand = ConstantValue::get(static_cast<float>(constOperand->getInt()));
else
operand = builder.createIToFInst(operand);
}
if (opType == SysYParser::LT) result = builder.createFCmpLTInst(result, operand);
else if (opType == SysYParser::GT) result = builder.createFCmpGTInst(result, operand);
else if (opType == SysYParser::LE) result = builder.createFCmpLEInst(result, operand);
else if (opType == SysYParser::GE) result = builder.createFCmpGEInst(result, operand);
else assert(false);
} else {
// 整数处理
if (opType == SysYParser::LT) result = builder.createICmpLTInst(result, operand);
else if (opType == SysYParser::GT) result = builder.createICmpGTInst(result, operand);
else if (opType == SysYParser::LE) result = builder.createICmpLEInst(result, operand);
else if (opType == SysYParser::GE) result = builder.createICmpGEInst(result, operand);
else assert(false);
}
}
}
return result;
}
std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
Value * result = std::any_cast<Value *>(visitRelExp(ctx->relExp(0)));
for (size_t i = 1; i < ctx->relExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value * operand = std::any_cast<Value *>(visitRelExp(ctx->relExp(i)));
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
ConstantValue* constOperand = dynamic_cast<ConstantValue *>(operand);
if ((constResult != nullptr) && (constOperand != nullptr)) {
auto operand1 = constResult->isFloat() ? constResult->getFloat()
: constResult->getInt();
auto operand2 = constOperand->isFloat() ? constOperand->getFloat()
: constOperand->getInt();
if (opType == SysYParser::EQ) result = ConstantValue::get(operand1 == operand2 ? 1 : 0);
else if (opType == SysYParser::NE) result = ConstantValue::get(operand1 != operand2 ? 1 : 0);
else assert(false);
} else {
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
if (resultType == floatType || operandType == floatType) {
if (resultType != floatType) {
if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt()));
else
result = builder.createIToFInst(result);
}
if (operandType != floatType) {
if (constOperand != nullptr)
operand = ConstantValue::get(static_cast<float>(constOperand->getInt()));
else
operand = builder.createIToFInst(operand);
}
if (opType == SysYParser::EQ) result = builder.createFCmpEQInst(result, operand);
else if (opType == SysYParser::NE) result = builder.createFCmpNEInst(result, operand);
else assert(false);
} else {
if (opType == SysYParser::EQ) result = builder.createICmpEQInst(result, operand);
else if (opType == SysYParser::NE) result = builder.createICmpNEInst(result, operand);
else assert(false);
}
}
}
if (ctx->relExp().size() == 1) {
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
// 如果只有一个关系表达式则将结果转换为0或1
if (constResult != nullptr) {
if (constResult->isFloat())
result = ConstantValue::get(constResult->getFloat() != 0.0F ? 1 : 0);
else
result = ConstantValue::get(constResult->getInt() != 0 ? 1 : 0);
}
}
return result;
}
std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext *ctx){
std::stringstream labelstring;
BasicBlock *curBlock = builder.getBasicBlock();
Function *function = builder.getBasicBlock()->getParent();
BasicBlock *trueBlock = builder.getTrueBlock();
BasicBlock *falseBlock = builder.getFalseBlock();
auto conds = ctx->eqExp();
for (size_t i = 0; i < conds.size() - 1; i++) {
labelstring << "AND.L" << builder.getLabelIndex();
BasicBlock *newtrueBlock = function->addBasicBlock(labelstring.str());
labelstring.str("");
auto cond = std::any_cast<Value *>(visitEqExp(ctx->eqExp(i)));
builder.createCondBrInst(cond, newtrueBlock, falseBlock, {}, {});
BasicBlock::conectBlocks(curBlock, newtrueBlock);
BasicBlock::conectBlocks(curBlock, falseBlock);
curBlock = newtrueBlock;
builder.setPosition(curBlock, curBlock->end());
}
auto cond = std::any_cast<Value *>(visitEqExp(conds.back()));
builder.createCondBrInst(cond, trueBlock, falseBlock, {}, {});
BasicBlock::conectBlocks(curBlock, trueBlock);
BasicBlock::conectBlocks(curBlock, falseBlock);
return std::any();
}
auto SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext *ctx) -> std::any {
std::stringstream labelstring;
BasicBlock *curBlock = builder.getBasicBlock();
Function *function = curBlock->getParent();
auto conds = ctx->lAndExp();
for (size_t i = 0; i < conds.size() - 1; i++) {
labelstring << "OR.L" << builder.getLabelIndex();
BasicBlock *newFalseBlock = function->addBasicBlock(labelstring.str());
labelstring.str("");
builder.pushFalseBlock(newFalseBlock);
visitLAndExp(ctx->lAndExp(i));
builder.popFalseBlock();
builder.setPosition(newFalseBlock, newFalseBlock->end());
}
visitLAndExp(conds.back());
return std::any();
}
void Utils::tree2Array(Type *type, ArrayValueTree *root,
const std::vector<Value *> &dims, unsigned numDims,
ValueCounter &result, IRBuilder *builder) {
Value* value = root->getValue();
auto &children = root->getChildren();
if (value != nullptr) {
if (type == value->getType()) {
result.push_back(value);
} else {
if (type == Type::getFloatType()) {
ConstantValue* constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr)
result.push_back(ConstantValue::get(static_cast<float>(constValue->getInt())));
else
result.push_back(builder->createIToFInst(value));
} else {
ConstantValue* constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr)
result.push_back(ConstantValue::get(static_cast<int>(constValue->getFloat())));
else
result.push_back(builder->createFtoIInst(value));
}
}
return;
}
auto beforeSize = result.size();
for (const auto &child : children) {
int begin = result.size();
int newNumDims = 0;
for (unsigned i = 0; i < numDims - 1; i++) {
auto dim = dynamic_cast<ConstantValue *>(*(dims.rbegin() + i))->getInt();
if (begin % dim == 0) {
newNumDims += 1;
begin /= dim;
} else {
break;
}
}
tree2Array(type, child.get(), dims, newNumDims, result, builder);
}
auto afterSize = result.size();
int blockSize = 1;
for (unsigned i = 0; i < numDims; i++) {
blockSize *= dynamic_cast<ConstantValue *>(*(dims.rbegin() + i))->getInt();
}
int num = blockSize - afterSize + beforeSize;
if (num > 0) {
if (type == Type::getFloatType())
result.push_back(ConstantValue::get(0.0F), num);
else
result.push_back(ConstantValue::get(0), num);
}
}
void Utils::createExternalFunction(
const std::vector<Type *> &paramTypes,
const std::vector<std::string> &paramNames,
const std::vector<std::vector<Value *>> &paramDims, Type *returnType,
const std::string &funcName, Module *pModule, IRBuilder *pBuilder) {
auto funcType = Type::getFunctionType(returnType, paramTypes);
auto function = pModule->createExternalFunction(funcName, funcType);
auto entry = function->getEntryBlock();
pBuilder->setPosition(entry, entry->end());
for (size_t i = 0; i < paramTypes.size(); ++i) {
auto alloca = pBuilder->createAllocaInst(
Type::getPointerType(paramTypes[i]), paramDims[i], paramNames[i]);
entry->insertArgument(alloca);
// pModule->addVariable(paramNames[i], alloca);
}
}
void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) {
std::vector<Type *> paramTypes;
std::vector<std::string> paramNames;
std::vector<std::vector<Value *>> paramDims;
Type *returnType;
std::string funcName;
returnType = Type::getIntType();
funcName = "getint";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
funcName = "getch";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.push_back(Type::getIntType());
paramNames.emplace_back("x");
paramDims.push_back(std::vector<Value *>{ConstantValue::get(-1)});
funcName = "getarray";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
returnType = Type::getFloatType();
paramTypes.clear();
paramNames.clear();
paramDims.clear();
funcName = "getfloat";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
returnType = Type::getIntType();
paramTypes.push_back(Type::getFloatType());
paramNames.emplace_back("x");
paramDims.push_back(std::vector<Value *>{ConstantValue::get(-1)});
funcName = "getfarray";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
returnType = Type::getVoidType();
paramTypes.clear();
paramTypes.push_back(Type::getIntType());
paramDims.clear();
paramDims.emplace_back();
funcName = "putint";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
funcName = "putch";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.clear();
paramTypes.push_back(Type::getIntType());
paramTypes.push_back(Type::getIntType());
paramDims.clear();
paramDims.emplace_back();
paramDims.push_back(std::vector<Value *>{ConstantValue::get(-1)});
paramNames.clear();
paramNames.emplace_back("n");
paramNames.emplace_back("a");
funcName = "putarray";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.clear();
paramTypes.push_back(Type::getFloatType());
paramDims.clear();
paramDims.emplace_back();
paramNames.clear();
paramNames.emplace_back("a");
funcName = "putfloat";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.clear();
paramTypes.push_back(Type::getIntType());
paramTypes.push_back(Type::getFloatType());
paramDims.clear();
paramDims.emplace_back();
paramDims.push_back(std::vector<Value *>{ConstantValue::get(-1)});
paramNames.clear();
paramNames.emplace_back("n");
paramNames.emplace_back("a");
funcName = "putfarray";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.clear();
paramTypes.push_back(Type::getIntType());
paramDims.clear();
paramDims.emplace_back();
paramNames.clear();
paramNames.emplace_back("__LINE__");
funcName = "starttime";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
paramTypes.clear();
paramTypes.push_back(Type::getIntType());
paramDims.clear();
paramDims.emplace_back();
paramNames.clear();
paramNames.emplace_back("__LINE__");
funcName = "stoptime";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder);
}
} // namespace sysy