WE CAN MAKE FUNCTIONS NOW!!!!!!!!!!!!!!!!!!!!!!!!!

This commit is contained in:
apio 2022-08-25 18:49:43 +02:00
parent 06d7f6bb09
commit bae0d82f26
36 changed files with 419 additions and 40 deletions

View File

@ -51,6 +51,16 @@ add_executable(
src/AST/SyscallNode.h
src/AST/UnaryOpNode.cpp
src/AST/UnaryOpNode.h
src/AST/ProgramNode.cpp
src/AST/ProgramNode.h
src/AST/TopLevelNode.cpp
src/AST/TopLevelNode.h
src/AST/FunctionPrototype.cpp
src/AST/FunctionPrototype.h
src/AST/FunctionNode.cpp
src/AST/FunctionNode.h
src/GlobalContext.cpp
src/GlobalContext.h
src/utils.cpp
src/utils.h
src/Parser.cpp

View File

@ -1,5 +1,4 @@
#pragma once
#include "../IRBuilder.h"
#include "llvm/IR/Value.h"
class IRBuilder;

54
src/AST/FunctionNode.cpp Normal file
View File

@ -0,0 +1,54 @@
#include "FunctionNode.h"
#include "../Error.h"
#include "../IRBuilder.h"
#include "../utils.h"
#include "llvm/IR/Verifier.h"
FunctionNode::FunctionNode(FunctionPrototype prototype, std::shared_ptr<ExprNode> body)
: prototype(prototype), body(body)
{
}
void FunctionNode::codegen(IRBuilder* builder, llvm::Module* module)
{
llvm::Function* Function = module->getFunction(prototype.name);
if (!Function)
{
llvm::FunctionType* Type = prototype.toFunctionType();
Function = llvm::Function::Create(Type, llvm::Function::ExternalLinkage, prototype.name, *module);
}
else
{
if (!equals(Function->getFunctionType(), prototype.toFunctionType()))
{
// FIXME: add location information to AST nodes, to add information to these errors
Error::throw_error_without_location(
format_string("Function %s redefined with different prototype", prototype.name.c_str()));
}
}
if (!Function) return;
if (!Function->empty())
{
Error::throw_error_without_location(format_string("Function %s already has a body", prototype.name.c_str()));
}
llvm::BasicBlock* BB = llvm::BasicBlock::Create(builder->getBuilder()->getContext(), "entry", Function);
builder->getBuilder()->SetInsertPoint(BB);
if (llvm::Value* retVal = body->codegen(builder))
{
builder->getBuilder()->CreateRet(retVal);
if (llvm::verifyFunction(*Function))
{
Error::throw_error_without_location(format_string("Invalid function %s", prototype.name.c_str()));
}
return;
}
Function->eraseFromParent();
Error::throw_error_without_location(format_string("Error generating code of function %s", prototype.name.c_str()));
}

17
src/AST/FunctionNode.h Normal file
View File

@ -0,0 +1,17 @@
#pragma once
#include "ExprNode.h"
#include "FunctionPrototype.h"
#include "TopLevelNode.h"
class FunctionNode final : public TopLevelNode
{
private:
FunctionPrototype prototype;
std::shared_ptr<ExprNode> body;
public:
FunctionNode(FunctionPrototype prototype, std::shared_ptr<ExprNode> body);
~FunctionNode() = default;
void codegen(IRBuilder* builder, llvm::Module* module) override;
};

View File

@ -0,0 +1,18 @@
#include "FunctionPrototype.h"
llvm::FunctionType* FunctionPrototype::toFunctionType()
{
if (arguments.size() == 0)
{
return llvm::FunctionType::get(returnType, false);
}
else
{
std::vector<llvm::Type*> args;
for (auto& item : arguments)
{
args.push_back(item.first);
}
return llvm::FunctionType::get(returnType, args, false);
}
}

View File

@ -0,0 +1,12 @@
#pragma once
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
struct FunctionPrototype
{
llvm::Type* returnType;
std::string name;
std::vector<std::pair<llvm::Type*, std::string>> arguments;
llvm::FunctionType* toFunctionType();
};

View File

@ -1,4 +1,5 @@
#include "MulNode.h"
#include "../IRBuilder.h"
MulNode::MulNode(std::shared_ptr<ExprNode> left, std::shared_ptr<ExprNode> right, char op)
: BinaryOpNode(left, right), op(op)

View File

@ -1,5 +1,8 @@
#pragma once
#include "../IRBuilder.h"
#include "ExprNode.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Type.h"
#include <string>
class NumberNode : public ExprNode

18
src/AST/ProgramNode.cpp Normal file
View File

@ -0,0 +1,18 @@
#include "ProgramNode.h"
ProgramNode::ProgramNode() : TopLevelNode()
{
}
void ProgramNode::append(std::shared_ptr<TopLevelNode> item)
{
program.push_back(item);
}
void ProgramNode::walk(std::function<void(std::shared_ptr<TopLevelNode>)> callback)
{
for (auto& item : program)
{
callback(item);
}
}

16
src/AST/ProgramNode.h Normal file
View File

@ -0,0 +1,16 @@
#pragma once
#include "TopLevelNode.h"
class ProgramNode final : public TopLevelNode
{
private:
std::vector<std::shared_ptr<TopLevelNode>> program;
public:
ProgramNode();
~ProgramNode() = default;
void append(std::shared_ptr<TopLevelNode>);
void walk(std::function<void(std::shared_ptr<TopLevelNode>)> callback);
};

View File

@ -1,4 +1,5 @@
#include "SumNode.h"
#include "../IRBuilder.h"
SumNode::SumNode(std::shared_ptr<ExprNode> left, std::shared_ptr<ExprNode> right, char op)
: BinaryOpNode(left, right), op(op)

View File

@ -1,6 +1,8 @@
#include "SyscallNode.h"
#include "../Arguments.h"
#include "../Error.h"
#include "../IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InlineAsm.h"
Syscall0Node::Syscall0Node(int syscall_number) : sys_num(syscall_number), ExprNode()

11
src/AST/TopLevelNode.cpp Normal file
View File

@ -0,0 +1,11 @@
#include "TopLevelNode.h"
void TopLevelNode::codegen(IRBuilder*, llvm::Module*)
{
return;
}
llvm::Value* TopLevelNode::codegen(IRBuilder*)
{
return nullptr;
}

15
src/AST/TopLevelNode.h Normal file
View File

@ -0,0 +1,15 @@
#pragma once
#include "ASTNode.h"
#include "llvm/IR/Module.h"
class IRBuilder;
class TopLevelNode : public ASTNode
{
public:
TopLevelNode() = default;
~TopLevelNode() = default;
virtual void codegen(IRBuilder* builder, llvm::Module* module);
llvm::Value* codegen(IRBuilder* builder) override;
};

8
src/GlobalContext.cpp Normal file
View File

@ -0,0 +1,8 @@
#include "GlobalContext.h"
std::shared_ptr<llvm::LLVMContext> globalContext;
void initGlobalContext()
{
globalContext = std::make_shared<llvm::LLVMContext>();
}

5
src/GlobalContext.h Normal file
View File

@ -0,0 +1,5 @@
#include "llvm/IR/LLVMContext.h"
extern std::shared_ptr<llvm::LLVMContext> globalContext;
void initGlobalContext();

View File

@ -1,6 +1,7 @@
#include "IRBuilder.h"
#include "Arguments.h"
#include "Error.h"
#include "GlobalContext.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Value.h"
@ -15,7 +16,7 @@
IRBuilder::IRBuilder()
{
context = std::make_unique<llvm::LLVMContext>();
context = globalContext;
builder = std::unique_ptr<llvm::IRBuilder<>>(new llvm::IRBuilder<>(*context));
module = std::make_unique<llvm::Module>(Arguments::input_fname, *context);
}
@ -25,19 +26,9 @@ llvm::IRBuilder<>* IRBuilder::getBuilder()
return builder.get();
}
void IRBuilder::create_main_function(std::shared_ptr<ASTNode> expression)
void IRBuilder::create_program(std::shared_ptr<ProgramNode> program)
{
llvm::FunctionType* mainType =
llvm::FunctionType::get(llvm::IntegerType::getInt32Ty(*context), std::vector<llvm::Type*>(), false);
llvm::Function* main = llvm::Function::Create(mainType, llvm::GlobalValue::ExternalLinkage, "main", module.get());
llvm::BasicBlock* entryBlock = llvm::BasicBlock::Create(*context, "entry", main);
builder->SetInsertPoint(entryBlock);
llvm::Value* returnValue = expression->codegen(this);
builder->CreateRet(returnValue);
llvm::verifyFunction(*main);
program->walk([&](std::shared_ptr<TopLevelNode> node) { node->codegen(this, module.get()); });
}
void IRBuilder::resolveToLLVMIR(std::string path)

