This notebook outlines a way of using pythreejs as a way of getting input from the user with a Picker
, and then triggering kernel-side computations based on this, which in turn will update the visualization.
The example case here is an iterative algorithm that minimizes a function of two variables. The visualization will be a grid-based surface plot of this function. By double-clicking the surface, the user will start the algorithm. The results of the alogrithm will be live-plotted as a line on top of the surface.
First, we import the dependencies we will need.
from pythreejs import *
import numpy as np
from IPython.display import display
from ipywidgets import HTML, Output, VBox, jslink
view_width = 600
view_height = 400
Here is a simple random descent method for finding local minima, that yields its iterative values. There are many issues with this algorithm, and as such it has a lot of potential for improvement! (left as an exercise for the reader)
def find_minima(f, start=(0, 0), xlim=None, ylim=None):
rate = 0.1 # Learning rate
max_iters = 200 # maximum number of iterations
iters = 0 # iteration counter
cur = np.array(start[:2])
previous_step_size = 1 #
cur_val = f(cur[0], cur[1])
while (iters < max_iters and
xlim[0] <= cur[0] <= xlim[1] and ylim[0] <= cur[1] <= ylim[1]):
iters = iters + 1
candidate = cur - rate * (np.random.rand(2) - 0.5)
candidate_val = f(candidate[0], candidate[1])
if candidate_val >= cur_val:
continue # Bad guess, try again
prev = cur
cur = candidate
cur_val = candidate_val
previous_step_size = np.abs(cur - prev)
yield tuple(cur) + (cur_val,)
print("The local minimum occurs at", cur)
Give a function to test minima finder on. Here we simply use $ f(x, y) = x^2 + y^2 $. You can also try with the surface $ f(x, y) = x^2 - y^2 $ to see the effect of an instable surface.
def f(x, y):
return x ** 2 + y ** 2
Interrogate function on a grid in order to visualize. This uses numpy helper code to generate the grid, and evaluate the function. Note the two evalutations: One at the grid lattice points, and one in the center of each grid square!
nx, ny = (20, 20) # grid resolution
xmax = 1 # grid extent (+/-)
x = np.linspace(-xmax, xmax, nx)
y = np.linspace(-xmax, xmax, ny)
step = x[1] - x[0]
xx, yy = np.meshgrid(x, y)
# Grid lattice values:
grid_z = np.vectorize(f)(xx, yy)
# Grid square center values:
center_z = np.vectorize(f)(0.5 * step + xx[:-1,:-1], 0.5 * step + yy[:-1,:-1])
Setup code for 3D visualization with user point picking with mouse. Here we use the SurfaceGeometry
and SurfaceGrid
helper classes (not direct three.js classes).
# Surface geometry:
surf_g = SurfaceGeometry(z=list(grid_z.flat),
width=2 * xmax,
height=2 * xmax,
width_segments=nx - 1,
height_segments=ny - 1)
# Surface material. Note that the map uses the center-evaluated function-values:
surf = Mesh(geometry=surf_g,
material=MeshLambertMaterial(map=height_texture(center_z, 'YlGnBu_r')))
# Grid-lines for the surface:
surfgrid = SurfaceGrid(geometry=surf_g, material=LineBasicMaterial(color='black'),
position=[0, 0, 1e-2]) # Avoid overlap by lifting grid slightly
# Set up scene:
key_light = DirectionalLight(color='white', position=[3, 5, 1], intensity=0.4)
c = PerspectiveCamera(position=[0, 3, 3], up=[0, 0, 1], aspect=view_width / view_height,
children=[key_light])
scene = Scene(children=[surf, c, surfgrid, AmbientLight(intensity=0.8)])
We will now plot our figure. Note that initially, this will not have the interactive features, but we will add the interactivity below.
renderer = Renderer(camera=c, scene=scene,
width=view_width, height=view_height,
controls=[OrbitControls(controlling=c)])
out = Output() # An Output for displaying captured print statements
box = VBox([renderer])
display(box)
First, let us add a simple position tracker. This will find what point on the surface that the mouse is hovering over. We will represent this point with a pink sphere.
# Picker object
hover_picker = Picker(controlling=surf, event='mousemove')
renderer.controls = renderer.controls + [hover_picker]
# A sphere for representing the current point on the surface
hover_point = Mesh(geometry=SphereGeometry(radius=0.05),
material=MeshLambertMaterial(color='hotpink'))
scene.add(hover_point)
# Have sphere follow picker point:
jslink((hover_point, 'position'), (hover_picker, 'point'));
Next, we will observe the changes to the hover point, and display its coordinates in a label which we add to the figure above:
coord_label = HTML() # A label for showing hover picker coordinates
def on_hover_change(change):
coord_label.value = 'Pink point at (%.3f, %.3f, %.3f)' % tuple(change['new'])
on_hover_change({'new': hover_point.position})
hover_picker.observe(on_hover_change, names=['point'])
box.children = (coord_label,) + box.children
Finally, we set up a picker for when the user double clikcs on the surface. This should trigger the execution and visualization of the alogrithm.
# Create our picker for the double-click event ("dblclick")
click_picker = Picker(controlling=surf, event='dblclick')
def on_click(change):
value = change['new']
with out:
print('Clicked on %s' % (value,))
# Add a red sphere on the picked point
point = Mesh(geometry=SphereGeometry(radius=0.05),
material=MeshLambertMaterial(color='red'),
position=value)
scene.add(point)
# Plot solution as a red line, this will start out empty
points = [value]
line = Line2(geometry=LineGeometry(positions=points), material=LineMaterial(color='red', linewidth=2))
scene.add(line)
with out: # Pick up any print statements in the algorithm
for pt in find_minima(f, value, [-xmax, xmax], [-xmax, xmax]):
# For each point, update the line:
pt = list(pt)
pt[2] += 1e-2 # offset to clear surface
line.geometry = LineGeometry(positions=np.vstack([line.geometry.positions, pt]))
# When the point selected by the picker changes, trigger our function:
click_picker.observe(on_click, names=['point'])
# Update figure:
renderer.controls = renderer.controls + [click_picker]
box.children = box.children + (out,)
Final note: This notebook tries to explain the visualization code for a specific scenario. If you are more interested in understanding how different iterative minimization algorithms work, you should extract this visualization code to an external function that you can import. Then you can use it to: