from pysat.examples.hitman import Hitman
import numpy as np
import cv2
from PIL import Image
from IPython.display import display
def get_sat_clauses(n, d):
clauses = []
for y in range(n + 1):
for x in range(n):
neighbours = []
for shift in range(d):
neighbours.append((x - shift, y))
neighbours.append((x - shift, y - d))
neighbours_rev = [tuple(reversed(n)) for n in neighbours]
clauses.append(neighbours)
clauses.append(neighbours_rev)
all_clauses = sorted({c for clist in clauses for c in clist})
id_map = {i + 1: c for i, c in enumerate(all_clauses)}
clause_map = {c: i for i,c in id_map.items()}
clauses = [
[clause_map[c] for c in clause]
for clause in clauses
]
return clauses, id_map
def get_solution_helper(n, d, return_id_map=False):
solver = Hitman(solver='m22', htype='lbx')
clauses, id_map = get_sat_clauses(n, d)
for c in clauses:
solver.hit(c)
out = solver.get()
if return_id_map:
return out, id_map
return out
def get_solution(n):
ds = list(range(1, n + 1))
sols = [get_solution_helper(n, d) for d in ds]
sol_sizes = [len(s) for s in sols]
best_idx = np.argmin(sol_sizes)
return ds[best_idx], sol_sizes[best_idx]
def plot_solution(n, d, solution, id_map):
img = np.ones((3 * n * 100, 3 * n * 100, 3)) * 255
for i in solution:
left, top = id_map[i]
left = left * 100 + n * 100
top = top * 100 + n * 100
left += np.random.randint(-7,7)
top += np.random.randint(-7,7)
right, bottom = left + d * 100, top + d * 100
color = [int(x) for x in np.random.randint(0, 256, size=(3,))]
img = cv2.rectangle(img, (left, top), (right, bottom), thickness=4, color=color)
centre = (left + right) // 2,(top + bottom) // 2
img = cv2.rectangle(img, (left, top), (right, bottom), thickness=4, color=color)
img = cv2.circle(img, centre, radius=20, thickness=-1, color=(255, 0, 0))
return Image.fromarray(img.astype(np.uint8))
# for n in range(10, 12):
# d, _ = get_solution(n)
# solution, id_map = get_solution_helper(n, d, return_id_map=True)
# print(n, d, len(solution))
for n in range(10, 12):
d, _ = get_solution(n)
solution, id_map = get_solution_helper(n, d, return_id_map=True)
print(n, d, len(solution))
img = plot_solution(n, d, solution, id_map)
display(img)