✨ forward chaining v0.
parent
fb41410797
commit
bfd64b20b6
|
@ -1 +1,2 @@
|
|||
__pycache__
|
||||
*.prof
|
|
@ -0,0 +1,5 @@
|
|||
.PHONY: prof
|
||||
|
||||
prof:
|
||||
python -m cProfile -o sine_patre.prof sine_patre.py
|
||||
snakeviz sine_patre.prof
|
|
@ -1,5 +1,6 @@
|
|||
from fol import lexer
|
||||
from fol import parser
|
||||
from fol import kb
|
||||
from fol import pretty
|
||||
from fol import cnf
|
||||
from fol.unify import unify
|
||||
|
|
50
fol/cnf.py
50
fol/cnf.py
|
@ -1,12 +1,28 @@
|
|||
import fol
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.values = {}
|
||||
self.counter = 0
|
||||
|
||||
def get(self, name):
|
||||
if name in self.values.keys():
|
||||
return self.values[name]
|
||||
self.values[name] = self.counter
|
||||
self.counter += 1
|
||||
return self.counter - 1
|
||||
|
||||
def begin(self):
|
||||
self.values = {}
|
||||
|
||||
|
||||
def elim_imp(f):
|
||||
res = fol.node.Node(f.name, f.value)
|
||||
|
||||
if f.name == 'IMP':
|
||||
a = fol.node.Node('NOT')
|
||||
a.add_child(f.children[0])
|
||||
a.add_child(elim_imp(f.children[0]))
|
||||
res = fol.node.Node('OR')
|
||||
res.add_child(elim_imp(a))
|
||||
res.add_child(elim_imp(f.children[1]))
|
||||
|
@ -23,9 +39,9 @@ def neg(f):
|
|||
|
||||
if f.name == 'NOT' and f.children[0].name in ['OR', 'AND']:
|
||||
p = fol.node.Node('NOT')
|
||||
p.add_child(f.children[0].children[0])
|
||||
p.add_child(neg(f.children[0].children[0]))
|
||||
q = fol.node.Node('NOT')
|
||||
q.add_child(f.children[0].children[1])
|
||||
q.add_child(neg(f.children[0].children[1]))
|
||||
|
||||
if f.children[0].name == 'OR':
|
||||
res = fol.node.Node('AND')
|
||||
|
@ -38,24 +54,23 @@ def neg(f):
|
|||
res.add_child(neg(q))
|
||||
return res
|
||||
|
||||
res = fol.node.Node(f.name, f.value)
|
||||
|
||||
for c in f.children:
|
||||
res.add_child(neg(c))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def normalize(f, values={}):
|
||||
def normalize(f, counter):
|
||||
res = fol.node.Node(f.name, f.value)
|
||||
|
||||
if f.name == 'VAR':
|
||||
if f.value in values.keys():
|
||||
res.value = values[f.value]
|
||||
else:
|
||||
new_name = f'x{len(values)}'
|
||||
values[res.value] = new_name
|
||||
res.value = new_name
|
||||
|
||||
res.value = f'x{counter.get(f.value)}'
|
||||
return res
|
||||
|
||||
for child in f.children:
|
||||
res.add_child(normalize(child))
|
||||
res.add_child(normalize(child, counter))
|
||||
|
||||
return res
|
||||
|
||||
|
@ -207,7 +222,12 @@ def distrib(f):
|
|||
return f
|
||||
|
||||
|
||||
def cnf(f):
|
||||
return distrib(
|
||||
skolem(prenex(normalize(neg(elim_imp(f)))))
|
||||
def cnf(f, counter=None):
|
||||
if counter is None:
|
||||
counter = Counter()
|
||||
|
||||
g = distrib(
|
||||
skolem(prenex(normalize(neg(elim_imp(f)), counter)))
|
||||
).remove_quant('FORALL')
|
||||
|
||||
return g
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
import fol
|
||||
|
||||
|
||||
class Kb:
|
||||
def __init__(self):
|
||||
self.base = []
|
||||
self.counter = fol.cnf.Counter()
|
||||
|
||||
def make_f(self, h):
|
||||
self.counter.begin()
|
||||
f = h
|
||||
if type(f) is str:
|
||||
f = fol.p(f)
|
||||
self.counter.begin()
|
||||
f = fol.cnf.cnf(f, self.counter)
|
||||
|
||||
if f.name == 'OR':
|
||||
f = f.list_of('OR')
|
||||
else:
|
||||
f = [f]
|
||||
|
||||
return f
|
||||
|
||||
def make_req(self, h):
|
||||
f = h
|
||||
if type(f) is str:
|
||||
f = fol.p(f)
|
||||
|
||||
if f.name == 'OR':
|
||||
f = f.list_of('OR')
|
||||
else:
|
||||
f = [f]
|
||||
|
||||
return f
|
||||
|
||||
def check_request(self, request):
|
||||
for r in request:
|
||||
if r.name not in ['NOT', 'PRED']:
|
||||
raise Exception(f'invalid statement {repr(r)}')
|
||||
|
||||
def conds(self, statement):
|
||||
result = []
|
||||
for f in statement:
|
||||
if f.name == 'NOT':
|
||||
result.append(f.children[0])
|
||||
return result
|
||||
|
||||
def conclusion(self, statement):
|
||||
for f in statement:
|
||||
if f.name != 'NOT':
|
||||
return f
|
||||
return None
|
||||
|
||||
@property
|
||||
def rules(self):
|
||||
res = []
|
||||
for r in self.base:
|
||||
if len(r) > 1:
|
||||
res.append(r)
|
||||
return res
|
||||
|
||||
@property
|
||||
def facts(self):
|
||||
res = []
|
||||
for r in self.base:
|
||||
if len(r) == 1:
|
||||
res.append(r[0])
|
||||
return res
|
||||
|
||||
def exists(self, f):
|
||||
h = self.make_f(f)
|
||||
for g in self.base:
|
||||
if self.clause_equals(h, g):
|
||||
return True
|
||||
return False
|
||||
|
||||
def clause_equals(self, left, right):
|
||||
if len(left) != len(right):
|
||||
return False
|
||||
|
||||
for element in left:
|
||||
if not self.clause_in(element, right):
|
||||
return False
|
||||
|
||||
for element in right:
|
||||
if not self.clause_in(element, left):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def clause_in(self, clause, lst):
|
||||
for other in lst:
|
||||
same = fol.unify(clause, other) is not None
|
||||
if same:
|
||||
return True
|
||||
return False
|
||||
|
||||
def tell(self, request):
|
||||
req = self.make_f(request)
|
||||
self.check_request(req)
|
||||
self.base.append(req)
|
||||
|
||||
def ask(self, request):
|
||||
for i in range(20):
|
||||
n = len(self.base)
|
||||
self.update()
|
||||
if n == len(self.base):
|
||||
break
|
||||
|
||||
req = self.make_req(request)
|
||||
results = []
|
||||
|
||||
for f in self.base:
|
||||
s = fol.unify(f, req)
|
||||
if s is not None:
|
||||
results.append(s)
|
||||
|
||||
return results
|
||||
|
||||
def update(self):
|
||||
for rule in self.rules:
|
||||
conclusion = self.conclusion(rule)
|
||||
solution = []
|
||||
solutions = []
|
||||
self.solve(rule, self.conds(rule), self.facts, solution, solutions)
|
||||
for sol in solutions:
|
||||
concl = conclusion.subst(sol)
|
||||
if not self.exists(concl):
|
||||
self.base.append([concl])
|
||||
return
|
||||
|
||||
def merge_substs(self, substs):
|
||||
res = {}
|
||||
for s in substs:
|
||||
for k in s.keys():
|
||||
if k in res.keys() and not res[k].equals(s[k]):
|
||||
return None
|
||||
res[k] = s[k]
|
||||
return res
|
||||
|
||||
def solve(self, rule, conds, facts, solution, solutions):
|
||||
if len(conds) == 0:
|
||||
substs = []
|
||||
|
||||
for sol in solution:
|
||||
s = fol.unify(sol[0], sol[1])
|
||||
if s is not None:
|
||||
substs.append(s)
|
||||
|
||||
s = self.merge_substs(substs)
|
||||
|
||||
if s is not None:
|
||||
solutions.append(s)
|
||||
return
|
||||
|
||||
for cond in conds:
|
||||
for fact in facts:
|
||||
solution.append((cond, fact))
|
||||
|
||||
self.solve(
|
||||
rule,
|
||||
[c for c in conds if c != cond],
|
||||
[f for f in facts if f != fact],
|
||||
solution,
|
||||
solutions
|
||||
)
|
||||
|
||||
solution.pop()
|
34
fol/node.py
34
fol/node.py
|
@ -1,3 +1,6 @@
|
|||
import fol
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, name, value=None):
|
||||
self.name = name
|
||||
|
@ -8,7 +11,7 @@ class Node:
|
|||
self.children.append(child)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
return fol.pretty.get(self)
|
||||
|
||||
def __str__(self):
|
||||
res = self.name
|
||||
|
@ -68,3 +71,32 @@ class Node:
|
|||
return False
|
||||
|
||||
return True
|
||||
|
||||
def list_of(self, name):
|
||||
result = []
|
||||
|
||||
if self.name == name:
|
||||
for c in self.children:
|
||||
if c.name != name:
|
||||
result.append(c)
|
||||
result.extend(c.list_of(name))
|
||||
|
||||
return result
|
||||
|
||||
def flatten(self):
|
||||
result = self.list_of(self.name)
|
||||
print(result)
|
||||
root = fol.node.Node(self.name, self.value)
|
||||
root.children = result
|
||||
return root
|
||||
|
||||
def find_by_name(self, name):
|
||||
if self.name == name:
|
||||
return [self]
|
||||
|
||||
res = []
|
||||
|
||||
for c in self.children:
|
||||
res.extend(c.find_by_name(name))
|
||||
|
||||
return res
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
import fol
|
||||
|
||||
|
||||
def get(f):
|
||||
if f.name == 'EXISTS':
|
||||
return f'\\exist {get(f.children[0])}({get(f.children[1])})'
|
||||
|
|
|
@ -32,7 +32,8 @@ def test_cnf_neg(f, res):
|
|||
'Pre(f(x0)) -> Post(g(g(x1), x0))'),
|
||||
])
|
||||
def test_cnf_normalize(f, res):
|
||||
assert fol.cnf.normalize(fol.p(f)).equals(fol.p(res))
|
||||
counter = fol.cnf.Counter()
|
||||
assert fol.cnf.normalize(fol.p(f), counter).equals(fol.p(res))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['f', 'res'], [
|
||||
|
@ -116,7 +117,12 @@ def test_distrib(oracle, f):
|
|||
|
||||
@pytest.mark.parametrize(['f', 'oracle'], [
|
||||
('Friend(x, y) -> Friend(y, x)',
|
||||
'~Friend(x0, x1) | Friend(x1, x0)')
|
||||
'~Friend(x0, x1) | Friend(x1, x0)'),
|
||||
|
||||
('(Friend(x, y) & Friend(y, z)) -> Friend(x, z)',
|
||||
'(~Friend(x0, x1) | ~Friend(x1, x2)) | Friend(x0, x2)')
|
||||
])
|
||||
def test_cnf(oracle, f):
|
||||
print(fol.pretty.get(fol.p(f)))
|
||||
print(fol.pretty.get(fol.cnf.cnf(fol.p(f))))
|
||||
assert fol.cnf.cnf(fol.p(f)).equals(fol.p(oracle))
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
import fol
|
||||
|
||||
|
||||
def test_kb_simple():
|
||||
kb = fol.kb.Kb()
|
||||
kb.tell('Friend(ALICE, BOB)')
|
||||
|
||||
res = kb.ask('Friend(ALICE, x)')
|
||||
|
||||
assert len(res) == 1
|
||||
assert res[0]['x'].value == 'BOB'
|
|
@ -6,16 +6,37 @@ def test_unify_simple():
|
|||
fol.p('Hello(x)'),
|
||||
fol.p('Hello(ALICE)')
|
||||
)
|
||||
|
||||
assert subst is not None
|
||||
assert len(subst) == 1
|
||||
assert str(subst['x']) == 'CONST[ALICE]'
|
||||
|
||||
|
||||
def test_unify_equals():
|
||||
subst = fol.unify(
|
||||
fol.p('Hello(ALICE)'),
|
||||
fol.p('Hello(ALICE)')
|
||||
)
|
||||
|
||||
assert subst is not None
|
||||
assert len(subst) == 0
|
||||
|
||||
|
||||
def test_unify_different_predicate_name():
|
||||
subst = fol.unify(
|
||||
fol.p('Hello(x)'),
|
||||
fol.p('World(ALICE)')
|
||||
)
|
||||
|
||||
assert subst is None
|
||||
|
||||
|
||||
def test_unify_fun():
|
||||
subst = fol.unify(
|
||||
fol.p('Hello(x, f(x))'),
|
||||
fol.p('Hello(ALICE, f(ALICE))')
|
||||
)
|
||||
|
||||
assert subst is not None
|
||||
assert len(subst) == 1
|
||||
assert str(subst['x']) == 'CONST[ALICE]'
|
||||
|
|
103
fol/unify.py
103
fol/unify.py
|
@ -1,29 +1,84 @@
|
|||
import fol
|
||||
|
||||
|
||||
def unify(lhs, rhs):
|
||||
res = {}
|
||||
try:
|
||||
unify_res(lhs, rhs, res)
|
||||
except Exception:
|
||||
s = {}
|
||||
return _unify(lhs, rhs, s)
|
||||
|
||||
|
||||
def _unify(lhs, rhs, s):
|
||||
if s is None:
|
||||
return None
|
||||
return res
|
||||
elif is_equal(lhs, rhs):
|
||||
return s
|
||||
elif is_var(lhs):
|
||||
return unify_var(lhs, rhs, s)
|
||||
elif is_var(rhs):
|
||||
return unify_var(rhs, lhs, s)
|
||||
elif is_compound(lhs) and is_compound(rhs) and (
|
||||
lhs.name == rhs.name and lhs.value == rhs.value
|
||||
):
|
||||
|
||||
|
||||
def unify_res(lhs, rhs, result):
|
||||
if lhs.name == 'VAR':
|
||||
unify_var(lhs, rhs, result)
|
||||
elif rhs.name == 'VAR':
|
||||
unify_var(rhs, lhs, result)
|
||||
return _unify(
|
||||
lhs.children,
|
||||
rhs.children,
|
||||
s
|
||||
)
|
||||
elif is_list(lhs) and is_list(rhs):
|
||||
if len(lhs) != len(rhs):
|
||||
return None
|
||||
elif len(lhs) == 0:
|
||||
return s
|
||||
elif len(lhs) == 1:
|
||||
return _unify(lhs[0], rhs[0], s)
|
||||
else:
|
||||
left = lhs.children
|
||||
right = rhs.children
|
||||
if len(left) == len(right):
|
||||
for i in range(len(left)):
|
||||
unify_res(left[i], right[i], result)
|
||||
|
||||
|
||||
def unify_var(var, rhs, result):
|
||||
if rhs.name in ['VAR', 'CONST', 'FUN'] and rhs.value != var.value:
|
||||
if var.value in result.keys() \
|
||||
and not result[var.value].equals(rhs):
|
||||
raise Exception('')
|
||||
return _unify(lhs[1:], rhs[1:], _unify(lhs[0], rhs[0], s))
|
||||
else:
|
||||
result[var.value] = rhs
|
||||
return None
|
||||
|
||||
|
||||
def unify_var(var, x, s):
|
||||
if var.value in s.keys():
|
||||
return _unify(s[var.value], x, s)
|
||||
elif x in s.keys():
|
||||
return _unify(var, s[x], s)
|
||||
elif occur_check(var, x):
|
||||
return None
|
||||
else:
|
||||
s[var.value] = x
|
||||
return s
|
||||
|
||||
|
||||
def occur_check(var, x):
|
||||
if var.name == x.name and var.value == x.value:
|
||||
return True
|
||||
|
||||
for child in x.children:
|
||||
if occur_check(var, child):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_compound(f):
|
||||
if type(f) is fol.node.Node:
|
||||
return len(f.children) > 0
|
||||
return False
|
||||
|
||||
|
||||
def is_equal(lhs, rhs):
|
||||
if type(lhs) is str and type(rhs) is str:
|
||||
return lhs == rhs
|
||||
|
||||
if type(lhs) is fol.node.Node and type(rhs) is fol.node.Node:
|
||||
return lhs.equals(rhs)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_var(f):
|
||||
return type(f) is fol.node.Node and f.name == 'VAR'
|
||||
|
||||
|
||||
def is_list(value):
|
||||
return type(value) is list
|
||||
|
|
|
@ -1,2 +1,13 @@
|
|||
import fol
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Hello, World!')
|
||||
kb = fol.kb.Kb()
|
||||
try:
|
||||
kb.tell('Friend(ALICE, BOB)')
|
||||
kb.tell('Friend(BOB, CLAIRE)')
|
||||
kb.tell('Friend(x, y) -> Friend(y, x)')
|
||||
kb.tell('(Friend(x, y) & Friend(y, z)) -> Friend(x, z)')
|
||||
|
||||
print(kb.ask('Friend(ALICE, x)'))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue