ADD: basic function call and int literal.

main
bog 2023-09-28 21:37:10 +02:00
parent a8b4978860
commit bd819bfc54
15 changed files with 594 additions and 24 deletions

View File

@ -2,15 +2,28 @@ PROG ::= INSTR*
INSTR ::=
| DIR
| EXPR semicolon
| FUNDECL
| return EXPR semicolon
| EXTERN semicolon
EXTERN ::= extern fun ident opar PARAMS cpar RET
FUNDECL ::= fun ident opar PARAMS cpar RET BLOCK
PARAMS ::= (ident type? (comma ident type?)*)
RET ::= type?
BLOCK ::= obrace INSTR* cbrace
DIR ::= hash ident EXPR
EXPR ::=
| ADDSUB
ADDSUB ::= MULDIVMOD ((add|sub) MULDIVMOD)*
MULDIVMOD ::= LITERAL ((mul|div|mod) LITERAL)*
LITERAL ::=
| ident
| int
| CALL
CALL ::= ident opar ARGS cpar
ARGS ::= (EXPR (comma EXPR)*)?

View File

@ -1,5 +1,7 @@
#include "Compiler.hpp"
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Verifier.h>
#include <llvm/IR/Type.h>
#include <llvm/TargetParser/Host.h>
#include <llvm/Target/TargetOptions.h>
#include <llvm/Target/TargetMachine.h>
@ -12,6 +14,7 @@ namespace wg
{
/*explicit*/ Compiler::Compiler()
{
}
/*virtual*/ Compiler::~Compiler()
@ -64,10 +67,13 @@ namespace wg
llvm::errs() << "Target machine cannot emit a file of this type";
}
llvm::verifyModule(*m_module);
pass.run(*m_module);
dest.flush();
m_module->print(llvm::outs(), nullptr);
dest.flush();
m_module->print(llvm::errs(), nullptr);
}
llvm::Value* Compiler::compile(std::shared_ptr<Node> node)
@ -75,16 +81,151 @@ namespace wg
switch (node->type())
{
case NODE_PROG: {
auto* block = llvm::BasicBlock::Create(*m_context,
"entry");
m_builder->SetInsertPoint(block);
for (size_t i=0; i<node->size(); i++)
{
compile(node->child(i));
}
return block;
return nullptr;
} break;
case NODE_BLOCK: {
for (size_t i=0; i<node->size(); i++)
{
compile(node->child(i));
}
return nullptr;
} break;
case NODE_RETURN: {
return m_builder->CreateRet(compile(node->child(0)));
} break;
case NODE_EXTERN: {
auto ident = node->child(0)->repr();
auto params = node->child(1);
auto ret = node->child(2);
std::vector<std::string> names;
std::vector<llvm::Type*> types;
for (size_t i=0; i<params->size(); i++)
{
auto param = params->child(i);
if (param->type() == NODE_IDENT)
{
names.push_back(param->repr());
}
else if (param->type() == NODE_TYPE)
{
auto ty = llvm::Type::getInt32Ty(*m_context);
for (auto name: names)
{
m_sym->declare(name, ty, node->loc());
types.push_back(ty);
}
names.clear();
}
}
llvm::Type* ret_type = llvm::Type::getVoidTy(*m_context);
if (ret->size() > 0)
{
ret_type = llvm::Type::getInt32Ty(*m_context);
}
auto fun_type = llvm::FunctionType::get(ret_type, types, false);
auto fun = llvm::Function::Create(fun_type,
llvm::Function::ExternalLinkage,
ident,
*m_module);
m_sym->declare_prototype(ident, fun_type, node->loc());
return fun;
} break;
case NODE_FUNDECL: {
auto ident = node->child(0)->repr();
auto params = node->child(1);
auto ret = node->child(2);
std::vector<std::string> names;
std::vector<llvm::Type*> types;
for (size_t i=0; i<params->size(); i++)
{
auto param = params->child(i);
if (param->type() == NODE_IDENT)
{
names.push_back(param->repr());
}
else if (param->type() == NODE_TYPE)
{
auto ty = llvm::Type::getInt32Ty(*m_context);
for (auto name: names)
{
m_sym->declare(name, ty, node->loc());
types.push_back(ty);
}
names.clear();
}
}
llvm::Type* ret_type = llvm::Type::getVoidTy(*m_context);
auto body = node->child(3);
if (ret->size() > 0)
{
ret_type = llvm::Type::getInt32Ty(*m_context);
}
auto fun_type = llvm::FunctionType::get(ret_type, types, false);
auto fun = llvm::Function::Create(fun_type,
llvm::Function::ExternalLinkage,
ident,
*m_module);
m_sym->declare(ident, fun_type, node->loc());
llvm::BasicBlock* old_bb = m_builder->GetInsertBlock();
auto* bb = llvm::BasicBlock::Create(*m_context,
"entry",
fun);
m_builder->SetInsertPoint(bb);
compile(body);
m_builder->SetInsertPoint(old_bb);
llvm::verifyFunction(*fun);
return fun;
} break;
case NODE_CALL: {
std::string ident = node->child(0)->repr();
auto fun = m_module->getFunction(ident);
WG_ASSERT(fun, "cannot call unknown function '" + ident + "'");
std::vector<llvm::Value*> values;
for (size_t i=0; i<node->child(1)->size(); i++)
{
auto arg = node->child(1)->child(i);
auto val = compile(arg);
values.push_back(val);
}
return m_builder->CreateCall(fun, values);
} break;
case NODE_ADD: {
@ -118,7 +259,11 @@ namespace wg
} break;
case NODE_INT: {
return llvm::ConstantInt::get(*m_context, llvm::APInt(32, 0, true));
return llvm::ConstantInt::get(*m_context,
llvm::APInt(32,
std::stoi(node->repr()),
true));
} break;
default:

View File

@ -7,6 +7,7 @@
#include "commons.hpp"
#include "Node.hpp"
#include "SymTable.hpp"
namespace wg
{
@ -28,6 +29,8 @@ namespace wg
std::unique_ptr<llvm::Module> m_module =
std::make_unique<llvm::Module>("my module", *m_context);
std::unique_ptr<SymTable> m_sym = std::make_unique<SymTable>();
};
}

