ADD: int arithmetic.

main
bog 2023-09-27 23:05:04 +02:00
parent 85a7af18b9
commit a8b4978860
11 changed files with 322 additions and 10 deletions

View File

@ -1,4 +1,16 @@
PROG ::= INSTR*
INSTR ::= DIR
INSTR ::=
| DIR
| EXPR semicolon
DIR ::= hash ident EXPR
EXPR ::= ident
EXPR ::=
| ADDSUB
ADDSUB ::= MULDIVMOD ((add|sub) MULDIVMOD)*
MULDIVMOD ::= LITERAL ((mul|div|mod) LITERAL)*
LITERAL ::=
| ident
| int

View File

@ -1,4 +1,12 @@
#include "Compiler.hpp"
#include <llvm/IR/BasicBlock.h>
#include <llvm/TargetParser/Host.h>
#include <llvm/Target/TargetOptions.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/FileSystem.h>
#include <llvm/MC/TargetRegistry.h>
#include <llvm/IR/LegacyPassManager.h>
namespace wg
{
@ -10,8 +18,115 @@ namespace wg
{
}
void Compiler::compile(std::shared_ptr<Node> node)
void Compiler::gen()
{
std::cout << node->string() << std::endl;
auto target_triple = llvm::sys::getDefaultTargetTriple();
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmParsers();
llvm::InitializeAllAsmPrinters();
std::string err;
auto target = llvm::TargetRegistry::lookupTarget(target_triple, err);
if (!target)
{
llvm::errs() << err;
abort();
}
auto cpu = "generic";
auto features = "";
llvm::TargetOptions opt;
auto rm = std::optional<llvm::Reloc::Model>();
auto target_machine = target->createTargetMachine(target_triple,
cpu,
features,
opt,
rm);
m_module->setDataLayout(target_machine->createDataLayout());
m_module->setTargetTriple(target_triple);
auto filename = "output.o";
std::error_code ec;
llvm::raw_fd_ostream dest(filename, ec, llvm::sys::fs::OF_None);
WG_ASSERT(!ec, "cannot write output file.");
llvm::legacy::PassManager pass;
auto file_type = llvm::CodeGenFileType::CGFT_ObjectFile;
if (target_machine->addPassesToEmitFile(pass, dest, nullptr, file_type))
{
llvm::errs() << "Target machine cannot emit a file of this type";
}
pass.run(*m_module);
dest.flush();
m_module->print(llvm::outs(), nullptr);
}
llvm::Value* Compiler::compile(std::shared_ptr<Node> node)
{
switch (node->type())
{
case NODE_PROG: {
auto* block = llvm::BasicBlock::Create(*m_context,
"entry");
m_builder->SetInsertPoint(block);
for (size_t i=0; i<node->size(); i++)
{
compile(node->child(i));
}
return block;
} break;
case NODE_ADD: {
auto* lhs = compile(node->child(0));
auto* rhs = compile(node->child(1));
return m_builder->CreateAdd(lhs, rhs);
} break;
case NODE_SUB: {
auto* lhs = compile(node->child(0));
auto* rhs = compile(node->child(1));
return m_builder->CreateSub(lhs, rhs);
} break;
case NODE_MUL: {
auto* lhs = compile(node->child(0));
auto* rhs = compile(node->child(1));
return m_builder->CreateMul(lhs, rhs);
} break;
case NODE_DIV: {
auto* lhs = compile(node->child(0));
auto* rhs = compile(node->child(1));
return m_builder->CreateSDiv(lhs, rhs);
} break;
case NODE_MOD: {
auto* lhs = compile(node->child(0));
auto* rhs = compile(node->child(1));
return m_builder->CreateSRem(lhs, rhs);
} break;
case NODE_INT: {
return llvm::ConstantInt::get(*m_context, llvm::APInt(32, 0, true));
} break;
default:
WG_ASSERT(false,
std::string()
+ "cannot compile unknown node '"
+ NodeTypeStr[node->type()]
+ "'");
}
}
}

View File

@ -16,7 +16,9 @@ namespace wg
explicit Compiler();
virtual ~Compiler();
void compile(std::shared_ptr<Node> node);
void gen();
llvm::Value* compile(std::shared_ptr<Node> node);
private:
std::unique_ptr<llvm::LLVMContext> m_context =
std::make_unique<llvm::LLVMContext>();

View File

@ -1,11 +1,21 @@
#include "Lexer.hpp"
#include "lib/Node.hpp"
namespace wg
{
/*explicit*/ Lexer::Lexer()
{
add_text("#", NODE_HASH);
add_text("+", NODE_ADD);
add_text("-", NODE_SUB);
add_text("*", NODE_MUL);
add_text("/", NODE_DIV);
add_text("%", NODE_MOD);
add_text("(", NODE_OPAR);
add_text(")", NODE_CPAR);
add_text(";", NODE_SEMICOLON);
m_scanners.push_back(std::bind(&Lexer::scan_int, this));
m_scanners.push_back(std::bind(&Lexer::scan_ident, this));
}
@ -154,4 +164,35 @@ namespace wg
return std::nullopt;
}
std::optional<ScanInfo> Lexer::scan_int() const
{
size_t cursor = m_cursor;
std::string repr;
if (cursor < m_source.size()
&& m_source[cursor] == '-')
{
repr += '-';
cursor++;
}
while (cursor < m_source.size()
&& std::isdigit(m_source[cursor]))
{
repr += m_source[cursor];
cursor++;
}
if (repr.empty() || repr.back() == '-')
{
return std::nullopt;
}
return ScanInfo {
cursor,
NODE_INT,
repr
};
}
}

View File

