roza/lib/StaticPass.cpp

246 lines
6.2 KiB
C++

#include "StaticPass.hpp"
#include "lib/Node.hpp"
#include "TypeResolver.hpp"
namespace roza
{
/*explicit*/ StaticPass::StaticPass(StatusLog& log)
: StaticPass (log, SymTable {})
{
}
/*explicit*/ StaticPass::StaticPass(StatusLog& log, SymTable const& sym_table)
: m_log { log }
, m_sym { SymTable(sym_table) }
{
}
/*virtual*/ StaticPass::~StaticPass()
{
}
void StaticPass::check(std::shared_ptr<Node> root)
{
TypeResolver resolver {m_log};
assert(root);
switch (root->type())
{
case NODE_FUN: {
//SymTable fun_sym;
//StaticPass pass {m_log, fun_sym};
m_sym.enter_scope();
auto params = root->child(0);
for (size_t i=0; i<params->size(); i++)
{
auto name = params->child(i)->child(0)->repr();
auto node = params->child(i)->child(1);
m_sym.declare_mut(name, node);
}
m_outer_fun_ret = root->child(1)->child(0);
check(root->child(2)->child(0));
m_outer_fun_ret = nullptr;
m_sym.leave_scope();
} break;
case NODE_RETURN: {
check_children(root);
auto actual_ty = resolver.find(root->child(0), m_sym);
auto fun_ty = resolver.find(m_outer_fun_ret, m_sym);
check_types(root, fun_ty, actual_ty);
} break;
case NODE_CALL: {
check_children(root);
std::string fname = root->child(0)->repr();
auto args = root->child(1);
auto entry = m_sym.find(fname);
auto params = entry.node->child(0);
if (args->size() != params->size())
{
std::stringstream ss;
ss << "function '"<< fname << "' expects " << params->size();
ss << " arguments,";
ss << " got " << args->size();
m_log.fatal(root->loc(), ss.str());
}
for (size_t i=0; i<args->size(); i++)
{
auto arg = args->child(i);
auto param = params->child(i);
auto arg_ty = resolver.find(arg, m_sym);
auto param_ty = resolver.find(param->child(1), m_sym);
check_types(root, param_ty, arg_ty);
}
} break;
case NODE_IF: {
m_sym.enter_scope();
auto cond_type = resolver.find(root->child(0), m_sym);
check_types(root, std::make_shared<Type>(TY_BOOL), cond_type);
check_children(root);
m_sym.leave_scope();
} break;
case NODE_ARGS:
case NODE_THEN:
case NODE_ELSE: {
check_children(root);
} break;
case NODE_CONSTDECL: {
check(root->child(1));
m_sym.declare(root->child(0)->repr(), root->child(1));
} break;
case NODE_VARDECL: {
check(root->child(1));
m_sym.declare_mut(root->child(0)->repr(), root->child(1));
} break;
case NODE_ASSIGN: {
auto const& entry = m_sym.find(root->child(0)->repr());
if (!entry.is_mut)
{
m_log.fatal(root->child(0)->loc(),
root->child(0)->repr() + " is not mutable");
}
auto lhs = resolver.find(entry.node, m_sym);
auto rhs = resolver.find(root->child(1), m_sym);
check_types(root, lhs, rhs);
} break;
case NODE_IDENT:
case NODE_ASSERT_STATIC_FAIL:
case NODE_INT:
case NODE_BOOL:
break;
case NODE_EQ:
case NODE_NE: {
check_children(root);
auto lhs = resolver.find(root->child(0), m_sym);
auto rhs = resolver.find(root->child(1), m_sym);
check_types(root, lhs, rhs);
} break;
case NODE_IMP:
case NODE_OR:
case NODE_AND: {
check_children(root);
auto lhs = resolver.find(root->child(0), m_sym);
auto rhs = resolver.find(root->child(1), m_sym);
check_types(root, lhs, std::make_shared<Type>(TY_BOOL));
check_types(root, lhs, rhs);
} break;
case NODE_ASSERT:
case NODE_NOT: {
check_children(root);
auto lhs = resolver.find(root->child(0), m_sym);
check_types(root, lhs, std::make_shared<Type>(TY_BOOL));
} break;
case NODE_LT:
case NODE_LE:
case NODE_GT:
case NODE_GE:
case NODE_ADD:
case NODE_SUB:
case NODE_MUL:
case NODE_DIV:
case NODE_MOD:
case NODE_POW: {
check_children(root);
auto lhs = resolver.find(root->child(0), m_sym);
auto rhs = resolver.find(root->child(1), m_sym);
check_types(root, lhs, std::make_shared<Type>(TY_INT));
check_types(root, lhs, rhs);
} break;
case NODE_UADD:
case NODE_USUB: {
check_children(root);
auto lhs = resolver.find(root->child(0), m_sym);
check_types(root, lhs, std::make_shared<Type>(TY_INT));
} break;
case NODE_PROG: {
check_children(root);
} break;
default:
m_log.fatal(root->loc(), "cannot check node '" + root->string() + "'");
}
}
void StaticPass::check_children(std::shared_ptr<Node> root)
{
for (size_t i=0; i<root->size(); i++)
{
check(root->child(i));
}
}
void StaticPass::check_types(std::shared_ptr<Node> root,
std::shared_ptr<Type> lhs,
std::shared_ptr<Type> rhs)
{
assert(lhs);
assert(rhs);
if (!lhs->equals(*rhs))
{
m_log.fatal(root->loc(),
std::string()
+ "type mismatch, expected '"
+ lhs->string()
+ "', got '"
+ rhs->string()
+ "'");
}
}
void StaticPass::check_types(std::shared_ptr<Node> root,
std::shared_ptr<Type> lhs,
std::vector<std::shared_ptr<Type>> const& rhs)
{
for (auto const& ty: rhs)
{
if (lhs->equals(*ty))
{
return;
}
}
std::stringstream ss;
ss << "type mismatch, got '" << lhs->string() << "'";
ss << "candidates are:" << std::endl;
for (auto ty: rhs)
{
ss << "\t-> " << ty->string() << std::endl;
}
m_log.fatal(root->loc(), ss.str());
}
}