View File

@ -5,6 +5,14 @@ namespace wg
{
/*explicit*/ Lexer::Lexer()
{
add_keyword("int", NODE_TYPE, true);
add_keyword("fun", NODE_FUN);
add_keyword("return", NODE_RETURN);
add_keyword("extern", NODE_EXTERN);
add_text("{", NODE_OBRACE);
add_text("}", NODE_CBRACE);
add_text(",", NODE_COMMA);
add_text("#", NODE_HASH);
add_text("+", NODE_ADD);
add_text("-", NODE_SUB);
@ -35,6 +43,18 @@ namespace wg
skip_spaces();
while (m_cursor + 1 < m_source.size()
&& m_source[m_cursor] == ':'
&& m_source[m_cursor + 1] == ':')
{
while (m_source[m_cursor] != '\n')
{
m_cursor++;
}
skip_spaces();
}
for (auto scanner: m_scanners)
{
auto info = scanner();
@ -87,6 +107,20 @@ namespace wg
node, has_value));
}
void Lexer::add_keyword(std::string const& text,
NodeType node,
bool has_value)
{
if (text.size() == 1)
{
m_seps.push_back(text[0]);
}
m_scanners.push_back(std::bind(&Lexer::scan_keyword,
this, text,
node, has_value));
}
bool Lexer::is_sep(size_t index) const
{
WG_ASSERT(index < m_source.size(), "cannot find separator");
@ -141,6 +175,35 @@ namespace wg
};
}
std::optional<ScanInfo> Lexer::scan_keyword(std::string const& text,
NodeType type,
bool has_value) const
{
if (m_cursor + text.size() > m_source.size())
{
return std::nullopt;
}
for (size_t i=0; i<text.size(); i++)
{
if (m_source[m_cursor + i] != text[i])
{
return std::nullopt;
}
}
if (!is_sep(m_cursor + text.size()))
{
return std::nullopt;
}
return ScanInfo {
m_cursor + text.size(),
type,
has_value ? text : ""
};
}
std::optional<ScanInfo> Lexer::scan_ident() const
{
size_t cursor = m_cursor;

View File

@ -36,6 +36,10 @@ namespace wg
NodeType node,
bool has_value=false);
void add_keyword(std::string const& text,
NodeType node,
bool has_value=false);
bool is_sep(size_t index) const;
void skip_spaces();
@ -44,6 +48,10 @@ namespace wg
NodeType type,
bool has_value) const;
std::optional<ScanInfo> scan_keyword(std::string const& text,
NodeType type,
bool has_value) const;
std::optional<ScanInfo> scan_ident() const;
std::optional<ScanInfo> scan_int() const;
};

