From cd77e92682acaa0dec29225176f191f2e2c562ca Mon Sep 17 00:00:00 2001 From: bog Date: Tue, 10 Oct 2023 18:32:12 +0200 Subject: [PATCH] :sparkles: node substitution and unification. --- meson.build | 2 + src/fol/Node.cpp | 181 ++++++++++++++++++++++++- src/fol/Node.hpp | 13 +- src/fol/Substitution.cpp | 84 ++++++++++++ src/fol/Substitution.hpp | 36 +++++ tests/Node.cpp | 277 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 591 insertions(+), 2 deletions(-) create mode 100644 src/fol/Substitution.cpp create mode 100644 src/fol/Substitution.hpp create mode 100644 tests/Node.cpp diff --git a/meson.build b/meson.build index 47e7edf..a115be9 100644 --- a/meson.build +++ b/meson.build @@ -13,6 +13,7 @@ sp_lib = static_library( 'src/fol/Node.cpp', 'src/fol/Lexer.cpp', 'src/fol/Parser.cpp', + 'src/fol/Substitution.cpp', ] ) @@ -32,6 +33,7 @@ executable('sine-patre-tests', 'tests/main.cpp', 'tests/Lexer.cpp', 'tests/Parser.cpp', + 'tests/Node.cpp', ], dependencies: [ sp_dep, diff --git a/src/fol/Node.cpp b/src/fol/Node.cpp index 76345da..724ac03 100644 --- a/src/fol/Node.cpp +++ b/src/fol/Node.cpp @@ -1,9 +1,168 @@ #include "Node.hpp" +#include "Lexer.hpp" +#include "Parser.hpp" +#include "Substitution.hpp" namespace sp { namespace fol { + /*static*/ NodeType Node::type_from_value(std::string const& value) + { + if (std::all_of(std::begin(value), + std::end(value), + [](auto const& el){ + return std::islower(el);})) + { + return NODE_VAR; + } + else if (std::all_of(std::begin(value), + std::end(value), + [](auto const& el){ + return std::isupper(el);})) + { + return NODE_CONST; + } + else + { + return NODE_PRED; + } + } + + /*static*/ std::shared_ptr + Node::from_string(std::string const& formula) + { + Lexer lexer; + lexer.scan(formula); + Parser parser; + return parser.parse(lexer.all()); + } + + /*static*/ std::shared_ptr + Node::apply(std::shared_ptr node, + Substitution const& subst) + { + std::string value = node->value(); + NodeType type = node->type(); + + if (subst.has(value)) + { + value = subst.get(value); + type = Node::type_from_value(value); + } + + auto res = std::make_shared(type, + value); + + for (size_t i=0; isize(); i++) + { + auto child = Node::apply(node->child(i), subst); + res->add_child(child); + } + + return res; + } + + /*static*/ std::optional + Node::unify(std::shared_ptr lhs, + std::shared_ptr rhs) + { + if (lhs->type() == NODE_FORMULA) + { + return unify(lhs->child(0), rhs); + } + + if (rhs->type() == NODE_FORMULA) + { + return unify(lhs, rhs->child(0)); + } + + if ((lhs->type() == NODE_AND + || lhs->type() == NODE_OR + || lhs->type() == NODE_NOT + || lhs->type() == NODE_IMP + || lhs->type() == NODE_FUNC) + && rhs->type() == lhs->type() + && lhs->size() == rhs->size()) + { + Substitution total_subst; + + for (size_t i=0; isize(); i++) + { + auto subst = unify(lhs->child(i), rhs->child(i)); + + if (!subst) + { + return std::nullopt; + } + + if (auto cat = Substitution::concat(total_subst, *subst); + cat) + { + total_subst = *cat; + } + else + { + return std::nullopt; + } + } + + return total_subst; + } + + if (lhs->equals(*rhs)) + { + return Substitution {}; + } + + if (rhs->type() == NODE_VAR) + { + std::swap(lhs, rhs); + } + + if (lhs->type() == NODE_PRED + && rhs->type() == NODE_PRED + && lhs->value() == rhs->value() + && lhs->size() == rhs->size()) + { + Substitution pred_subst; + + for (size_t i=0; isize(); i++) + { + auto subst = Node::unify(lhs->child(i), + rhs->child(i)); + if (!subst) + { + return std::nullopt; + } + + if (auto cat=Substitution::concat(pred_subst, *subst); + cat) + { + pred_subst = *cat; + } + else + { + return std::nullopt; + } + } + + return pred_subst; + } + + if (lhs->type() == NODE_VAR + && (rhs->type() == NODE_VAR + || rhs->type() == NODE_CONST)) + { + Substitution subst; + subst.set(lhs->value(), rhs->value()); + + return subst; + } + + return std::nullopt; + } + /*explicit*/ Node::Node(NodeType type, std::string const& value) : m_type { type } , m_value { value } @@ -19,7 +178,7 @@ namespace sp m_children.push_back(child); } - std::shared_ptr Node::child(size_t index) + std::shared_ptr Node::child(size_t index) const { SP_ASSERT(index < size(), "cannot get child at index '" + std::to_string(index) @@ -54,5 +213,25 @@ namespace sp return ss.str(); } + + bool Node::equals(Node const& rhs) const + { + if (m_type != rhs.m_type + || m_value != rhs.m_value + || size() != rhs.size()) + { + return false; + } + + for (size_t i=0; iequals(*rhs.child(i)) ) + { + return false; + } + } + + return true; + } } } diff --git a/src/fol/Node.hpp b/src/fol/Node.hpp index 881c41c..14916f5 100644 --- a/src/fol/Node.hpp +++ b/src/fol/Node.hpp @@ -27,10 +27,19 @@ namespace sp namespace fol { SP_ENUM(NodeType, NODE_TYPES); + class Substitution; class Node { public: + static NodeType type_from_value(std::string const& value); + static std::shared_ptr from_string(std::string const& formula); + static std::shared_ptr apply(std::shared_ptr node, + Substitution const& subst); + + static std::optional unify(std::shared_ptr lhs, + std::shared_ptr rhs); + explicit Node(NodeType type, std::string const& value); virtual ~Node(); @@ -39,10 +48,12 @@ namespace sp size_t size() const { return m_children.size(); } void add_child(std::shared_ptr child); - std::shared_ptr child(size_t index); + std::shared_ptr child(size_t index) const; std::string string() const; + bool equals(Node const& rhs) const; + private: NodeType m_type; std::string m_value; diff --git a/src/fol/Substitution.cpp b/src/fol/Substitution.cpp new file mode 100644 index 0000000..6d71208 --- /dev/null +++ b/src/fol/Substitution.cpp @@ -0,0 +1,84 @@ +#include "Substitution.hpp" + +namespace sp +{ + namespace fol + { + /*static*/ std::optional + Substitution::concat(Substitution const& lhs, + Substitution const& rhs) + { + auto res = lhs; + + for (auto entry: rhs.m_substs) + { + if (res.has(entry.first) + && res.get(entry.first) != entry.second) + { + return std::nullopt; + } + + if (!res.has(entry.first)) + { + res.set(entry.first, entry.second); + } + } + + return res; + } + + /*explicit*/ Substitution::Substitution() + { + } + + /*virtual*/ Substitution::~Substitution() + { + } + + void Substitution::set(std::string const& from, std::string const& to) + { + if (auto itr=m_substs.find(from); + itr != std::end(m_substs)) + { + throw substitution_error { + "'" + from + "' already defined in substitution" + }; + } + + m_substs[from] = to; + } + + bool Substitution::has(std::string const& from) const + { + return m_substs.find(from) != std::end(m_substs); + } + + std::string Substitution::get(std::string const& from) const + { + SP_ASSERT(has(from), "cannot get subtitution for '" + from + "'"); + + auto result = m_substs.at(from); + + if (has(result)) + { + return get(result); + } + + return result; + } + + std::string Substitution::string() const + { + std::stringstream ss; + std::string sep = ""; + + for (auto entry: m_substs) + { + ss << sep << entry.first << " => " << entry.second; + sep = ", "; + } + + return ss.str(); + } + } +} diff --git a/src/fol/Substitution.hpp b/src/fol/Substitution.hpp new file mode 100644 index 0000000..88f51bb --- /dev/null +++ b/src/fol/Substitution.hpp @@ -0,0 +1,36 @@ +#ifndef sp_fol_SUBSTITUTION_HPP +#define sp_fol_SUBSTITUTION_HPP + +#include "../commons.hpp" + +namespace sp +{ + namespace fol + { + SP_ERROR(substitution_error); + + class Substitution + { + public: + static std::optional + concat(Substitution const& lhs, + Substitution const& rhs); + + explicit Substitution(); + virtual ~Substitution(); + + size_t size() const { return m_substs.size(); } + + void set(std::string const& from, std::string const& to); + bool has(std::string const& from) const; + std::string get(std::string const& from) const; + + std::string string() const; + + private: + std::unordered_map m_substs; + }; + } +} + +#endif diff --git a/tests/Node.cpp b/tests/Node.cpp new file mode 100644 index 0000000..6a9c4dd --- /dev/null +++ b/tests/Node.cpp @@ -0,0 +1,277 @@ +#include +#include "../src/fol/Node.hpp" +#include "../src/fol/Substitution.hpp" + +using namespace sp::fol; + +class NodeTest +{ +public: + explicit NodeTest() {} + virtual ~NodeTest() {} + +protected: +}; + +TEST_CASE_METHOD(NodeTest, "Node_equals") +{ + auto a = Node::from_string("Happy(x) -> Happy(y)"); + auto b = Node::from_string("Happy(x) -> Happy(y)"); + auto c = Node::from_string("Happy(x) -> Happy(x)"); + auto d = Node::from_string("Happy(x) & Happy(x)"); + auto e = Node::from_string("Happy(x) | Happy(x)"); + auto f = Node::from_string("Happy(x)"); + auto g = Node::from_string("!Happy(x)"); + + REQUIRE(a->equals(*b)); + REQUIRE(!a->equals(*c)); + REQUIRE(!d->equals(*e)); + REQUIRE(f->equals(*f)); + REQUIRE(!f->equals(*g)); +} + +TEST_CASE_METHOD(NodeTest, "Node_substitution") +{ + auto node = Node::from_string(" Friend(x, y) -> Friend(y, x) "); + + SECTION("simple") + { + Substitution sub; + sub.set("x", "a"); + auto oracle = Node::from_string("Friend(a, y) -> Friend(y, a)"); + + auto res = Node::apply(node, sub); + + REQUIRE(res->equals(*oracle)); + } + + SECTION("multiple") + { + Substitution sub; + sub.set("x", "a"); + sub.set("a", "b"); + sub.set("b", "c"); + + auto oracle = Node::from_string("Friend(c, y) -> Friend(y, c)"); + + auto res = Node::apply(node, sub); + + REQUIRE(res->equals(*oracle)); + } + + SECTION("multiple v2") + { + Substitution sub; + sub.set("x", "a"); + sub.set("y", "b"); + sub.set("b", "c"); + + auto oracle = Node::from_string("Friend(a, c) -> Friend(c, a)"); + + auto res = Node::apply(node, sub); + + REQUIRE(res->equals(*oracle)); + } + + SECTION("multiple const") + { + Substitution sub; + sub.set("x", "WORLD"); + + auto oracle = + Node::from_string("Friend(WORLD, y) -> Friend(y, WORLD)"); + + auto res = Node::apply(node, sub); + INFO("oracle = " << oracle->string()); + INFO("result = " << res->string()); + + REQUIRE(res->equals(*oracle)); + } +} + +TEST_CASE_METHOD(NodeTest, "Node_unify") +{ + SECTION("predicate: base case") + { + auto lhs = Node::from_string(" Friend(A, y) "); + auto rhs = Node::from_string(" Friend(x, y) "); + auto result = Node::unify(lhs, rhs); + INFO("result is nullopt ? " << (result == std::nullopt)); + REQUIRE(std::nullopt != result); + INFO("result size: " << result->size()); + REQUIRE(1 == result->size()); + REQUIRE(result->has("x")); + REQUIRE("A" == result->get("x")); + } + + SECTION("predicate: two substitutions") + { + auto lhs = Node::from_string(" Friend(x, A, z) "); + auto rhs = Node::from_string(" Friend(x, y, B) "); + auto result = Node::unify(lhs, rhs); + INFO("result is nullopt ? " << (result == std::nullopt)); + REQUIRE(std::nullopt != result); + INFO("result size: " << result->size()); + REQUIRE(2 == result->size()); + REQUIRE(result->has("y")); + REQUIRE(result->has("z")); + REQUIRE("A" == result->get("y")); + REQUIRE("B" == result->get("z")); + } + + SECTION("predicate: no substitution found") + { + auto lhs = Node::from_string(" Friend(x, A, z) "); + auto rhs = Node::from_string(" Friend(x, C, B) "); + auto result = Node::unify(lhs, rhs); + + REQUIRE(std::nullopt == result); + } + + SECTION("predicate: wrong name") + { + auto lhs = Node::from_string(" Friendly(A, y, C) "); + auto rhs = Node::from_string(" Friend(x, B, z) "); + auto result = Node::unify(lhs, rhs); + + REQUIRE(std::nullopt == result); + } + + SECTION("predicate: three substitutions") + { + auto lhs = Node::from_string(" Friend(A, y, C) "); + auto rhs = Node::from_string(" Friend(x, B, z) "); + auto result = Node::unify(lhs, rhs); + INFO("result is nullopt ? " << (result == std::nullopt)); + REQUIRE(std::nullopt != result); + INFO("result size: " << result->size()); + REQUIRE(3 == result->size()); + REQUIRE(result->has("x")); + REQUIRE(result->has("y")); + REQUIRE(result->has("z")); + REQUIRE("A" == result->get("x")); + REQUIRE("B" == result->get("y")); + REQUIRE("C" == result->get("z")); + } + + SECTION("predicate: and operation") + { + auto lhs = Node::from_string(" Friend(x, y) & Friend(y, x) "); + auto rhs = Node::from_string(" Friend(JOHN, y) & Friend(y, x) "); + auto result = Node::unify(lhs, rhs); + + REQUIRE(std::nullopt != result); + + REQUIRE(1 == result->size()); + REQUIRE(result->has("x")); + REQUIRE("JOHN" == result->get("x")); + } + + SECTION("predicate: or operation") + { + auto lhs = Node::from_string(" Friend(x, y) | Friend(y, x) "); + auto rhs = Node::from_string(" Friend(JOHN, y) | Friend(BOB, x) "); + auto result = Node::unify(lhs, rhs); + + REQUIRE(std::nullopt != result); + + REQUIRE(2 == result->size()); + REQUIRE(result->has("x")); + REQUIRE("JOHN" == result->get("x")); + REQUIRE(result->has("y")); + REQUIRE("BOB" == result->get("y")); + } + + SECTION("predicate: imp operation") + { + auto lhs = Node::from_string(" Friend(x, y) -> Friend(y, x) "); + auto rhs = Node::from_string(" Friend(JOHN, y) -> Friend(BOB, x) "); + auto result = Node::unify(lhs, rhs); + + REQUIRE(std::nullopt != result); + + REQUIRE(2 == result->size()); + REQUIRE(result->has("x")); + REQUIRE("JOHN" == result->get("x")); + REQUIRE(result->has("y")); + REQUIRE("BOB" == result->get("y")); + } + + SECTION("predicate: doubles") + { + auto lhs = Node::from_string(" Friend(x, y) -> Friend(y, JOHN) "); + auto rhs = Node::from_string(" Friend(JOHN, y) -> Friend(BOB, x) "); + auto result = Node::unify(lhs, rhs); + + REQUIRE(std::nullopt != result); + + REQUIRE(2 == result->size()); + REQUIRE(result->has("x")); + REQUIRE("JOHN" == result->get("x")); + REQUIRE(result->has("y")); + REQUIRE("BOB" == result->get("y")); + } + + SECTION("doubles inconsistence") + { + auto lhs = Node::from_string(" Friend(x, y) -> Friend(y, ALICE) "); + auto rhs = Node::from_string(" Friend(JOHN, y) -> Friend(BOB, x) "); + auto result = Node::unify(lhs, rhs); + + REQUIRE(std::nullopt == result); + } + + SECTION("compound formula") + { + auto lhs = + Node::from_string(" (Friend(x, y) & Friend(y, x)) -> Ok(x, y)"); + + auto rhs = + Node::from_string(" (Friend(x, ALICE) & Friend(y, BOB)) " + "-> Ok(x, ALICE)"); + + auto result = Node::unify(lhs, rhs); + + REQUIRE(std::nullopt != result); + + REQUIRE(2 == result->size()); + REQUIRE(result->has("x")); + REQUIRE("BOB" == result->get("x")); + REQUIRE(result->has("y")); + REQUIRE("ALICE" == result->get("y")); + } + + SECTION("functions I") + { + auto lhs = + Node::from_string(" Happy(x) -> Happy(father(x))"); + + auto rhs = + Node::from_string(" Happy(BOB) -> Happy(father(x))"); + + auto result = Node::unify(lhs, rhs); + + REQUIRE(std::nullopt != result); + + REQUIRE(1 == result->size()); + REQUIRE(result->has("x")); + REQUIRE("BOB" == result->get("x")); + } + + SECTION("functions II") + { + auto lhs = + Node::from_string(" Happy(x) -> Happy(father(x))"); + + auto rhs = + Node::from_string(" Happy(x) -> Happy(father(CLAIRE))"); + + auto result = Node::unify(lhs, rhs); + + REQUIRE(std::nullopt != result); + + REQUIRE(1 == result->size()); + REQUIRE(result->has("x")); + REQUIRE("CLAIRE" == result->get("x")); + } +}