View File

@ -1,5 +1,5 @@
#pragma once
#include "AST/ASTNode.h"
#include "AST/ProgramNode.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
@ -9,13 +9,13 @@ class ASTNode;
class IRBuilder
{
std::unique_ptr<llvm::LLVMContext> context;
std::shared_ptr<llvm::LLVMContext> context;
std::unique_ptr<llvm::IRBuilder<>> builder;
std::unique_ptr<llvm::Module> module;
public:
IRBuilder();
void create_main_function(std::shared_ptr<ASTNode> expression);
void create_program(std::shared_ptr<ProgramNode> program);
llvm::IRBuilder<>* getBuilder();

View File

@ -147,10 +147,10 @@ TokenStream Lexer::lex(const std::string& text)
result.push_back(Token::make_with_line({TT_RParen, loc}, current_line_text));
break;
case '{':
result.push_back(Token::make_with_line({TT_RBracket, loc}, current_line_text));
result.push_back(Token::make_with_line({TT_LBracket, loc}, current_line_text));
break;
case '}':
result.push_back(Token::make_with_line({TT_LBracket, loc}, current_line_text));
result.push_back(Token::make_with_line({TT_RBracket, loc}, current_line_text));
break;
case ';':
result.push_back(Token::make_with_line({TT_Semicolon, loc}, current_line_text));
@ -252,6 +252,10 @@ Token Lexer::create_identifier()
if (identifier == "compmacro")
return Token::make_with_line({TT_CompilerMacro, {prev_line, prev_column, loc.fname}},
current_line_text);
if (identifier == "let")
return Token::make_with_line({TT_Let, {prev_line, prev_column, loc.fname}}, current_line_text);
if (identifier == "in")
return Token::make_with_line({TT_In, {prev_line, prev_column, loc.fname}}, current_line_text);
return Token::make_with_line({TT_Identifier, identifier, {prev_line, prev_column, loc.fname}},
current_line_text);
}
@ -281,6 +285,10 @@ Token Lexer::create_identifier()
return Token::make_with_line({TT_Syscall5, {prev_line, prev_column, loc.fname}}, current_line_text);
if (identifier == "compmacro")
return Token::make_with_line({TT_CompilerMacro, {prev_line, prev_column, loc.fname}}, current_line_text);
if (identifier == "let")
return Token::make_with_line({TT_Let, {prev_line, prev_column, loc.fname}}, current_line_text);
if (identifier == "in")
return Token::make_with_line({TT_In, {prev_line, prev_column, loc.fname}}, current_line_text);
return Token::make_with_line({TT_Identifier, identifier, {prev_line, prev_column, loc.fname}}, current_line_text);
}