View File

@ -16,10 +16,10 @@ namespace wg
int line() const { return m_line; }
template <typename T>
void error(std::string const& what);
void error(std::string const& what) const;
template <typename T>
void error(std::stringstream const& what);
void error(std::stringstream const& what) const;
private:
std::filesystem::path m_origin;
@ -27,7 +27,7 @@ namespace wg
};
template <typename T>
void Loc::error(std::string const& what)
void Loc::error(std::string const& what) const
{
std::stringstream ss;
ss << m_origin.string() << ": ERROR " << what;
@ -36,7 +36,7 @@ namespace wg
}
template <typename T>
void Loc::error(std::stringstream const& what)
void Loc::error(std::stringstream const& what) const
{
error<T>(what.str());
}

View File

@ -20,7 +20,9 @@ namespace wg
std::shared_ptr<Node> Node::child(size_t index) const
{
WG_ASSERT(index < size(), "aze");
WG_ASSERT(index < size(), "Cannot get child node of '"
+ string()
+ "'");
return m_children.at(index);
}

View File

@ -13,7 +13,11 @@
G(NODE_ADD), G(NODE_SUB), \
G(NODE_MUL),G(NODE_DIV), \
G(NODE_MOD), G(NODE_OPAR), G(NODE_CPAR), \
G(NODE_SEMICOLON)
G(NODE_SEMICOLON), G(NODE_COMMA), G(NODE_CALL), \
G(NODE_ARGS), G(NODE_TYPE), G(NODE_RETURN), \
G(NODE_FUN), G(NODE_PARAMS), G(NODE_BLOCK), \
G(NODE_OBRACE), G(NODE_CBRACE), G(NODE_FUNDECL), \
G(NODE_EXTERN), G(NODE_RET)
namespace wg
{

View File

@ -1,4 +1,5 @@
#include "Parser.hpp"
#include "lib/Node.hpp"
namespace wg
{
@ -23,7 +24,7 @@ namespace wg
{
if (m_cursor >= m_tokens.size())
{
return Loc {};
return m_tokens.back()->loc();
}
return m_tokens[m_cursor]->loc();
@ -110,11 +111,120 @@ namespace wg
return parse_dir();
}
if (type_is(NODE_FUN))
{
return parse_fundecl();
}
if (type_is(NODE_RETURN))
{
auto node = consume();
node->add_child(parse_expr());
consume(NODE_SEMICOLON);
return node;
}
if (type_is(NODE_EXTERN))
{
auto node = parse_extern();
consume(NODE_SEMICOLON);
return node;
}
auto expr = parse_expr();
consume(NODE_SEMICOLON);
return expr;
}
std::shared_ptr<Node> Parser::parse_extern()
{
auto node = consume(NODE_EXTERN);
consume(NODE_FUN);
node->add_child(consume(NODE_IDENT));
consume(NODE_OPAR);
node->add_child(parse_params());
consume(NODE_CPAR);
node->add_child(parse_ret());
return node;
}
std::shared_ptr<Node> Parser::parse_fundecl()
{
auto node = make_node(NODE_FUNDECL);
consume(NODE_FUN);
node->add_child(consume(NODE_IDENT));
consume(NODE_OPAR);
node->add_child(parse_params());
consume(NODE_CPAR);
node->add_child(parse_ret());
node->add_child(parse_block());
return node;
}
std::shared_ptr<Node> Parser::parse_params()
{
auto node = make_node(NODE_PARAMS);
if (type_is(NODE_CPAR))
{
return node;
}
node->add_child(consume(NODE_IDENT));
if (type_is(NODE_TYPE))
{
node->add_child(consume());
}
while (type_is(NODE_COMMA))
{
consume();
node->add_child(consume(NODE_IDENT));
if (type_is(NODE_TYPE))
{
node->add_child(consume());
}
}
return node;
}
std::shared_ptr<Node> Parser::parse_ret()
{
auto node = make_node(NODE_RET);
if (type_is(NODE_TYPE))
{
node->add_child(consume());
}
return node;
}
std::shared_ptr<Node> Parser::parse_block()
{
auto node = make_node(NODE_BLOCK);
consume(NODE_OBRACE);
while (type_isnt(NODE_CBRACE))
{
node->add_child(parse_instr());
}
consume(NODE_CBRACE);
return node;
}
std::shared_ptr<Node> Parser::parse_dir()
{
auto node = make_node(NODE_DIR);
@ -164,6 +274,12 @@ namespace wg
std::shared_ptr<Node> Parser::parse_literal()
{
if (type_is(NODE_IDENT)
&& type_is(NODE_OPAR, 1))
{
return parse_call();
}
if (type_is(NODE_INT)
|| type_is(NODE_IDENT))
{
@ -189,4 +305,36 @@ namespace wg
+ "'");
return nullptr;
}
std::shared_ptr<Node> Parser::parse_call()
{
auto node = make_node(NODE_CALL);
node->add_child(consume(NODE_IDENT));
consume(NODE_OPAR);
node->add_child(parse_args());
consume(NODE_CPAR);
return node;
}
std::shared_ptr<Node> Parser::parse_args()
{
auto node = make_node(NODE_ARGS);
if (type_is(NODE_CPAR))
{
return node;
}
node->add_child(parse_expr());
while (type_is(NODE_COMMA))
{
consume();
node->add_child(parse_expr());
}
return node;
}
}

View File

@ -30,12 +30,21 @@ namespace wg
std::shared_ptr<Node> parse_prog();
std::shared_ptr<Node> parse_instr();
std::shared_ptr<Node> parse_extern();
std::shared_ptr<Node> parse_fundecl();
std::shared_ptr<Node> parse_params();
std::shared_ptr<Node> parse_ret();
std::shared_ptr<Node> parse_block();
std::shared_ptr<Node> parse_dir();
std::shared_ptr<Node> parse_expr();
std::shared_ptr<Node> parse_addsub();
std::shared_ptr<Node> parse_muldivmod();
std::shared_ptr<Node> parse_literal();
std::shared_ptr<Node> parse_call();
std::shared_ptr<Node> parse_args();
};
}

