node substitution and unification.

main
bog 2023-10-10 18:32:12 +02:00
parent 93f7430e11
commit cd77e92682
6 changed files with 591 additions and 2 deletions

View File

@ -13,6 +13,7 @@ sp_lib = static_library(
'src/fol/Node.cpp', 'src/fol/Node.cpp',
'src/fol/Lexer.cpp', 'src/fol/Lexer.cpp',
'src/fol/Parser.cpp', 'src/fol/Parser.cpp',
'src/fol/Substitution.cpp',
] ]
) )
@ -32,6 +33,7 @@ executable('sine-patre-tests',
'tests/main.cpp', 'tests/main.cpp',
'tests/Lexer.cpp', 'tests/Lexer.cpp',
'tests/Parser.cpp', 'tests/Parser.cpp',
'tests/Node.cpp',
], ],
dependencies: [ dependencies: [
sp_dep, sp_dep,

View File

@ -1,9 +1,168 @@
#include "Node.hpp" #include "Node.hpp"
#include "Lexer.hpp"
#include "Parser.hpp"
#include "Substitution.hpp"
namespace sp namespace sp
{ {
namespace fol namespace fol
{ {
/*static*/ NodeType Node::type_from_value(std::string const& value)
{
if (std::all_of(std::begin(value),
std::end(value),
[](auto const& el){
return std::islower(el);}))
{
return NODE_VAR;
}
else if (std::all_of(std::begin(value),
std::end(value),
[](auto const& el){
return std::isupper(el);}))
{
return NODE_CONST;
}
else
{
return NODE_PRED;
}
}
/*static*/ std::shared_ptr<Node>
Node::from_string(std::string const& formula)
{
Lexer lexer;
lexer.scan(formula);
Parser parser;
return parser.parse(lexer.all());
}
/*static*/ std::shared_ptr<Node>
Node::apply(std::shared_ptr<Node> node,
Substitution const& subst)
{
std::string value = node->value();
NodeType type = node->type();
if (subst.has(value))
{
value = subst.get(value);
type = Node::type_from_value(value);
}
auto res = std::make_shared<Node>(type,
value);
for (size_t i=0; i<node->size(); i++)
{
auto child = Node::apply(node->child(i), subst);
res->add_child(child);
}
return res;
}
/*static*/ std::optional<Substitution>
Node::unify(std::shared_ptr<Node> lhs,
std::shared_ptr<Node> rhs)
{
if (lhs->type() == NODE_FORMULA)
{
return unify(lhs->child(0), rhs);
}
if (rhs->type() == NODE_FORMULA)
{
return unify(lhs, rhs->child(0));
}
if ((lhs->type() == NODE_AND
|| lhs->type() == NODE_OR
|| lhs->type() == NODE_NOT
|| lhs->type() == NODE_IMP
|| lhs->type() == NODE_FUNC)
&& rhs->type() == lhs->type()
&& lhs->size() == rhs->size())
{
Substitution total_subst;
for (size_t i=0; i<lhs->size(); i++)
{
auto subst = unify(lhs->child(i), rhs->child(i));
if (!subst)
{
return std::nullopt;
}
if (auto cat = Substitution::concat(total_subst, *subst);
cat)
{
total_subst = *cat;
}
else
{
return std::nullopt;
}
}
return total_subst;
}
if (lhs->equals(*rhs))
{
return Substitution {};
}
if (rhs->type() == NODE_VAR)
{
std::swap(lhs, rhs);
}
if (lhs->type() == NODE_PRED
&& rhs->type() == NODE_PRED
&& lhs->value() == rhs->value()
&& lhs->size() == rhs->size())
{
Substitution pred_subst;
for (size_t i=0; i<lhs->size(); i++)
{
auto subst = Node::unify(lhs->child(i),
rhs->child(i));
if (!subst)
{
return std::nullopt;
}
if (auto cat=Substitution::concat(pred_subst, *subst);
cat)
{
pred_subst = *cat;
}
else
{
return std::nullopt;
}
}
return pred_subst;
}
if (lhs->type() == NODE_VAR
&& (rhs->type() == NODE_VAR
|| rhs->type() == NODE_CONST))
{
Substitution subst;
subst.set(lhs->value(), rhs->value());
return subst;
}
return std::nullopt;
}
/*explicit*/ Node::Node(NodeType type, std::string const& value) /*explicit*/ Node::Node(NodeType type, std::string const& value)
: m_type { type } : m_type { type }
, m_value { value } , m_value { value }
@ -19,7 +178,7 @@ namespace sp
m_children.push_back(child); m_children.push_back(child);
} }
std::shared_ptr<Node> Node::child(size_t index) std::shared_ptr<Node> Node::child(size_t index) const
{ {
SP_ASSERT(index < size(), "cannot get child at index '" SP_ASSERT(index < size(), "cannot get child at index '"
+ std::to_string(index) + std::to_string(index)
@ -54,5 +213,25 @@ namespace sp
return ss.str(); return ss.str();
} }
bool Node::equals(Node const& rhs) const
{
if (m_type != rhs.m_type
|| m_value != rhs.m_value
|| size() != rhs.size())
{
return false;
}
for (size_t i=0; i<size(); i++)
{
if ( !child(i)->equals(*rhs.child(i)) )
{
return false;
}
}
return true;
}
} }
} }

View File

@ -27,10 +27,19 @@ namespace sp
namespace fol namespace fol
{ {
SP_ENUM(NodeType, NODE_TYPES); SP_ENUM(NodeType, NODE_TYPES);
class Substitution;
class Node class Node
{ {
public: public:
static NodeType type_from_value(std::string const& value);
static std::shared_ptr<Node> from_string(std::string const& formula);
static std::shared_ptr<Node> apply(std::shared_ptr<Node> node,
Substitution const& subst);
static std::optional<Substitution> unify(std::shared_ptr<Node> lhs,
std::shared_ptr<Node> rhs);
explicit Node(NodeType type, std::string const& value); explicit Node(NodeType type, std::string const& value);
virtual ~Node(); virtual ~Node();
@ -39,10 +48,12 @@ namespace sp
size_t size() const { return m_children.size(); } size_t size() const { return m_children.size(); }
void add_child(std::shared_ptr<Node> child); void add_child(std::shared_ptr<Node> child);
std::shared_ptr<Node> child(size_t index); std::shared_ptr<Node> child(size_t index) const;
std::string string() const; std::string string() const;
bool equals(Node const& rhs) const;
private: private:
NodeType m_type; NodeType m_type;
std::string m_value; std::string m_value;

84
src/fol/Substitution.cpp Normal file
View File

@ -0,0 +1,84 @@
#include "Substitution.hpp"
namespace sp
{
namespace fol
{
/*static*/ std::optional<Substitution>
Substitution::concat(Substitution const& lhs,
Substitution const& rhs)
{
auto res = lhs;
for (auto entry: rhs.m_substs)
{
if (res.has(entry.first)
&& res.get(entry.first) != entry.second)
{
return std::nullopt;
}
if (!res.has(entry.first))
{
res.set(entry.first, entry.second);
}
}
return res;
}
/*explicit*/ Substitution::Substitution()
{
}
/*virtual*/ Substitution::~Substitution()
{
}
void Substitution::set(std::string const& from, std::string const& to)
{
if (auto itr=m_substs.find(from);
itr != std::end(m_substs))
{
throw substitution_error {
"'" + from + "' already defined in substitution"
};
}
m_substs[from] = to;
}
bool Substitution::has(std::string const& from) const
{
return m_substs.find(from) != std::end(m_substs);
}
std::string Substitution::get(std::string const& from) const
{
SP_ASSERT(has(from), "cannot get subtitution for '" + from + "'");
auto result = m_substs.at(from);
if (has(result))
{
return get(result);
}
return result;
}
std::string Substitution::string() const
{
std::stringstream ss;
std::string sep = "";
for (auto entry: m_substs)
{
ss << sep << entry.first << " => " << entry.second;
sep = ", ";
}
return ss.str();
}
}
}

36
src/fol/Substitution.hpp Normal file
View File

@ -0,0 +1,36 @@
#ifndef sp_fol_SUBSTITUTION_HPP
#define sp_fol_SUBSTITUTION_HPP
#include "../commons.hpp"
namespace sp
{
namespace fol
{
SP_ERROR(substitution_error);
class Substitution
{
public:
static std::optional<Substitution>
concat(Substitution const& lhs,
Substitution const& rhs);
explicit Substitution();
virtual ~Substitution();
size_t size() const { return m_substs.size(); }
void set(std::string const& from, std::string const& to);
bool has(std::string const& from) const;
std::string get(std::string const& from) const;
std::string string() const;
private:
std::unordered_map<std::string, std::string> m_substs;
};
}
}
#endif

277
tests/Node.cpp Normal file
View File

@ -0,0 +1,277 @@
#include <catch2/catch.hpp>
#include "../src/fol/Node.hpp"
#include "../src/fol/Substitution.hpp"
using namespace sp::fol;
class NodeTest
{
public:
explicit NodeTest() {}
virtual ~NodeTest() {}
protected:
};
TEST_CASE_METHOD(NodeTest, "Node_equals")
{
auto a = Node::from_string("Happy(x) -> Happy(y)");
auto b = Node::from_string("Happy(x) -> Happy(y)");
auto c = Node::from_string("Happy(x) -> Happy(x)");
auto d = Node::from_string("Happy(x) & Happy(x)");
auto e = Node::from_string("Happy(x) | Happy(x)");
auto f = Node::from_string("Happy(x)");
auto g = Node::from_string("!Happy(x)");
REQUIRE(a->equals(*b));
REQUIRE(!a->equals(*c));
REQUIRE(!d->equals(*e));
REQUIRE(f->equals(*f));
REQUIRE(!f->equals(*g));
}
TEST_CASE_METHOD(NodeTest, "Node_substitution")
{
auto node = Node::from_string(" Friend(x, y) -> Friend(y, x) ");
SECTION("simple")
{
Substitution sub;
sub.set("x", "a");
auto oracle = Node::from_string("Friend(a, y) -> Friend(y, a)");
auto res = Node::apply(node, sub);
REQUIRE(res->equals(*oracle));
}
SECTION("multiple")
{
Substitution sub;
sub.set("x", "a");
sub.set("a", "b");
sub.set("b", "c");
auto oracle = Node::from_string("Friend(c, y) -> Friend(y, c)");
auto res = Node::apply(node, sub);
REQUIRE(res->equals(*oracle));
}
SECTION("multiple v2")
{
Substitution sub;
sub.set("x", "a");
sub.set("y", "b");
sub.set("b", "c");
auto oracle = Node::from_string("Friend(a, c) -> Friend(c, a)");
auto res = Node::apply(node, sub);
REQUIRE(res->equals(*oracle));
}
SECTION("multiple const")
{
Substitution sub;
sub.set("x", "WORLD");
auto oracle =
Node::from_string("Friend(WORLD, y) -> Friend(y, WORLD)");
auto res = Node::apply(node, sub);
INFO("oracle = " << oracle->string());
INFO("result = " << res->string());
REQUIRE(res->equals(*oracle));
}
}
TEST_CASE_METHOD(NodeTest, "Node_unify")
{
SECTION("predicate: base case")
{
auto lhs = Node::from_string(" Friend(A, y) ");
auto rhs = Node::from_string(" Friend(x, y) ");
auto result = Node::unify(lhs, rhs);
INFO("result is nullopt ? " << (result == std::nullopt));
REQUIRE(std::nullopt != result);
INFO("result size: " << result->size());
REQUIRE(1 == result->size());
REQUIRE(result->has("x"));
REQUIRE("A" == result->get("x"));
}
SECTION("predicate: two substitutions")
{
auto lhs = Node::from_string(" Friend(x, A, z) ");
auto rhs = Node::from_string(" Friend(x, y, B) ");
auto result = Node::unify(lhs, rhs);
INFO("result is nullopt ? " << (result == std::nullopt));
REQUIRE(std::nullopt != result);
INFO("result size: " << result->size());
REQUIRE(2 == result->size());
REQUIRE(result->has("y"));
REQUIRE(result->has("z"));
REQUIRE("A" == result->get("y"));
REQUIRE("B" == result->get("z"));
}
SECTION("predicate: no substitution found")
{
auto lhs = Node::from_string(" Friend(x, A, z) ");
auto rhs = Node::from_string(" Friend(x, C, B) ");
auto result = Node::unify(lhs, rhs);
REQUIRE(std::nullopt == result);
}
SECTION("predicate: wrong name")
{
auto lhs = Node::from_string(" Friendly(A, y, C) ");
auto rhs = Node::from_string(" Friend(x, B, z) ");
auto result = Node::unify(lhs, rhs);
REQUIRE(std::nullopt == result);
}
SECTION("predicate: three substitutions")
{
auto lhs = Node::from_string(" Friend(A, y, C) ");
auto rhs = Node::from_string(" Friend(x, B, z) ");
auto result = Node::unify(lhs, rhs);
INFO("result is nullopt ? " << (result == std::nullopt));
REQUIRE(std::nullopt != result);
INFO("result size: " << result->size());
REQUIRE(3 == result->size());
REQUIRE(result->has("x"));
REQUIRE(result->has("y"));
REQUIRE(result->has("z"));
REQUIRE("A" == result->get("x"));
REQUIRE("B" == result->get("y"));
REQUIRE("C" == result->get("z"));
}
SECTION("predicate: and operation")
{
auto lhs = Node::from_string(" Friend(x, y) & Friend(y, x) ");
auto rhs = Node::from_string(" Friend(JOHN, y) & Friend(y, x) ");
auto result = Node::unify(lhs, rhs);
REQUIRE(std::nullopt != result);
REQUIRE(1 == result->size());
REQUIRE(result->has("x"));
REQUIRE("JOHN" == result->get("x"));
}
SECTION("predicate: or operation")
{
auto lhs = Node::from_string(" Friend(x, y) | Friend(y, x) ");
auto rhs = Node::from_string(" Friend(JOHN, y) | Friend(BOB, x) ");
auto result = Node::unify(lhs, rhs);
REQUIRE(std::nullopt != result);
REQUIRE(2 == result->size());
REQUIRE(result->has("x"));
REQUIRE("JOHN" == result->get("x"));
REQUIRE(result->has("y"));
REQUIRE("BOB" == result->get("y"));
}
SECTION("predicate: imp operation")
{
auto lhs = Node::from_string(" Friend(x, y) -> Friend(y, x) ");
auto rhs = Node::from_string(" Friend(JOHN, y) -> Friend(BOB, x) ");
auto result = Node::unify(lhs, rhs);
REQUIRE(std::nullopt != result);
REQUIRE(2 == result->size());
REQUIRE(result->has("x"));
REQUIRE("JOHN" == result->get("x"));
REQUIRE(result->has("y"));
REQUIRE("BOB" == result->get("y"));
}
SECTION("predicate: doubles")
{
auto lhs = Node::from_string(" Friend(x, y) -> Friend(y, JOHN) ");
auto rhs = Node::from_string(" Friend(JOHN, y) -> Friend(BOB, x) ");
auto result = Node::unify(lhs, rhs);
REQUIRE(std::nullopt != result);
REQUIRE(2 == result->size());
REQUIRE(result->has("x"));
REQUIRE("JOHN" == result->get("x"));
REQUIRE(result->has("y"));
REQUIRE("BOB" == result->get("y"));
}
SECTION("doubles inconsistence")
{
auto lhs = Node::from_string(" Friend(x, y) -> Friend(y, ALICE) ");
auto rhs = Node::from_string(" Friend(JOHN, y) -> Friend(BOB, x) ");
auto result = Node::unify(lhs, rhs);
REQUIRE(std::nullopt == result);
}
SECTION("compound formula")
{
auto lhs =
Node::from_string(" (Friend(x, y) & Friend(y, x)) -> Ok(x, y)");
auto rhs =
Node::from_string(" (Friend(x, ALICE) & Friend(y, BOB)) "
"-> Ok(x, ALICE)");
auto result = Node::unify(lhs, rhs);
REQUIRE(std::nullopt != result);
REQUIRE(2 == result->size());
REQUIRE(result->has("x"));
REQUIRE("BOB" == result->get("x"));
REQUIRE(result->has("y"));
REQUIRE("ALICE" == result->get("y"));
}
SECTION("functions I")
{
auto lhs =
Node::from_string(" Happy(x) -> Happy(father(x))");
auto rhs =
Node::from_string(" Happy(BOB) -> Happy(father(x))");
auto result = Node::unify(lhs, rhs);
REQUIRE(std::nullopt != result);
REQUIRE(1 == result->size());
REQUIRE(result->has("x"));
REQUIRE("BOB" == result->get("x"));
}
SECTION("functions II")
{
auto lhs =
Node::from_string(" Happy(x) -> Happy(father(x))");
auto rhs =
Node::from_string(" Happy(x) -> Happy(father(CLAIRE))");
auto result = Node::unify(lhs, rhs);
REQUIRE(std::nullopt != result);
REQUIRE(1 == result->size());
REQUIRE(result->has("x"));
REQUIRE("CLAIRE" == result->get("x"));
}
}