View File

@ -1,6 +1,9 @@
#include "Parser.h"
#include "AST/FunctionNode.h"
#include "AST/MulNode.h"
#include "AST/SumNode.h"
#include "FormatString/FormatString.hpp"
#include "GlobalContext.h"
Parser::Parser(const TokenStream& tokens) : tokens(tokens)
{
@ -12,16 +15,18 @@ std::shared_ptr<Parser> Parser::new_parser(const TokenStream& tokens)
new Parser(tokens)); // As always, not using std::make_shared 'cause constructor is private
}
std::shared_ptr<ASTNode> Parser::parse()
std::shared_ptr<ProgramNode> Parser::parse()
{
advance();
auto result = expr();
if (result.is_error()) result.ethrow();
if (current_token->tk_type != TT_EOF)
std::shared_ptr<ProgramNode> final_result = std::make_shared<ProgramNode>();
while (true)
{
Err<ExprNode>("expected *, /, + or -", current_token).ethrow();
auto result = toplevel();
if (result.is_error()) result.ethrow();
final_result->append(result.get());
if (current_token->tk_type == TT_EOF) break;
}
return result.get();
return final_result;
}
int Parser::advance()
@ -82,3 +87,41 @@ Result<ExprNode> Parser::expr()
}
return left;
}
Result<TopLevelNode> Parser::toplevel()
{
// FIXME: Add more top-level stuff later, for now it's only functions.
return function();
}
Result<TopLevelNode> Parser::function()
{
FunctionPrototype proto;
proto.returnType = llvm::IntegerType::getInt32Ty(*globalContext); // FIXME: allow specifying return type
proto.arguments = {}; // FIXME: allow specifying arguments
if (current_token->tk_type != TT_Let)
return Err<TopLevelNode>("Expected let at the beginning of a function", current_token);
advance();
if (current_token->tk_type != TT_At)
return Err<TopLevelNode>("Expected @ at the beginning of a function", current_token);
advance();
if (current_token->tk_type != TT_Identifier) return Err<TopLevelNode>("Expected an identifier", current_token);
else
proto.name = current_token->string_value;
advance();
if (current_token->tk_type != TT_In && current_token->tk_type != TT_Semicolon)
return Err<TopLevelNode>("Expected 'in'", current_token);
if (current_token->tk_type == TT_Semicolon)
return Err<TopLevelNode>("Functions without a body are unsupported (for now)", current_token);
advance();
if (current_token->tk_type != TT_LBracket)
return Err<TopLevelNode>("Invalid syntax",
current_token); // FIXME: Do not be lazy and return a meaningful error message.
advance();
Result<ExprNode> body = expr();
if (body.is_error()) return Err<TopLevelNode>(body.error(), body.token());
if (current_token->tk_type != TT_RBracket)
return Err<TopLevelNode>(format_string("Invalid syntax %d", current_token->tk_type), current_token);
advance();
return Ok<TopLevelNode>(new FunctionNode(proto, body.get()), current_token);
}