90
lib/SymTable.cpp Normal file
View File

@ -0,0 +1,90 @@
#include "SymTable.hpp"
#include "commons.hpp"
namespace wg
{
/*explicit*/ SymTable::SymTable()
{
}
/*virtual*/ SymTable::~SymTable()
{
}
bool SymTable::exists(std::string const& name) const
{
return m_entries.find(name) != std::end(m_entries);
}
void SymTable::declare_prototype(std::string const& name,
llvm::Type* type,
Loc const& loc)
{
if (auto itr=m_entries.find(name);
itr != std::end(m_entries))
{
loc.error<symbol_error>("cannot declare existing symbol '"
+ name
+ "'");
}
SymEntry entry;
entry.name = name;
entry.type = type;
entry.prototype = true;
m_entries[name] = entry;
}
void SymTable::declare(std::string const& name,
llvm::Type* type,
Loc const& loc)
{
if (auto itr=m_entries.find(name);
itr != std::end(m_entries)
&& itr->second.prototype == false)
{
loc.error<symbol_error>("cannot declare existing symbol '"
+ name
+ "'");
}
SymEntry entry;
entry.name = name;
entry.type = type;
entry.prototype = false;
m_entries[name] = entry;
}
void SymTable::set(std::string const& name,
llvm::Type* type,
Loc const& loc)
{
if (auto itr=m_entries.find(name);
itr == std::end(m_entries))
{
loc.error<symbol_error>("cannot set inexisting symbol '"
+ name
+ "'");
}
SymEntry entry;
entry.name = name;
entry.type = type;
m_entries[name] = entry;
}
SymEntry& SymTable::get(std::string const& name, Loc const& loc)
{
if (auto itr=m_entries.find(name);
itr != std::end(m_entries))
{
return itr->second;
}
else
{
loc.error<symbol_error>("cannot find symbol '" + name + "'");
}
abort();
}
}