@ -45,7 +45,7 @@ namespace wg
bool has_value) const;
std::optional<ScanInfo> scan_ident() const;
std::optional<ScanInfo> scan_int() const;
};
}

View File

@ -8,7 +8,12 @@
G(NODE_PROG), \
G(NODE_IDENT), \
G(NODE_HASH), \
G(NODE_DIR),
G(NODE_DIR), \
G(NODE_INT), \
G(NODE_ADD), G(NODE_SUB), \
G(NODE_MUL),G(NODE_DIV), \
G(NODE_MOD), G(NODE_OPAR), G(NODE_CPAR), \
G(NODE_SEMICOLON)
namespace wg
{

View File

@ -21,11 +21,26 @@ namespace wg
Loc Parser::loc() const
{
if (m_cursor >= m_tokens.size())
{
return Loc {};
}
return m_tokens[m_cursor]->loc();
}
std::shared_ptr<Node> Parser::consume(NodeType type)
{
if (m_cursor >= m_tokens.size())
{
std::stringstream ss;
ss << "type mismatch, expected '"
<< (NodeTypeStr[type] + strlen("NODE_"))
<< "', got nothing.";
loc().error<syntax_error>(ss);
}
auto current = m_tokens[m_cursor];
if (current->type() != type)
@ -89,10 +104,17 @@ namespace wg
}
std::shared_ptr<Node> Parser::parse_instr()
{
if (type_is(NODE_HASH))
{
return parse_dir();
}
auto expr = parse_expr();
consume(NODE_SEMICOLON);
return expr;
}
std::shared_ptr<Node> Parser::parse_dir()
{
auto node = make_node(NODE_DIR);
@ -104,6 +126,67 @@ namespace wg
std::shared_ptr<Node> Parser::parse_expr()
{
return consume(NODE_IDENT);
return parse_addsub();
}
std::shared_ptr<Node> Parser::parse_addsub()
{
auto lhs = parse_muldivmod();
while (type_is(NODE_ADD)
|| type_is(NODE_SUB))
{
auto node = consume();
node->add_child(lhs);
node->add_child(parse_muldivmod());
lhs = node;
}
return lhs;
}
std::shared_ptr<Node> Parser::parse_muldivmod()
{
auto lhs = parse_literal();
while (type_is(NODE_MUL)
|| type_is(NODE_DIV)
|| type_is(NODE_MOD))
{
auto node = consume();
node->add_child(lhs);
node->add_child(parse_literal());
lhs = node;
}
return lhs;
}
std::shared_ptr<Node> Parser::parse_literal()
{
if (type_is(NODE_INT)
|| type_is(NODE_IDENT))
{
return consume();
}
// Groups
if (type_is(NODE_OPAR))
{
consume(NODE_OPAR);
auto expr = parse_expr();
consume(NODE_CPAR);
return expr;
}
loc().error<syntax_error>(std::string()
+ "unknown literal '"
+ (NodeTypeStr[m_tokens[m_cursor]->type()]
+ strlen("NODE_"))
+ "'");
return nullptr;
}
}

View File

@ -33,6 +33,9 @@ namespace wg
std::shared_ptr<Node> parse_dir();
std::shared_ptr<Node> parse_expr();
std::shared_ptr<Node> parse_addsub();
std::shared_ptr<Node> parse_muldivmod();
std::shared_ptr<Node> parse_literal();
};
}

View File

@ -1,5 +1,6 @@
#include <iostream>
#include <fstream>
#include <Lexer.hpp>
#include <Parser.hpp>
#include <Compiler.hpp>
@ -34,6 +35,7 @@ int main(int argc, char** argv)
wg::Compiler compiler;
compiler.compile(ast);
compiler.gen();
return 0;
}

View File

@ -34,3 +34,27 @@ TEST_CASE_METHOD(LexerTest, "Lexer_")
test_next(lex, "IDENT[canard]");
test_end(lex);
}
TEST_CASE_METHOD(LexerTest, "Lexer_int_literal")
{
wg::Lexer lex;
lex.scan(" 3 -2 78 ");
test_next(lex, "INT[3]");
test_next(lex, "INT[-2]");
test_next(lex, "INT[78]");
test_end(lex);
}
TEST_CASE_METHOD(LexerTest, "Lexer_int_arith")
{
wg::Lexer lex;
lex.scan(" +-*/% ()");
test_next(lex, "ADD");
test_next(lex, "SUB");
test_next(lex, "MUL");
test_next(lex, "DIV");
test_next(lex, "MOD");
test_next(lex, "OPAR");
test_next(lex, "CPAR");
test_end(lex);
}

View File

@ -25,8 +25,33 @@ public:
protected:
};
TEST_CASE_METHOD(ParserTest, "Parser_")
TEST_CASE_METHOD(ParserTest, "Parser_dir")
{
test_parse("PROG(DIR(IDENT[hello],IDENT[world]))",
"#hello world");
}
TEST_CASE_METHOD(ParserTest, "Parser_int")
{
test_parse("PROG(INT[45])",
" 45; ");
test_parse("PROG(ADD(INT[1],MUL(INT[2],INT[3])))",
" 1 + 2 * 3; ");
test_parse("PROG(MUL(ADD(INT[1],INT[2]),INT[3]))",
" (1 + 2) * 3; ");
test_parse("PROG(SUB(INT[1],DIV(INT[2],INT[3])))",
" 1 - 2 / 3; ");
test_parse("PROG(DIV(SUB(INT[1],INT[2]),INT[3]))",
" (1 - 2) / 3; ");
test_parse("PROG(ADD(INT[1],MOD(INT[2],INT[3])))",
" 1 + 2 % 3; ");
test_parse("PROG(MOD(ADD(INT[1],INT[2]),INT[3]))",
" (1 + 2) % 3; ");
}