import fol def unify(lhs, rhs): s = {} return _unify(lhs, rhs, s) def _unify(lhs, rhs, s): if s is None: return None elif is_equal(lhs, rhs): return s elif is_var(lhs): return unify_var(lhs, rhs, s) elif is_var(rhs): return unify_var(rhs, lhs, s) elif is_compound(lhs) and is_compound(rhs) and ( lhs.name == rhs.name and lhs.value == rhs.value ): return _unify( lhs.children, rhs.children, s ) elif is_list(lhs) and is_list(rhs): if len(lhs) != len(rhs): return None elif len(lhs) == 0: return s elif len(lhs) == 1: return _unify(lhs[0], rhs[0], s) else: return _unify(lhs[1:], rhs[1:], _unify(lhs[0], rhs[0], s)) else: return None def unify_var(var, x, s): if var.value in s.keys(): return _unify(s[var.value], x, s) elif x in s.keys(): return _unify(var, s[x], s) elif occur_check(var, x): return None else: s[var.value] = x return s def occur_check(var, x): if var.name == x.name and var.value == x.value: return True for child in x.children: if occur_check(var, child): return True return False def is_compound(f): if type(f) is fol.node.Node: return len(f.children) > 0 return False def is_equal(lhs, rhs): if type(lhs) is str and type(rhs) is str: return lhs == rhs if type(lhs) is fol.node.Node and type(rhs) is fol.node.Node: return lhs.equals(rhs) return False def is_var(f): return type(f) is fol.node.Node and f.name == 'VAR' def is_list(value): return type(value) is list