✨ node substitution and unification.
parent
93f7430e11
commit
cd77e92682
|
@ -13,6 +13,7 @@ sp_lib = static_library(
|
|||
'src/fol/Node.cpp',
|
||||
'src/fol/Lexer.cpp',
|
||||
'src/fol/Parser.cpp',
|
||||
'src/fol/Substitution.cpp',
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -32,6 +33,7 @@ executable('sine-patre-tests',
|
|||
'tests/main.cpp',
|
||||
'tests/Lexer.cpp',
|
||||
'tests/Parser.cpp',
|
||||
'tests/Node.cpp',
|
||||
],
|
||||
dependencies: [
|
||||
sp_dep,
|
||||
|
|
181
src/fol/Node.cpp
181
src/fol/Node.cpp
|
@ -1,9 +1,168 @@
|
|||
#include "Node.hpp"
|
||||
#include "Lexer.hpp"
|
||||
#include "Parser.hpp"
|
||||
#include "Substitution.hpp"
|
||||
|
||||
namespace sp
|
||||
{
|
||||
namespace fol
|
||||
{
|
||||
/*static*/ NodeType Node::type_from_value(std::string const& value)
|
||||
{
|
||||
if (std::all_of(std::begin(value),
|
||||
std::end(value),
|
||||
[](auto const& el){
|
||||
return std::islower(el);}))
|
||||
{
|
||||
return NODE_VAR;
|
||||
}
|
||||
else if (std::all_of(std::begin(value),
|
||||
std::end(value),
|
||||
[](auto const& el){
|
||||
return std::isupper(el);}))
|
||||
{
|
||||
return NODE_CONST;
|
||||
}
|
||||
else
|
||||
{
|
||||
return NODE_PRED;
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/ std::shared_ptr<Node>
|
||||
Node::from_string(std::string const& formula)
|
||||
{
|
||||
Lexer lexer;
|
||||
lexer.scan(formula);
|
||||
Parser parser;
|
||||
return parser.parse(lexer.all());
|
||||
}
|
||||
|
||||
/*static*/ std::shared_ptr<Node>
|
||||
Node::apply(std::shared_ptr<Node> node,
|
||||
Substitution const& subst)
|
||||
{
|
||||
std::string value = node->value();
|
||||
NodeType type = node->type();
|
||||
|
||||
if (subst.has(value))
|
||||
{
|
||||
value = subst.get(value);
|
||||
type = Node::type_from_value(value);
|
||||
}
|
||||
|
||||
auto res = std::make_shared<Node>(type,
|
||||
value);
|
||||
|
||||
for (size_t i=0; i<node->size(); i++)
|
||||
{
|
||||
auto child = Node::apply(node->child(i), subst);
|
||||
res->add_child(child);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/*static*/ std::optional<Substitution>
|
||||
Node::unify(std::shared_ptr<Node> lhs,
|
||||
std::shared_ptr<Node> rhs)
|
||||
{
|
||||
if (lhs->type() == NODE_FORMULA)
|
||||
{
|
||||
return unify(lhs->child(0), rhs);
|
||||
}
|
||||
|
||||
if (rhs->type() == NODE_FORMULA)
|
||||
{
|
||||
return unify(lhs, rhs->child(0));
|
||||
}
|
||||
|
||||
if ((lhs->type() == NODE_AND
|
||||
|| lhs->type() == NODE_OR
|
||||
|| lhs->type() == NODE_NOT
|
||||
|| lhs->type() == NODE_IMP
|
||||
|| lhs->type() == NODE_FUNC)
|
||||
&& rhs->type() == lhs->type()
|
||||
&& lhs->size() == rhs->size())
|
||||
{
|
||||
Substitution total_subst;
|
||||
|
||||
for (size_t i=0; i<lhs->size(); i++)
|
||||
{
|
||||
auto subst = unify(lhs->child(i), rhs->child(i));
|
||||
|
||||
if (!subst)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto cat = Substitution::concat(total_subst, *subst);
|
||||
cat)
|
||||
{
|
||||
total_subst = *cat;
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
return total_subst;
|
||||
}
|
||||
|
||||
if (lhs->equals(*rhs))
|
||||
{
|
||||
return Substitution {};
|
||||
}
|
||||
|
||||
if (rhs->type() == NODE_VAR)
|
||||
{
|
||||
std::swap(lhs, rhs);
|
||||
}
|
||||
|
||||
if (lhs->type() == NODE_PRED
|
||||
&& rhs->type() == NODE_PRED
|
||||
&& lhs->value() == rhs->value()
|
||||
&& lhs->size() == rhs->size())
|
||||
{
|
||||
Substitution pred_subst;
|
||||
|
||||
for (size_t i=0; i<lhs->size(); i++)
|
||||
{
|
||||
auto subst = Node::unify(lhs->child(i),
|
||||
rhs->child(i));
|
||||
if (!subst)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto cat=Substitution::concat(pred_subst, *subst);
|
||||
cat)
|
||||
{
|
||||
pred_subst = *cat;
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
return pred_subst;
|
||||
}
|
||||
|
||||
if (lhs->type() == NODE_VAR
|
||||
&& (rhs->type() == NODE_VAR
|
||||
|| rhs->type() == NODE_CONST))
|
||||
{
|
||||
Substitution subst;
|
||||
subst.set(lhs->value(), rhs->value());
|
||||
|
||||
return subst;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/*explicit*/ Node::Node(NodeType type, std::string const& value)
|
||||
: m_type { type }
|
||||
, m_value { value }
|
||||
|
@ -19,7 +178,7 @@ namespace sp
|
|||
m_children.push_back(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> Node::child(size_t index)
|
||||
std::shared_ptr<Node> Node::child(size_t index) const
|
||||
{
|
||||
SP_ASSERT(index < size(), "cannot get child at index '"
|
||||
+ std::to_string(index)
|
||||
|
@ -54,5 +213,25 @@ namespace sp
|
|||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool Node::equals(Node const& rhs) const
|
||||
{
|
||||
if (m_type != rhs.m_type
|
||||
|| m_value != rhs.m_value
|
||||
|| size() != rhs.size())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i=0; i<size(); i++)
|
||||
{
|
||||
if ( !child(i)->equals(*rhs.child(i)) )
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,10 +27,19 @@ namespace sp
|
|||
namespace fol
|
||||
{
|
||||
SP_ENUM(NodeType, NODE_TYPES);
|
||||
class Substitution;
|
||||
|
||||
class Node
|
||||
{
|
||||
public:
|
||||
static NodeType type_from_value(std::string const& value);
|
||||
static std::shared_ptr<Node> from_string(std::string const& formula);
|
||||
static std::shared_ptr<Node> apply(std::shared_ptr<Node> node,
|
||||
Substitution const& subst);
|
||||
|
||||
static std::optional<Substitution> unify(std::shared_ptr<Node> lhs,
|
||||
std::shared_ptr<Node> rhs);
|
||||
|
||||
explicit Node(NodeType type, std::string const& value);
|
||||
virtual ~Node();
|
||||
|
||||
|
@ -39,10 +48,12 @@ namespace sp
|
|||
|
||||
size_t size() const { return m_children.size(); }
|
||||
void add_child(std::shared_ptr<Node> child);
|
||||
std::shared_ptr<Node> child(size_t index);
|
||||
std::shared_ptr<Node> child(size_t index) const;
|
||||
|
||||
std::string string() const;
|
||||
|
||||
bool equals(Node const& rhs) const;
|
||||
|
||||
private:
|
||||
NodeType m_type;
|
||||
std::string m_value;
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
#include "Substitution.hpp"
|
||||
|
||||
namespace sp
|
||||
{
|
||||
namespace fol
|
||||
{
|
||||
/*static*/ std::optional<Substitution>
|
||||
Substitution::concat(Substitution const& lhs,
|
||||
Substitution const& rhs)
|
||||
{
|
||||
auto res = lhs;
|
||||
|
||||
for (auto entry: rhs.m_substs)
|
||||
{
|
||||
if (res.has(entry.first)
|
||||
&& res.get(entry.first) != entry.second)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!res.has(entry.first))
|
||||
{
|
||||
res.set(entry.first, entry.second);
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/*explicit*/ Substitution::Substitution()
|
||||
{
|
||||
}
|
||||
|
||||
/*virtual*/ Substitution::~Substitution()
|
||||
{
|
||||
}
|
||||
|
||||
void Substitution::set(std::string const& from, std::string const& to)
|
||||
{
|
||||
if (auto itr=m_substs.find(from);
|
||||
itr != std::end(m_substs))
|
||||
{
|
||||
throw substitution_error {
|
||||
"'" + from + "' already defined in substitution"
|
||||
};
|
||||
}
|
||||
|
||||
m_substs[from] = to;
|
||||
}
|
||||
|
||||
bool Substitution::has(std::string const& from) const
|
||||
{
|
||||
return m_substs.find(from) != std::end(m_substs);
|
||||
}
|
||||
|
||||
std::string Substitution::get(std::string const& from) const
|
||||
{
|
||||
SP_ASSERT(has(from), "cannot get subtitution for '" + from + "'");
|
||||
|
||||
auto result = m_substs.at(from);
|
||||
|
||||
if (has(result))
|
||||
{
|
||||
return get(result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Substitution::string() const
|
||||
{
|
||||
std::stringstream ss;
|
||||
std::string sep = "";
|
||||
|
||||
for (auto entry: m_substs)
|
||||
{
|
||||
ss << sep << entry.first << " => " << entry.second;
|
||||
sep = ", ";
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
#ifndef sp_fol_SUBSTITUTION_HPP
|
||||
#define sp_fol_SUBSTITUTION_HPP
|
||||
|
||||
#include "../commons.hpp"
|
||||
|
||||
namespace sp
|
||||
{
|
||||
namespace fol
|
||||
{
|
||||
SP_ERROR(substitution_error);
|
||||
|
||||
class Substitution
|
||||
{
|
||||
public:
|
||||
static std::optional<Substitution>
|
||||
concat(Substitution const& lhs,
|
||||
Substitution const& rhs);
|
||||
|
||||
explicit Substitution();
|
||||
virtual ~Substitution();
|
||||
|
||||
size_t size() const { return m_substs.size(); }
|
||||
|
||||
void set(std::string const& from, std::string const& to);
|
||||
bool has(std::string const& from) const;
|
||||
std::string get(std::string const& from) const;
|
||||
|
||||
std::string string() const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> m_substs;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,277 @@
|
|||
#include <catch2/catch.hpp>
|
||||
#include "../src/fol/Node.hpp"
|
||||
#include "../src/fol/Substitution.hpp"
|
||||
|
||||
using namespace sp::fol;
|
||||
|
||||
class NodeTest
|
||||
{
|
||||
public:
|
||||
explicit NodeTest() {}
|
||||
virtual ~NodeTest() {}
|
||||
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_CASE_METHOD(NodeTest, "Node_equals")
|
||||
{
|
||||
auto a = Node::from_string("Happy(x) -> Happy(y)");
|
||||
auto b = Node::from_string("Happy(x) -> Happy(y)");
|
||||
auto c = Node::from_string("Happy(x) -> Happy(x)");
|
||||
auto d = Node::from_string("Happy(x) & Happy(x)");
|
||||
auto e = Node::from_string("Happy(x) | Happy(x)");
|
||||
auto f = Node::from_string("Happy(x)");
|
||||
auto g = Node::from_string("!Happy(x)");
|
||||
|
||||
REQUIRE(a->equals(*b));
|
||||
REQUIRE(!a->equals(*c));
|
||||
REQUIRE(!d->equals(*e));
|
||||
REQUIRE(f->equals(*f));
|
||||
REQUIRE(!f->equals(*g));
|
||||
}
|
||||
|
||||
TEST_CASE_METHOD(NodeTest, "Node_substitution")
|
||||
{
|
||||
auto node = Node::from_string(" Friend(x, y) -> Friend(y, x) ");
|
||||
|
||||
SECTION("simple")
|
||||
{
|
||||
Substitution sub;
|
||||
sub.set("x", "a");
|
||||
auto oracle = Node::from_string("Friend(a, y) -> Friend(y, a)");
|
||||
|
||||
auto res = Node::apply(node, sub);
|
||||
|
||||
REQUIRE(res->equals(*oracle));
|
||||
}
|
||||
|
||||
SECTION("multiple")
|
||||
{
|
||||
Substitution sub;
|
||||
sub.set("x", "a");
|
||||
sub.set("a", "b");
|
||||
sub.set("b", "c");
|
||||
|
||||
auto oracle = Node::from_string("Friend(c, y) -> Friend(y, c)");
|
||||
|
||||
auto res = Node::apply(node, sub);
|
||||
|
||||
REQUIRE(res->equals(*oracle));
|
||||
}
|
||||
|
||||
SECTION("multiple v2")
|
||||
{
|
||||
Substitution sub;
|
||||
sub.set("x", "a");
|
||||
sub.set("y", "b");
|
||||
sub.set("b", "c");
|
||||
|
||||
auto oracle = Node::from_string("Friend(a, c) -> Friend(c, a)");
|
||||
|
||||
auto res = Node::apply(node, sub);
|
||||
|
||||
REQUIRE(res->equals(*oracle));
|
||||
}
|
||||
|
||||
SECTION("multiple const")
|
||||
{
|
||||
Substitution sub;
|
||||
sub.set("x", "WORLD");
|
||||
|
||||
auto oracle =
|
||||
Node::from_string("Friend(WORLD, y) -> Friend(y, WORLD)");
|
||||
|
||||
auto res = Node::apply(node, sub);
|
||||
INFO("oracle = " << oracle->string());
|
||||
INFO("result = " << res->string());
|
||||
|
||||
REQUIRE(res->equals(*oracle));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE_METHOD(NodeTest, "Node_unify")
|
||||
{
|
||||
SECTION("predicate: base case")
|
||||
{
|
||||
auto lhs = Node::from_string(" Friend(A, y) ");
|
||||
auto rhs = Node::from_string(" Friend(x, y) ");
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
INFO("result is nullopt ? " << (result == std::nullopt));
|
||||
REQUIRE(std::nullopt != result);
|
||||
INFO("result size: " << result->size());
|
||||
REQUIRE(1 == result->size());
|
||||
REQUIRE(result->has("x"));
|
||||
REQUIRE("A" == result->get("x"));
|
||||
}
|
||||
|
||||
SECTION("predicate: two substitutions")
|
||||
{
|
||||
auto lhs = Node::from_string(" Friend(x, A, z) ");
|
||||
auto rhs = Node::from_string(" Friend(x, y, B) ");
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
INFO("result is nullopt ? " << (result == std::nullopt));
|
||||
REQUIRE(std::nullopt != result);
|
||||
INFO("result size: " << result->size());
|
||||
REQUIRE(2 == result->size());
|
||||
REQUIRE(result->has("y"));
|
||||
REQUIRE(result->has("z"));
|
||||
REQUIRE("A" == result->get("y"));
|
||||
REQUIRE("B" == result->get("z"));
|
||||
}
|
||||
|
||||
SECTION("predicate: no substitution found")
|
||||
{
|
||||
auto lhs = Node::from_string(" Friend(x, A, z) ");
|
||||
auto rhs = Node::from_string(" Friend(x, C, B) ");
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
|
||||
REQUIRE(std::nullopt == result);
|
||||
}
|
||||
|
||||
SECTION("predicate: wrong name")
|
||||
{
|
||||
auto lhs = Node::from_string(" Friendly(A, y, C) ");
|
||||
auto rhs = Node::from_string(" Friend(x, B, z) ");
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
|
||||
REQUIRE(std::nullopt == result);
|
||||
}
|
||||
|
||||
SECTION("predicate: three substitutions")
|
||||
{
|
||||
auto lhs = Node::from_string(" Friend(A, y, C) ");
|
||||
auto rhs = Node::from_string(" Friend(x, B, z) ");
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
INFO("result is nullopt ? " << (result == std::nullopt));
|
||||
REQUIRE(std::nullopt != result);
|
||||
INFO("result size: " << result->size());
|
||||
REQUIRE(3 == result->size());
|
||||
REQUIRE(result->has("x"));
|
||||
REQUIRE(result->has("y"));
|
||||
REQUIRE(result->has("z"));
|
||||
REQUIRE("A" == result->get("x"));
|
||||
REQUIRE("B" == result->get("y"));
|
||||
REQUIRE("C" == result->get("z"));
|
||||
}
|
||||
|
||||
SECTION("predicate: and operation")
|
||||
{
|
||||
auto lhs = Node::from_string(" Friend(x, y) & Friend(y, x) ");
|
||||
auto rhs = Node::from_string(" Friend(JOHN, y) & Friend(y, x) ");
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
|
||||
REQUIRE(std::nullopt != result);
|
||||
|
||||
REQUIRE(1 == result->size());
|
||||
REQUIRE(result->has("x"));
|
||||
REQUIRE("JOHN" == result->get("x"));
|
||||
}
|
||||
|
||||
SECTION("predicate: or operation")
|
||||
{
|
||||
auto lhs = Node::from_string(" Friend(x, y) | Friend(y, x) ");
|
||||
auto rhs = Node::from_string(" Friend(JOHN, y) | Friend(BOB, x) ");
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
|
||||
REQUIRE(std::nullopt != result);
|
||||
|
||||
REQUIRE(2 == result->size());
|
||||
REQUIRE(result->has("x"));
|
||||
REQUIRE("JOHN" == result->get("x"));
|
||||
REQUIRE(result->has("y"));
|
||||
REQUIRE("BOB" == result->get("y"));
|
||||
}
|
||||
|
||||
SECTION("predicate: imp operation")
|
||||
{
|
||||
auto lhs = Node::from_string(" Friend(x, y) -> Friend(y, x) ");
|
||||
auto rhs = Node::from_string(" Friend(JOHN, y) -> Friend(BOB, x) ");
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
|
||||
REQUIRE(std::nullopt != result);
|
||||
|
||||
REQUIRE(2 == result->size());
|
||||
REQUIRE(result->has("x"));
|
||||
REQUIRE("JOHN" == result->get("x"));
|
||||
REQUIRE(result->has("y"));
|
||||
REQUIRE("BOB" == result->get("y"));
|
||||
}
|
||||
|
||||
SECTION("predicate: doubles")
|
||||
{
|
||||
auto lhs = Node::from_string(" Friend(x, y) -> Friend(y, JOHN) ");
|
||||
auto rhs = Node::from_string(" Friend(JOHN, y) -> Friend(BOB, x) ");
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
|
||||
REQUIRE(std::nullopt != result);
|
||||
|
||||
REQUIRE(2 == result->size());
|
||||
REQUIRE(result->has("x"));
|
||||
REQUIRE("JOHN" == result->get("x"));
|
||||
REQUIRE(result->has("y"));
|
||||
REQUIRE("BOB" == result->get("y"));
|
||||
}
|
||||
|
||||
SECTION("doubles inconsistence")
|
||||
{
|
||||
auto lhs = Node::from_string(" Friend(x, y) -> Friend(y, ALICE) ");
|
||||
auto rhs = Node::from_string(" Friend(JOHN, y) -> Friend(BOB, x) ");
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
|
||||
REQUIRE(std::nullopt == result);
|
||||
}
|
||||
|
||||
SECTION("compound formula")
|
||||
{
|
||||
auto lhs =
|
||||
Node::from_string(" (Friend(x, y) & Friend(y, x)) -> Ok(x, y)");
|
||||
|
||||
auto rhs =
|
||||
Node::from_string(" (Friend(x, ALICE) & Friend(y, BOB)) "
|
||||
"-> Ok(x, ALICE)");
|
||||
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
|
||||
REQUIRE(std::nullopt != result);
|
||||
|
||||
REQUIRE(2 == result->size());
|
||||
REQUIRE(result->has("x"));
|
||||
REQUIRE("BOB" == result->get("x"));
|
||||
REQUIRE(result->has("y"));
|
||||
REQUIRE("ALICE" == result->get("y"));
|
||||
}
|
||||
|
||||
SECTION("functions I")
|
||||
{
|
||||
auto lhs =
|
||||
Node::from_string(" Happy(x) -> Happy(father(x))");
|
||||
|
||||
auto rhs =
|
||||
Node::from_string(" Happy(BOB) -> Happy(father(x))");
|
||||
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
|
||||
REQUIRE(std::nullopt != result);
|
||||
|
||||
REQUIRE(1 == result->size());
|
||||
REQUIRE(result->has("x"));
|
||||
REQUIRE("BOB" == result->get("x"));
|
||||
}
|
||||
|
||||
SECTION("functions II")
|
||||
{
|
||||
auto lhs =
|
||||
Node::from_string(" Happy(x) -> Happy(father(x))");
|
||||
|
||||
auto rhs =
|
||||
Node::from_string(" Happy(x) -> Happy(father(CLAIRE))");
|
||||
|
||||
auto result = Node::unify(lhs, rhs);
|
||||
|
||||
REQUIRE(std::nullopt != result);
|
||||
|
||||
REQUIRE(1 == result->size());
|
||||
REQUIRE(result->has("x"));
|
||||
REQUIRE("CLAIRE" == result->get("x"));
|
||||
}
|
||||
}
|
Reference in New Issue