View File

@ -1,5 +1,6 @@
#pragma once
#include "AST/NumberNode.h"
#include "AST/ProgramNode.h"
#include "AST/SumNode.h"
#include "Error.h"
#include "Lexer.h"
@ -20,9 +21,12 @@ class Parser
Result<ExprNode> expr();
Result<ExprNode> term();
Result<TopLevelNode> toplevel();
Result<TopLevelNode> function();
public:
/* Construct a new Parser with the given TokenStream. */
static std::shared_ptr<Parser> new_parser(const TokenStream& tokens);
/* Parse the stored TokenStream and return the top-level node of the result Abstract Syntax Tree. */
std::shared_ptr<ASTNode> parse();
std::shared_ptr<ProgramNode> parse();
};

View File

@ -24,6 +24,10 @@ template<typename T> class Result
{
return m_result;
}
std::string error()
{
return m_error;
}
protected:
Token* m_token;

View File

@ -150,6 +150,10 @@ std::string Token::to_string() const
return "SYSCALL5 " + details;
case TT_CompilerMacro:
return "COMPMACRO " + details;
case TT_Let:
return "LET " + details;
case TT_In:
return "IN " + details;
}
return "";
}

View File

@ -44,7 +44,9 @@ enum TokenType
TT_Syscall3,
TT_Syscall4,
TT_Syscall5,
TT_CompilerMacro
TT_CompilerMacro,
TT_Let,
TT_In
};
extern const std::string token_strings[];

View File