39
lib/SymTable.hpp Normal file
View File

@ -0,0 +1,39 @@
#ifndef wg_SYMTABLE_HPP
#define wg_SYMTABLE_HPP
#include <llvm/IR/Type.h>
#include "commons.hpp"
#include "Loc.hpp"
namespace wg
{
WG_ERROR(symbol_error);
struct SymEntry {
std::string name;
llvm::Type* type;
bool prototype = false;
};
class SymTable
{
public:
explicit SymTable();
virtual ~SymTable();
bool exists(std::string const& name) const;
void declare_prototype(std::string const& name,
llvm::Type* type,
Loc const& loc);
void declare(std::string const& name, llvm::Type* type, Loc const& loc);
void set(std::string const& name, llvm::Type* type, Loc const& loc);
SymEntry& get(std::string const& name, Loc const& loc);
private:
std::unordered_map<std::string, SymEntry> m_entries;
};
}
#endif

View File

@ -15,6 +15,7 @@ wongola_lib = static_library(
'lib/Parser.cpp',
'lib/Compiler.cpp',
'lib/Loc.cpp',
'lib/SymTable.cpp',
],
dependencies: [
dependency('LLVM')

View File

@ -58,3 +58,13 @@ TEST_CASE_METHOD(LexerTest, "Lexer_int_arith")
test_next(lex, "CPAR");
test_end(lex);
}
TEST_CASE_METHOD(LexerTest, "Lexer_fun_call")
{
wg::Lexer lex;
lex.scan(" , int extern ");
test_next(lex, "COMMA");
test_next(lex, "TYPE[int]");
test_next(lex, "EXTERN");
test_end(lex);
}

View File

@ -55,3 +55,38 @@ TEST_CASE_METHOD(ParserTest, "Parser_int")
test_parse("PROG(MOD(ADD(INT[1],INT[2]),INT[3]))",
" (1 + 2) % 3; ");
}
TEST_CASE_METHOD(ParserTest, "Parser_call")
{
test_parse("PROG(CALL(IDENT[hello],ARGS))",
" hello(); ");
test_parse("PROG(CALL(IDENT[hello_world],ARGS(IDENT[x])))",
" hello_world(x); ");
test_parse("PROG(CALL(IDENT[hello_world],ARGS(IDENT[x],INT[78])))",
" hello_world(x, 78); ");
}
TEST_CASE_METHOD(ParserTest, "Parser_fundecl")
{
test_parse("PROG(EXTERN(IDENT[hello],"
"PARAMS(IDENT[x],IDENT[y],TYPE[int]),RET(TYPE[int])))",
" extern fun hello(x, y int) int; ");
test_parse("PROG(FUNDECL(IDENT[couc],PARAMS,RET,BLOCK))",
" fun couc() {} ");
test_parse("PROG(FUNDECL(IDENT[couc],PARAMS("
"IDENT[x],IDENT[y],TYPE[int]"
"),RET,BLOCK(RETURN(INT[4]))))",
" fun couc(x, y int) { return 4; } ");
test_parse("PROG(FUNDECL(IDENT[couc],PARAMS("
"IDENT[x],IDENT[y],TYPE[int]"
"),RET(TYPE[int]),BLOCK(RETURN(INT[4]))))",
" fun couc(x, y int) int { return 4; } ");
test_parse("PROG(RETURN(ADD(CALL(IDENT[a],ARGS),INT[1])))",
" return a() + 1; ");
}