This Elastica tutorial explains how to setup a Cosserat rod simulation to simulate a slithering snake. It is a more complex use case than the Timoshenko Beam example. If you have not done so, we strongly suggest you start with this beam example as it covers many of the basics of setting up and running simulations with Elastica.
This slithering snake example includes gravitational forces, friction forces, and internal muscle torques. It also introduces the use of call back functions to allow logging of simulations data for post-processing after the simulation is over.
To set up the simulation, the first thing you need to do is import the necessary classes. As with the Timoshenko bean, we need to import wrapper functions which allow us to more easily construct different simulation systems. We also need to import a rod class, all the necessary forces to be applied, timestepping functions, and callback classes.
import numpy as np
# import wrappers
from elastica.wrappers import BaseSystemCollection, Constraints, Forcing, CallBacks
# import rod class and forces to be applied
from elastica.rod.cosserat_rod import CosseratRod
from elastica.external_forces import GravityForces, MuscleTorques
from elastica.interaction import AnisotropicFrictionalPlane
# import timestepping functions
from elastica.timestepper.symplectic_steppers import PositionVerlet
from elastica.timestepper import integrate
# import call back functions
from elastica.callback_functions import CallBackBaseClass
from collections import defaultdict
The first thing to do is initialize the simulator class by combining all the imported wrappers. After initializing, we will generate a rod and add it to the simulation.
class SnakeSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks):
pass
snake_sim = SnakeSimulator()
# Define rod parameters
n_elem = 50
start = np.array([0.0, 0.0, 0.0])
direction = np.array([0.0, 0.0, 1.0])
normal = np.array([0.0, 1.0, 0.0])
base_length = 0.35
base_radius = base_length * 0.011
base_area = np.pi * base_radius ** 2
density = 1000
nu = 1e-4
E = 1e6
poisson_ratio = 0.5
shear_modulus = E / (poisson_ratio + 1.0)
# Create rod
shearable_rod = CosseratRod.straight_rod(
n_elem,
start,
direction,
normal,
base_length,
base_radius,
density,
nu,
E,
shear_modulus=shear_modulus,
)
# Add rod to the snake system
snake_sim.append(shearable_rod)
With our rod added to the system, we need to specify the relevant forces that will be acting on the rod. For all the forces, the method of adding forces is system_name.add_forcing_to(name_of_rod).using(type_of_force, *kwargs)
where *kwargs
are the parameters specific to each type of force.
The first force to add is gravity. We specify the strength of gravity and also the direction it is pointing.
# Add gravitational forces
gravitational_acc = -9.80665
snake_sim.add_forcing_to(shearable_rod).using(
GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
)
print("Gravity now acting on shearable rod")
Gravity now acting on shearable rod
A snake generates torque throughout its body through muscle activations. While these muscle activations are generated internally by the snake, it is simpler to treat them as applied external forces, allowing us to apply them to the rod in the same manner as the other external forces.
You may notice that the muscle torque parameters appear to have special values. These are optimized coefficients for a snake gait. For information about how to do this optimization, see the snake optimization example script.
# Define muscle torque parameters
period = 2.0
wave_length = 1.0
b_coeff = np.array([3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3])
# Add muscle torques to the rod
snake_sim.add_forcing_to(shearable_rod).using(
MuscleTorques,
base_length=base_length,
b_coeff=b_coeff,
period=period,
wave_number=2.0 * np.pi / (wave_length),
phase_shift=0.0,
rest_lengths=shearable_rod.rest_lengths,
ramp_up_time=period,
direction=normal,
with_spline=True,
)
print("Muscle torques added to the rod")
Muscle torques added to the rod
The last force that needs to be added is the friction force between the snake and the ground. Snakes exhibits anisotropic friction where the friction coefficient is different in different directions. You can also define both static and kinematic friction coefficients. This is accomplished by defining some small velocity threshold slip_velocity_tol
that defines the transitions between static and kinematic friction.
# Define friction force parameters
origin_plane = np.array([0.0, -base_radius, 0.0])
normal_plane = normal
slip_velocity_tol = 1e-8
froude = 0.1
mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
kinetic_mu_array = np.array(
[1.0 * mu, 1.5 * mu, 2.0 * mu]
) # [forward, backward, sideways]
static_mu_array = 2 * kinetic_mu_array
# Add friction forces to the rod
snake_sim.add_forcing_to(shearable_rod).using(
AnisotropicFrictionalPlane,
k=1.0,
nu=1e-6,
plane_origin=origin_plane,
plane_normal=normal_plane,
slip_velocity_tol=slip_velocity_tol,
static_mu_array=static_mu_array,
kinetic_mu_array=kinetic_mu_array,
)
print("Friction forces added to the rod")
Friction forces added to the rod
The simulation is now setup, but before it is run, we want to define a callback function. A callback function allows us to record time-series data throughout the simulation. If you do not define a callback function, you will only have access to the final configuration of the system. If you want to be able to analyze how the system evolves over time, it is critical that you record the appropriate quantities.
To create a callback function, begin with the CallBackBaseClass
. You can then define which state quantities you wish to record by having them appended to the self.callback_params
dictionary as well as how often you wish to save the data by defining skip_step
.
# Add call backs
class ContinuumSnakeCallBack(CallBackBaseClass):
"""
Call back function for continuum snake
"""
def __init__(self, step_skip: int, callback_params: dict):
CallBackBaseClass.__init__(self)
self.every = step_skip
self.callback_params = callback_params
def make_callback(self, system, time, current_step: int):
if current_step % self.every == 0:
self.callback_params["time"].append(time)
self.callback_params["step"].append(current_step)
self.callback_params["position"].append(system.position_collection.copy())
self.callback_params["velocity"].append(system.velocity_collection.copy())
self.callback_params["avg_velocity"].append(
system.compute_velocity_center_of_mass()
)
self.callback_params["center_of_mass"].append(
system.compute_position_center_of_mass()
)
self.callback_params["curvature"].append(system.kappa.copy())
return
pp_list = defaultdict(list)
snake_sim.collect_diagnostics(shearable_rod).using(
ContinuumSnakeCallBack, step_skip=1000, callback_params=pp_list
)
print("Callback function added to the simulator")
Callback function added to the simulator
With the callback function added, we can now finalize the system and also define the time stepping parameters of the simulation such as the time step, final time, and time stepping algorithm to use.
snake_sim.finalize()
final_time = 5.0 * period
dt = 8.0e-6
total_steps = int(final_time / dt)
print("Total steps", total_steps)
timestepper = PositionVerlet()
Total steps 1250000
Now all that is left is to run the simulation. Using the default parameters the simulation takes about 2-3 minutes to complete.
integrate(timestepper, snake_sim, final_time, total_steps)
100%|██████████| 1250000/1250000 [02:13<00:00, 9368.55it/s]
Final time of simulation is : 9.999999999982162
With the simulation complete, we want to analyze the simulation. Because we added a callback function, we can analyze how the snake evolves over time. All of the data from the callback function is located in the pp_list
dictionary. Here we will use this information to compute and plot the velocity of the snake in the forward, lateral, and normal directions. We do this by using a pre-written analysis function compute_projected_velocity
.
In the plotted graph, you can see that it takes about one period for the snake to begin moving before rapidly reaching a steady gait over just 2-3 periods. We also see that the normal velocity is zero since we are only actuating the snake in a 2D plane.
def compute_projected_velocity(plot_params: dict, period):
import numpy as np
time_per_period = np.array(plot_params["time"]) / period
avg_velocity = np.array(plot_params["avg_velocity"])
center_of_mass = np.array(plot_params["center_of_mass"])
# Compute rod velocity in rod direction. We need to compute that because,
# after snake starts to move it chooses an arbitrary direction, which does not
# have to be initial tangent direction of the rod. Thus we need to project the
# snake velocity with respect to its new tangent and roll direction, after that
# we will get the correct forward and lateral speed. After this projection
# lateral velocity of the snake has to be oscillating between + and - values with
# zero mean.
# Number of steps in one period.
period_step = int(1.0 / (time_per_period[-1] - time_per_period[-2])) + 1
number_of_period = int(time_per_period[-1]) + 1
# Center of mass position averaged in one period
center_of_mass_averaged_over_one_period = np.zeros((number_of_period - 2, 3))
for i in range(1, number_of_period - 1):
# position of center of mass averaged over one period
center_of_mass_averaged_over_one_period[i - 1] = np.mean(
center_of_mass[(i + 1) * period_step : (i + 2) * period_step]
- center_of_mass[(i + 0) * period_step : (i + 1) * period_step],
axis=0,
)
# Average the rod directions over multiple periods and get the direction of the rod.
direction_of_rod = np.mean(center_of_mass_averaged_over_one_period, axis=0)
direction_of_rod /= np.linalg.norm(direction_of_rod, ord=2)
# Compute the projected rod velocity in the direction of the rod
velocity_mag_in_direction_of_rod = np.einsum(
"ji,i->j", avg_velocity, direction_of_rod
)
velocity_in_direction_of_rod = np.einsum(
"j,i->ji", velocity_mag_in_direction_of_rod, direction_of_rod
)
# Get the lateral or roll velocity of the rod after subtracting its projected
# velocity in the direction of rod
velocity_in_rod_roll_dir = avg_velocity - velocity_in_direction_of_rod
# Compute the average velocity over the simulation, this can be used for optimizing snake
# for fastest forward velocity. We start after first period, because of the ramping up happens
# in first period.
average_velocity_over_simulation = np.mean(
velocity_in_direction_of_rod[period_step * 2 :], axis=0
)
return (
velocity_in_direction_of_rod,
velocity_in_rod_roll_dir,
average_velocity_over_simulation[2],
average_velocity_over_simulation[0],
)
def compute_and_plot_velocity(plot_params: dict, period):
from matplotlib import pyplot as plt
from matplotlib.colors import to_rgb
time_per_period = np.array(plot_params["time"]) / period
avg_velocity = np.array(plot_params["avg_velocity"])
[
velocity_in_direction_of_rod,
velocity_in_rod_roll_dir,
_,
_,
] = compute_projected_velocity(plot_params, period)
fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
plt.rcParams.update({"font.size": 16})
ax = fig.add_subplot(111)
ax.grid(b=True, which="minor", color="k", linestyle="--")
ax.grid(b=True, which="major", color="k", linestyle="-")
ax.plot(
time_per_period[:], velocity_in_direction_of_rod[:, 2], "r-", label="forward"
)
ax.plot(
time_per_period[:],
velocity_in_rod_roll_dir[:, 0],
c=to_rgb("xkcd:bluish"),
label="lateral",
)
ax.plot(time_per_period[:], avg_velocity[:, 1], "k-", label="normal")
ax.set_ylabel("Velocity [m/s]", fontsize=16)
ax.set_xlabel("Time [s]", fontsize=16)
fig.legend(prop={"size": 20})
plt.show()
plt.close(plt.gcf())
compute_and_plot_velocity(pp_list, period)
We can plot the curvature along the snake at different time instance and compare it with the sterotypical snake curvature function $7cos(2 \pi s)$.
def plot_curvature(
plot_params: dict,
rest_lengths,
period,
):
from matplotlib import pyplot as plt
from matplotlib.colors import to_rgb
s = np.cumsum(rest_lengths)
L0 = s[-1]
s = s / L0
s = s[:-1].copy()
x = np.linspace(0, 1, 100)
curvature = np.array(plot_params["curvature"])
time = np.array(plot_params["time"])
peak_time = period * 0.125
dt = time[1] - time[0]
peak_idx = int(peak_time / (dt))
plt.rcParams.update({"font.size": 16})
fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
ax = fig.add_subplot(111)
try:
for i in range(peak_idx * 8, peak_idx * 8 * 2, peak_idx):
ax.plot(s, curvature[i, 0, :] * L0, "k")
except:
print("Simulation time not long enough to plot curvature")
ax.plot(
x, 7 * np.cos(2 * np.pi * x - 0.80), "--", label="stereotypical snake curvature"
)
ax.set_ylabel(r"$\kappa$", fontsize=16)
ax.set_xlabel("s", fontsize=16)
ax.set_xlim(0, 1)
ax.set_ylim(-10, 10)
fig.legend(prop={"size": 16})
plt.show()
plt.close(plt.gcf())
plot_curvature(pp_list, shearable_rod.rest_lengths, period)
Because we saved data of the snake's behavior, we can make a video of its movement. The easiest way to do this is to do this is to plot the snake's position at each time that the data was recorded and then stitch these plots together to form a video.
note: ffmpeg is required for matplotlib to be able to create a video. More info on ffmepg here.
from IPython.display import Video
from tqdm import tqdm
def plot_video_2D(plot_params: dict, video_name="video.mp4", margin=0.2, fps=15):
from matplotlib import pyplot as plt
import matplotlib.animation as manimation
t = np.array(plot_params["time"])
positions_over_time = np.array(plot_params["position"])
total_time = int(np.around(t[..., -1], 1))
total_frames = fps * total_time
step = round(len(t) / total_frames)
print("creating video -- this can take a few minutes")
FFMpegWriter = manimation.writers["ffmpeg"]
metadata = dict(title="Movie Test", artist="Matplotlib", comment="Movie support!")
writer = FFMpegWriter(fps=fps, metadata=metadata)
fig = plt.figure()
ax = fig.add_subplot(111)
plt.axis("equal")
rod_lines_2d = ax.plot(
positions_over_time[0][2], positions_over_time[0][0], linewidth=3
)[0]
ax.set_xlim([0 - margin, 3 + margin])
ax.set_ylim([-1.5 - margin, 1.5 + margin])
with writer.saving(fig, video_name, dpi=100):
with plt.style.context("seaborn-whitegrid"):
for time in range(1, len(t), step):
rod_lines_2d.set_xdata(positions_over_time[time][2])
rod_lines_2d.set_ydata(positions_over_time[time][0])
writer.grab_frame()
plt.close(fig)
filename_video = "continuum_snake.mp4"
plot_video_2D(pp_list, video_name=filename_video, margin=0.2, fps=125)
Video("continuum_snake.mp4")
creating video -- this can take a few minutes
Finally, you can also plot the position of the snake from a 3D perspective. This is most helpful is you have a simulation that consists of more than planar motion.
from IPython.display import Video
def plot_video(plot_params: dict, video_name="video.mp4", margin=0.2, fps=15):
from matplotlib import pyplot as plt
import matplotlib.animation as manimation
from mpl_toolkits import mplot3d
t = np.array(plot_params["time"])
positions_over_time = np.array(plot_params["position"])
total_time = int(np.around(t[..., -1], 1))
total_frames = fps * total_time
step = round(len(t) / total_frames)
print("creating video -- this can take a few minutes")
FFMpegWriter = manimation.writers["ffmpeg"]
metadata = dict(title="Movie Test", artist="Matplotlib", comment="Movie support!")
writer = FFMpegWriter(fps=fps, metadata=metadata)
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.set_xlim(0 - margin, 3 + margin)
ax.set_ylim(-1.5 - margin, 1.5 + margin)
ax.set_zlim(0, 1)
ax.view_init(elev=20, azim=-80)
rod_lines_3d = ax.plot(
positions_over_time[0][2],
positions_over_time[0][0],
positions_over_time[0][1],
linewidth=3,
)[0]
with writer.saving(fig, video_name, dpi=100):
with plt.style.context("seaborn-whitegrid"):
for time in range(1, len(t), step):
rod_lines_3d.set_xdata(positions_over_time[time][2])
rod_lines_3d.set_ydata(positions_over_time[time][0])
rod_lines_3d.set_3d_properties(positions_over_time[time][1])
writer.grab_frame()
plt.close(fig)
filename_video = "continuum_snake_3d.mp4"
plot_video(pp_list, video_name=filename_video, margin=0.2, fps=60)
Video("continuum_snake_3d.mp4")
creating video -- this can take a few minutes