This is a classic example from relational learning implemented in LTN and grounded in a domain $\mathbb{R}^n$.
import logging;logging.basicConfig(level=logging.INFO)
import tensorflow as tf
# adding parent directory to sys.path (ltnw is in parent dir)
import sys, os
sys.path.insert(0,os.path.normpath(os.path.join(os.path.abspath(''),os.path.pardir)))
import logictensornetworks as ltn
import logictensornetworks_wrapper as ltnw
Some initial settings. Includes the dimensionality of the embedding/grounding space
ltn.LAYERS = 4
ltn.BIAS_factor = 1e-3
max_epochs=1000
embedding_size = 10 # embedding space dimensionality
We use three predicates. Two unary predicates and one binary predicate
ltnw.predicate('Smokes',embedding_size);
ltnw.predicate('Friends',embedding_size*2);
ltnw.predicate('Cancer',embedding_size);
Generate constants a to n
g1,g2='abcdefgh','ijklmn'
g=g1+g2
for l in g:
ltnw.constant(l,min_value=[0.]*embedding_size,max_value=[1.]*embedding_size)
First we define the facts known about the world, i.e. who we know is friends, smokes and has cancer
friends = [('a','b'),('a','e'),('a','f'),('a','g'),('b','c'),('c','d'),('e','f'),('g','h'),
('i','j'),('j','m'),('k','l'),('m','n')]
[ltnw.formula("Friends(%s,%s)" %(x,y)) for (x,y) in friends]
[ltnw.formula("~Friends(%s,%s)" %(x,y)) for x in g1 for y in g1 if (x,y) not in friends and x < y]
[ltnw.formula("~Friends(%s,%s)" %(x,y)) for x in g2 for y in g2 if (x,y) not in friends and x < y]
smokes = ['a','e','f','g','j','n']
[ltnw.formula("Smokes(%s)" % x) for x in smokes]
[ltnw.formula("~Smokes(%s)" % x) for x in g if x not in smokes]
cancer = ['a','e']
[ltnw.formula("Cancer(%s)" % x) for x in cancer]
[ltnw.formula("~Cancer(%s)" % x) for x in g if x not in cancer]
print("\n".join(sorted(ltnw.FORMULAS.keys())))
Cancer(a) Cancer(e) Friends(a,b) Friends(a,e) Friends(a,f) Friends(a,g) Friends(b,c) Friends(c,d) Friends(e,f) Friends(g,h) Friends(i,j) Friends(j,m) Friends(k,l) Friends(m,n) Smokes(a) Smokes(e) Smokes(f) Smokes(g) Smokes(j) Smokes(n) ~Cancer(b) ~Cancer(c) ~Cancer(d) ~Cancer(f) ~Cancer(g) ~Cancer(h) ~Cancer(i) ~Cancer(j) ~Cancer(k) ~Cancer(l) ~Cancer(m) ~Cancer(n) ~Friends(a,c) ~Friends(a,d) ~Friends(a,h) ~Friends(b,d) ~Friends(b,e) ~Friends(b,f) ~Friends(b,g) ~Friends(b,h) ~Friends(c,e) ~Friends(c,f) ~Friends(c,g) ~Friends(c,h) ~Friends(d,e) ~Friends(d,f) ~Friends(d,g) ~Friends(d,h) ~Friends(e,g) ~Friends(e,h) ~Friends(f,g) ~Friends(f,h) ~Friends(i,k) ~Friends(i,l) ~Friends(i,m) ~Friends(i,n) ~Friends(j,k) ~Friends(j,l) ~Friends(j,n) ~Friends(k,m) ~Friends(k,n) ~Friends(l,m) ~Friends(l,n) ~Smokes(b) ~Smokes(c) ~Smokes(d) ~Smokes(h) ~Smokes(i) ~Smokes(k) ~Smokes(l) ~Smokes(m)
Then we add knowledge that we know holds for each constant/individual. For this we use two variables, that go over the all samples
ltnw.variable("p",tf.concat(list(ltnw.CONSTANTS.values()),axis=0))
ltnw.variable("q",tf.concat(list(ltnw.CONSTANTS.values()),axis=0))
ltnw.formula("forall p: ~Friends(p,p)")
ltnw.formula("forall p,q:Friends(p,q) -> Friends(q,p)")
ltnw.formula("forall p: exists q: Friends(p,q)")
ltnw.formula("forall p,q:Friends(p,q) -> (Smokes(p)->Smokes(q))")
ltnw.formula("forall p: Smokes(p) -> Cancer(p)")
print("\n".join(sorted(filter(lambda x: x.startswith("forall"), ltnw.FORMULAS.keys()))))
forall p,q:Friends(p,q) -> (Smokes(p)->Smokes(q)) forall p,q:Friends(p,q) -> Friends(q,p) forall p: Smokes(p) -> Cancer(p) forall p: exists q: Friends(p,q) forall p: ~Friends(p,p)
Initialize the knowledgebase and optimize
optimizer=tf.train.RMSPropOptimizer(learning_rate=.01,decay=.9)
formula_aggregator=lambda *x: tf.reduce_mean(tf.concat(x,axis=0))+ltn.BIAS
ltnw.initialize_knowledgebase(initial_sat_level_threshold=.1,optimizer=optimizer,formula_aggregator=formula_aggregator)
ltnw.train(track_sat_levels=1000,sat_level_epsilon=.99,max_epochs=max_epochs);
INFO:logictensornetworks_wrapper:Initializing knowledgebase INFO:logictensornetworks_wrapper:Initializing optimizer INFO:logictensornetworks_wrapper:Assembling feed dict INFO:logictensornetworks_wrapper:Initializing Tensorflow session INFO:logictensornetworks_wrapper:INITIALIZED with sat level = 0.4619353 INFO:logictensornetworks_wrapper:TRAINING 0 sat level -----> 0.4619353 INFO:logictensornetworks_wrapper:TRAINING finished after 999 epochs with sat level 0.9683184
Check some formulas
for x in g:
print(x," = ",ltnw.ask(x).squeeze())
print("Cancer("+x+"): %.2f" % ltnw.ask("Cancer(%s)" % x).squeeze())
print("Smokes("+x+"): %.2f" % ltnw.ask("Smokes(%s)" % x).squeeze())
for y in g:
print("Friends("+x+","+y+"): %.2f" % ltnw.ask("Friends(%s,%s)" % (x,y)).squeeze())
a = [ 0.40694648 0.13683926 1.1569124 1.1818535 0.14294973 0.9392728 1.7961706 -0.82798356 0.61106527 1.3495604 ] Cancer(a): 1.00 Smokes(a): 1.00 Friends(a,a): 0.00 Friends(a,b): 0.96 Friends(a,c): 0.00 Friends(a,d): 0.00 Friends(a,e): 1.00 Friends(a,f): 1.00 Friends(a,g): 1.00 Friends(a,h): 0.00 Friends(a,i): 0.00 Friends(a,j): 1.00 Friends(a,k): 0.96 Friends(a,l): 0.00 Friends(a,m): 0.26 Friends(a,n): 0.26 b = [ 0.9147028 1.0019453 0.5234545 0.3568315 0.9653716 -0.7044046 0.29647025 1.2177726 0.7527934 0.6989008 ] Cancer(b): 0.00 Smokes(b): 0.00 Friends(b,a): 0.26 Friends(b,b): 0.00 Friends(b,c): 0.96 Friends(b,d): 0.00 Friends(b,e): 0.00 Friends(b,f): 0.00 Friends(b,g): 0.00 Friends(b,h): 0.00 Friends(b,i): 0.00 Friends(b,j): 0.00 Friends(b,k): 0.00 Friends(b,l): 0.00 Friends(b,m): 0.00 Friends(b,n): 0.00 c = [ 0.17684561 0.10682261 -0.3228023 0.18780398 2.2169807 0.07184859 0.67084223 0.1279161 0.03808275 0.2801455 ] Cancer(c): 0.00 Smokes(c): 0.00 Friends(c,a): 0.00 Friends(c,b): 0.00 Friends(c,c): 0.00 Friends(c,d): 1.00 Friends(c,e): 0.00 Friends(c,f): 0.00 Friends(c,g): 0.00 Friends(c,h): 0.00 Friends(c,i): 0.00 Friends(c,j): 0.00 Friends(c,k): 0.00 Friends(c,l): 0.00 Friends(c,m): 0.00 Friends(c,n): 0.00 d = [ 1.8529487 1.5305904 -0.00983319 -0.43959376 1.1610166 -0.19935457 0.8084609 0.8635864 -0.37106708 -0.21966477] Cancer(d): 0.00 Smokes(d): 0.00 Friends(d,a): 0.00 Friends(d,b): 0.00 Friends(d,c): 0.96 Friends(d,d): 0.00 Friends(d,e): 0.00 Friends(d,f): 0.00 Friends(d,g): 0.00 Friends(d,h): 0.00 Friends(d,i): 0.00 Friends(d,j): 0.00 Friends(d,k): 0.96 Friends(d,l): 0.00 Friends(d,m): 0.00 Friends(d,n): 0.00 e = [0.6416602 0.66157794 0.12425633 0.67886037 0.09569763 1.7586013 1.1241968 0.92352915 0.47239357 0.8050211 ] Cancer(e): 1.00 Smokes(e): 1.00 Friends(e,a): 1.00 Friends(e,b): 0.00 Friends(e,c): 0.00 Friends(e,d): 0.00 Friends(e,e): 0.00 Friends(e,f): 1.00 Friends(e,g): 0.00 Friends(e,h): 0.00 Friends(e,i): 0.00 Friends(e,j): 0.96 Friends(e,k): 0.00 Friends(e,l): 0.00 Friends(e,m): 0.26 Friends(e,n): 0.00 f = [-0.08516412 1.3672957 0.72731215 0.5337677 0.24218766 -0.17402795 0.7505686 1.0381958 -0.35125494 1.0763618 ] Cancer(f): 0.00 Smokes(f): 1.00 Friends(f,a): 1.00 Friends(f,b): 0.00 Friends(f,c): 0.00 Friends(f,d): 0.00 Friends(f,e): 1.00 Friends(f,f): 0.00 Friends(f,g): 0.00 Friends(f,h): 0.00 Friends(f,i): 0.00 Friends(f,j): 0.00 Friends(f,k): 0.00 Friends(f,l): 0.00 Friends(f,m): 0.26 Friends(f,n): 0.04 g = [ 0.8539798 0.761432 0.75761944 0.69546074 0.30572608 0.59105855 -0.343103 1.1347185 0.53520316 -0.08765783] Cancer(g): 0.00 Smokes(g): 1.00 Friends(g,a): 1.00 Friends(g,b): 0.00 Friends(g,c): 0.00 Friends(g,d): 0.00 Friends(g,e): 0.00 Friends(g,f): 0.00 Friends(g,g): 0.00 Friends(g,h): 0.96 Friends(g,i): 1.00 Friends(g,j): 0.00 Friends(g,k): 0.00 Friends(g,l): 0.00 Friends(g,m): 0.26 Friends(g,n): 0.00 h = [ 0.33865392 0.84384173 0.70018744 0.1345554 1.1172969 0.09839492 0.81501096 -0.15736319 0.56537557 0.7202723 ] Cancer(h): 0.00 Smokes(h): 0.00 Friends(h,a): 0.00 Friends(h,b): 0.00 Friends(h,c): 0.00 Friends(h,d): 0.00 Friends(h,e): 0.00 Friends(h,f): 0.00 Friends(h,g): 0.00 Friends(h,h): 0.00 Friends(h,i): 0.00 Friends(h,j): 0.00 Friends(h,k): 0.00 Friends(h,l): 0.00 Friends(h,m): 0.00 Friends(h,n): 0.00 i = [ 0.8514193 1.4208581 0.40041497 -0.19588503 -0.200309 -0.2671704 1.292986 -0.57946455 1.1746347 0.6758825 ] Cancer(i): 0.00 Smokes(i): 0.00 Friends(i,a): 0.00 Friends(i,b): 0.00 Friends(i,c): 0.00 Friends(i,d): 0.00 Friends(i,e): 0.00 Friends(i,f): 0.00 Friends(i,g): 0.96 Friends(i,h): 0.00 Friends(i,i): 0.00 Friends(i,j): 0.96 Friends(i,k): 0.00 Friends(i,l): 0.00 Friends(i,m): 0.00 Friends(i,n): 0.00 j = [ 0.66317457 0.4312353 0.76770747 0.26138163 -0.5081676 0.71022123 0.3749792 0.8656672 1.5517639 0.17871857] Cancer(j): 0.00 Smokes(j): 1.00 Friends(j,a): 1.00 Friends(j,b): 0.00 Friends(j,c): 0.00 Friends(j,d): 0.00 Friends(j,e): 0.96 Friends(j,f): 0.00 Friends(j,g): 0.00 Friends(j,h): 0.00 Friends(j,i): 0.96 Friends(j,j): 0.00 Friends(j,k): 0.00 Friends(j,l): 0.00 Friends(j,m): 1.00 Friends(j,n): 0.00 k = [ 0.8996744 0.8115943 0.30769414 0.25508854 1.6205107 -1.1135386 -0.31606957 0.27544066 1.0825695 0.51878643] Cancer(k): 0.00 Smokes(k): 0.00 Friends(k,a): 0.26 Friends(k,b): 0.00 Friends(k,c): 0.00 Friends(k,d): 0.96 Friends(k,e): 0.00 Friends(k,f): 0.00 Friends(k,g): 0.00 Friends(k,h): 0.00 Friends(k,i): 0.00 Friends(k,j): 0.00 Friends(k,k): 0.00 Friends(k,l): 1.00 Friends(k,m): 0.00 Friends(k,n): 0.00 l = [ 1.0598446 0.3733673 1.0997576 -0.74925214 0.9516005 0.44306272 0.15078382 -0.45237944 0.21813329 0.13396993] Cancer(l): 0.00 Smokes(l): 0.00 Friends(l,a): 0.00 Friends(l,b): 0.00 Friends(l,c): 0.00 Friends(l,d): 0.00 Friends(l,e): 0.00 Friends(l,f): 0.00 Friends(l,g): 0.00 Friends(l,h): 0.00 Friends(l,i): 0.00 Friends(l,j): 0.00 Friends(l,k): 0.96 Friends(l,l): 0.00 Friends(l,m): 0.00 Friends(l,n): 0.00 m = [-0.2774569 1.4202287 1.1293546 0.9845873 0.48729423 0.41340557 1.2946917 0.7260294 0.6903299 1.0546893 ] Cancer(m): 0.00 Smokes(m): 1.00 Friends(m,a): 0.26 Friends(m,b): 0.00 Friends(m,c): 0.00 Friends(m,d): 0.00 Friends(m,e): 0.26 Friends(m,f): 0.26 Friends(m,g): 0.26 Friends(m,h): 0.00 Friends(m,i): 0.00 Friends(m,j): 1.00 Friends(m,k): 0.00 Friends(m,l): 0.00 Friends(m,m): 0.26 Friends(m,n): 1.00 n = [0.50134546 0.961617 0.84858876 0.2242519 0.9684452 1.192804 0.06254435 0.4865449 0.0592797 0.5465894 ] Cancer(n): 0.00 Smokes(n): 1.00 Friends(n,a): 0.26 Friends(n,b): 0.00 Friends(n,c): 0.00 Friends(n,d): 0.00 Friends(n,e): 0.00 Friends(n,f): 0.00 Friends(n,g): 0.00 Friends(n,h): 0.00 Friends(n,i): 0.00 Friends(n,j): 0.00 Friends(n,k): 0.00 Friends(n,l): 0.00 Friends(n,m): 0.26 Friends(n,n): 0.00
Check some general axioms
for formula in ["forall p: ~Friends(p,p)",
"forall p,q: Friends(p,q) -> Friends(q,p)",
"forall p: exists q: Friends(p,q)",
"forall p,q: Friends(p,q) -> (Smokes(p)->Smokes(q))"]:
print(formula,": %.2f" % ltnw.ask(formula).squeeze())
forall p: ~Friends(p,p) : 0.97 forall p,q: Friends(p,q) -> Friends(q,p) : 0.77 forall p: exists q: Friends(p,q) : 0.99 forall p,q: Friends(p,q) -> (Smokes(p)->Smokes(q)) : 0.00