@ -1,5 +1,6 @@
#include "Arguments.h"
#include "FileIO.h"
#include "GlobalContext.h"
#include "IRBuilder.h"
#include "Importer.h"
#include "Lexer.h"
@ -33,9 +34,11 @@ int main(int argc, char** argv)
result = Normalizer::normalize(result);
}
initGlobalContext();
auto parser = Parser::new_parser(result);
std::shared_ptr<ASTNode> ast;
std::shared_ptr<ProgramNode> ast;
{
benchmark("Parsing");
ast = parser->parse();
@ -44,8 +47,8 @@ int main(int argc, char** argv)
IRBuilder builder;
{
benchmark("IR generation");
builder.create_main_function(ast);
benchmark("Code generation");
builder.create_program(ast);
}
if (Arguments::emit_llvm) builder.resolveToLLVMIR(Arguments::output_fname);

View File

@ -1,5 +1,6 @@
#include <algorithm>
#include <cassert>
#include <functional>
#include <iostream>
#include <memory>
#include <string>

View File

@ -1,6 +1,78 @@
#include "utils.h"
#include <llvm/IR/DerivedTypes.h>
#include <llvm/Transforms/Utils/FunctionComparator.h>
#include <sstream>
bool equals(const llvm::Type* left, const llvm::Type* right)
{
auto left_ptr = llvm::dyn_cast<llvm::PointerType>(left);
auto right_ptr = llvm::dyn_cast<llvm::PointerType>(right);
if (left == right) return true;
if (left->getTypeID() == right->getTypeID()) return true;
switch (left->getTypeID())
{
case llvm::Type::IntegerTyID:
return llvm::cast<llvm::IntegerType>(left)->getBitWidth() ==
llvm::cast<llvm::IntegerType>(right)->getBitWidth();
// left == right would have returned true earlier, because types are uniqued.
case llvm::Type::VoidTyID:
case llvm::Type::FloatTyID:
case llvm::Type::DoubleTyID:
case llvm::Type::X86_FP80TyID:
case llvm::Type::FP128TyID:
case llvm::Type::PPC_FP128TyID:
case llvm::Type::LabelTyID:
case llvm::Type::MetadataTyID:
case llvm::Type::TokenTyID:
return true;
case llvm::Type::PointerTyID:
assert(left_ptr && right_ptr && "Both types must be pointers here.");
return left_ptr->getAddressSpace() == right_ptr->getAddressSpace();
case llvm::Type::StructTyID: {
auto left_struct = llvm::cast<llvm::StructType>(left);
auto right_struct = llvm::cast<llvm::StructType>(right);
if (left_struct->getNumElements() != right_struct->getNumElements()) return false;
if (left_struct->isPacked() != right_struct->isPacked()) return false;
for (unsigned i = 0, e = left_struct->getNumElements(); i != e; ++i)
{
if (!equals(left_struct->getElementType(i), right_struct->getElementType(i))) return false;
}
return true;
}
case llvm::Type::FunctionTyID: {
auto left_function = llvm::cast<llvm::FunctionType>(left);
auto right_function = llvm::cast<llvm::FunctionType>(right);
if (left_function->getNumParams() != right_function->getNumParams()) return false;
if (left_function->isVarArg() != right_function->isVarArg()) return false;
if (!equals(left_function->getReturnType(), right_function->getReturnType())) return false;
for (unsigned i = 0, e = left_function->getNumParams(); i != e; ++i)
{
if (!equals(left_function->getParamType(i), right_function->getParamType(i))) return false;
}
return true;
}
default:
return false;
}
}
bool replace(std::string& str, const std::string& from, const std::string& to)
{
size_t start_pos = str.find(from);

View File

@ -7,8 +7,11 @@
#pragma once
#include "FormatString/FormatString.hpp"
#include "sapphirepch.h"
#include "llvm/IR/Type.h"
#include <chrono>
bool equals(const llvm::Type* left, const llvm::Type* right);
/*
* Replaces all ocurrences of a substring with another one in a string.
* @param[in] str The input string.

View File

@ -45,7 +45,7 @@ def test_test_case(test_case: dict) -> bool:
runtime = test_case.get("run", False)
if runtime is not False:
print("-> Running command: gcc .tests-bin/output.o -o .tests-bin/output")
link_task = subprocess.Popen(["gcc", ".tests-bin/output.o", "-o", ".tests-bin/output"] + extra_flags,
link_task = subprocess.Popen(["gcc", ".tests-bin/output.o", "-o", ".tests-bin/output"],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if link_task.wait() != 0:
print(f"-> Failed to link program")
@ -111,7 +111,7 @@ def create_test(filename: str, extra_flags: list):
compiler["stderr"] = retstderr
test_case["compile"] = compiler
if retcode == 0:
link_task = subprocess.Popen(["gcc", ".tests-bin/output.o", "-o", ".tests-bin/output"] + extra_flags,
link_task = subprocess.Popen(["gcc", ".tests-bin/output.o", "-o", ".tests-bin/output"],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
link_status = link_task.wait()
if link_status == 0:
@ -127,7 +127,7 @@ def create_test(filename: str, extra_flags: list):
test_case["run"] = program
ofilepath = ".".join(filepath.split(".")[:-1]) + ".json"
ofile = open(ofilepath, "w+")
json.dump(test_case, ofile)
json.dump(test_case, ofile, indent=4)
ofile.close()
shutil.rmtree(".tests-bin")

View File

@ -1 +1,14 @@
{"file": "calc.sp", "compile": {"flags": [], "exit-code": 0, "stdout": "", "stderr": ""}, "run": {"exit-code": 16, "stdout": "", "stderr": ""}}
{
"file": "calc.sp",
"compile": {
"flags": [],
"exit-code": 0,
"stdout": "",
"stderr": ""
},
"run": {
"exit-code": 16,
"stdout": "",
"stderr": ""
}
}

View File

@ -1 +1,3 @@
1 + 3 * 5
let @main in {
1 + 3 * 5
}

View File

@ -1 +1,9 @@
{"file": "import-inexistent.sp", "compile": {"flags": [], "exit-code": 1, "stdout": "", "stderr": "\u001b[1;1mtests/import-inexistent.sp:1:8: \u001b[31;49merror: \u001b[0;0mfile not found\n1 import penguin_boi;\n \u001b[31;49m^\u001b[0;0m\n"}}
{
"file": "import-inexistent.sp",
"compile": {
"flags": [],
"exit-code": 1,
"stdout": "",
"stderr": "\u001b[1;1mtests/import-inexistent.sp:1:8: \u001b[31;49merror: \u001b[0;0mfile not found\n1 import penguin_boi;\n \u001b[31;49m^\u001b[0;0m\n"
}
}

View File

@ -1 +1,5 @@
import penguin_boi;
let @main in {
6 + 3
}

View File

@ -1 +1,9 @@
{"file": "simple.sp", "compile": {"flags": [], "exit-code": 1, "stdout": "", "stderr": "\u001b[1;1mtests/simple.sp:1:1: \u001b[31;49merror: \u001b[0;0mexpected a number\n1 const { outln } from @'core/io';\n \u001b[31;49m^\u001b[0;0m\n"}}
{
"file": "simple.sp",
"compile": {
"flags": [],
"exit-code": 1,
"stdout": "",
"stderr": "\u001b[1;1mtests/simple.sp:1:1: \u001b[31;49merror: \u001b[0;0mExpected let at the beginning of a function\n1 const { outln } from @'core/io';\n \u001b[31;49m^\u001b[0;0m\n"
}
}

View File

@ -1 +1,16 @@
{"file": "wimport.sp", "compile": {"flags": ["--wimport"], "exit-code": 1, "stdout": "\u001b[1;1mtests/wimport.sp:1:8: \u001b[33;49mwarning: \u001b[0;0mfile already imported, skipping\n1 import tests/wimport;\n \u001b[33;49m^\u001b[0;0m\n", "stderr": "\u001b[1;1mtests/wimport.sp:1:22: \u001b[31;49merror: \u001b[0;0mexpected a number\n1 \n \u001b[31;49m^\u001b[0;0m\n"}}
{
"file": "wimport.sp",
"compile": {
"flags": [
"--wimport"
],
"exit-code": 0,
"stdout": "\u001b[1;1mtests/wimport.sp:1:8: \u001b[33;49mwarning: \u001b[0;0mfile already imported, skipping\n1 import tests/wimport;\n \u001b[33;49m^\u001b[0;0m\n",
"stderr": ""
},
"run": {
"exit-code": 0,
"stdout": "",
"stderr": ""
}
}

View File

@ -1 +1,5 @@
import tests/wimport;
let @main in {
0
}