Draw 0 or 1 on a drawing pad, click "save" and "generate" to see how RBM generates the number you drew.
import jupyter_drawing_pad as jd
import ipywidgets as widgets
from IPython.display import display,clear_output
from PIL import Image
import numpy as np
widget = jd.CustomBox()
btn_save = widgets.Button(description='save')
btn_clear = widgets.Button(description='clear')
btn_generate = widgets.Button(description='generate')
output = widgets.Output()
output2 = widgets.Output()
box_layout = widgets.Layout(display='flex',
flex_flow='column',
align_items='center')
box = widgets.HBox(children=[btn_save, btn_clear],layout=box_layout)
vbox = widgets.VBox([widget.drawing_pad, box, output])
PILimage = []
dilation = np.zeros((28,28), dtype='int64')
def on_button_save_clicked(b):
global PILimage, dilation, output
import numpy as np
import cv2 as cv
image = np.zeros((28, 28))
for x, y in zip(widget.drawing_pad.data[0], widget.drawing_pad.data[1]):
image[27-round(y*27/100),round(x*27/100)] = 1
kernel = np.ones((2,2),np.uint8)
dilation = cv.dilate(image,kernel,iterations = 1)
# dilation = dilation.astype('int64')
PILdilation = dilation.astype('uint8') * 255
PILimage = Image.fromarray(PILdilation)
PILimage = PILimage.resize((300,300))
with output:
clear_output()
display(PILimage)
def on_button_clear_clicked(b):
global widget
widget.drawing_pad.clear()
def load_model(filename):
from qrbm.MSQRBM import MSQRBM
bm = MSQRBM(784, 30,qpu=False)
bm.load(f'./pretrained/{filename}')
return bm
def on_button_generate_clicked(b):
global dilation, output2
image_height = 28
bm = load_model('01.txt')
generated_pic = bm.generate(test_img = dilation.flatten().tolist())
# plt.figure()
# plt.axis('off')
res = np.array(generated_pic, dtype='uint8') * 255
res = np.reshape(res, (28,28))
PILres = Image.fromarray(res)
PILres = PILres.resize((300,300))
with output2:
clear_output()
display(PILres)
btn_save.on_click(on_button_save_clicked)
btn_clear.on_click(on_button_clear_clicked)
btn_generate.on_click(on_button_generate_clicked)
vbox
btn_generate
output2