from clawpack import riemann
from clawpack import pyclaw
import numpy as np

riemann_solver = riemann.nonlinear_elasticity_fwave_1D
solver = pyclaw.ClawSolver1D(riemann_solver)
solver.fwave = True

# Boundary conditions
solver.bc_lower[0] = pyclaw.BC.extrap
solver.bc_upper[0] = pyclaw.BC.extrap
solver.aux_bc_lower[0] = pyclaw.BC.extrap
solver.aux_bc_upper[0] = pyclaw.BC.extrap

xlower=0.0; xupper=1000.0
cells_per_layer=12; mx=int(round(xupper-xlower))*cells_per_layer
x = pyclaw.Dimension('x',xlower,xupper,mx)
domain = pyclaw.Domain(x)
state = pyclaw.State(domain,solver.num_eqn,3)
xc=state.grid.x.centers

#Initialize q and aux
KA    = 1.0; rhoA  = 1.0
KB    = 4.0; rhoB  = 4.0
xfrac = xc-np.floor(xc)

state.aux[0,:] = rhoA*(xfrac<0.5)+rhoB*(xfrac>=0.5) #Density
state.aux[1,:] = KA  *(xfrac<0.5)+KB  *(xfrac>=0.5) #Bulk modulus
state.aux[2,:] = 0. # not used

sigma = 0.5*np.exp(-((xc-500.)/5.)**2.)
state.q[0,:] = np.log(sigma+1.)/state.aux[1,:]  # Strain
state.q[1,:] = 0.                               # Momentum

claw = pyclaw.Controller()
claw.solution = pyclaw.Solution(state,domain)

claw.output_style = 1
claw.num_output_times = 100
claw.tfinal =  550.
claw.solver = solver
claw.keep_copy = True
claw.output_format = None

import plotly
py = plotly.plotly(username_or_email='DavidKetcheson', key='mgs2lgb203')

data = [{'x':xc, 'y':state.q[0,:]},{'x':xc[5800:6200],'y':state.q[0,5800:6200],"xaxis":"x2","yaxis":"y2"}]
layout = {"title": "$\\text{Strain }(\epsilon)$",'showlegend':False,    "xaxis2": {"domain": [0.6, 0.95],"anchor": "y2"},
    "yaxis2":{"domain": [0.2, 0.6],"anchor": "x2"}}
py.iplot(data,layout=layout)

Z = np.sqrt(state.aux[0,:]*state.aux[1,:])
data = [{'x':xc, 'y':Z},{'x':xc[5800:6200],'y':Z[5800:6200],"xaxis":"x2","yaxis":"y2"}]
layout = {"title": "$\\text{Impedance }(Z)$",'showlegend':False,    "xaxis2": {"domain": [0.6, 0.95],"anchor": "y2"},
    "yaxis2":{"domain": [0.2, 0.6],"anchor": "x2"}}
py.iplot(data,layout = layout)

claw.run()

%pylab inline
from matplotlib import animation
import matplotlib.pyplot as plt
from clawpack.visclaw.JSAnimation import IPython_display

fig = plt.figure(figsize=[8,4])
ax = plt.axes(xlim=(xc[0], xc[-1]), ylim=(0, 0.4))
line, = ax.plot([], [], lw=1)

def fplot(i):
    frame = claw.frames[i]
    strain = frame.q[0,:]
    line.set_data(xc, strain)
    ax.set_title('Strain at t='+str(frame.t))
    return line,

animation.FuncAnimation(fig, fplot, frames=len(claw.frames))

data = [{'x':xc, 'y':state.q[0,:]},{'x':xc[11000:11700],'y':state.q[0,11000:11700],"xaxis":"x2","yaxis":"y2"}]
layout = {"title": "Strain ($\epsilon$)",'showlegend':False,"xaxis2": {"domain": [0.4, 0.85],"anchor": "y2"},
    "yaxis2":{"domain": [0.15, 0.65],"anchor": "x2"}}

py.iplot(data, layout=layout)

def set_bc_periodic(solver,state):
    "Change to periodic BCs after initial pulse"
    if state.t>5*state.problem_data['tw1']:
        solver.bc_lower[0] = pyclaw.BC.periodic
        solver.bc_upper[0] = pyclaw.BC.periodic
        solver.aux_bc_lower[0] = pyclaw.BC.periodic
        solver.aux_bc_upper[0] = pyclaw.BC.periodic
        solver.before_step = None

def moving_wall_bc(state,dim,t,qbc,num_ghost):
    "Initial pulse generated at left boundary by prescribed motion"
    if dim.on_lower_boundary:
        qbc[0,:num_ghost]=qbc[0,num_ghost] 
        t=state.t; t1=state.problem_data['t1']; tw1=state.problem_data['tw1']
        amp = state.problem_data['amp'];
        t0 = (t-t1)/tw1
        if abs(t0)<=1.: vwall = -amp*(1.+np.cos(t0*np.pi))
        else: vwall=0.
        for ibc in xrange(num_ghost-1):
            qbc[1,num_ghost-ibc-1] = 2*vwall*state.aux[1,ibc] - qbc[1,num_ghost+ibc]

riemann_solver = riemann.nonlinear_elasticity_fwave_1D
solver = pyclaw.ClawSolver1D(riemann_solver)
solver.fwave = True
solver.before_step = set_bc_periodic

# Boundary conditions
solver.bc_lower[0] = pyclaw.BC.custom 
solver.user_bc_lower = moving_wall_bc

solver.bc_upper[0] = pyclaw.BC.extrap
solver.aux_bc_lower[0] = pyclaw.BC.extrap
solver.aux_bc_upper[0] = pyclaw.BC.extrap


xlower=0.0; xupper=300.0 
cells_per_layer=24; mx=int(round(xupper-xlower))*cells_per_layer
x = pyclaw.Dimension('x',xlower,xupper,mx)
domain = pyclaw.Domain(x)
state = pyclaw.State(domain,solver.num_eqn,3)
xc=state.grid.x.centers

#Initialize q and aux
KA    = 1.0; rhoA  = 1.0
KB    = 4.0; rhoB  = 4.0
xfrac = xc-np.floor(xc)

state.aux[0,:] = rhoA*(xfrac<0.5)+rhoB*(xfrac>=0.5) #Density
state.aux[1,:] = KA  *(xfrac<0.5)+KB  *(xfrac>=0.5) #Bulk modulus
state.aux[2,:] = 0. # not used

state.q[0,:] = 0.  # Strain 
state.q[1,:] = 0.  # Momentum

state.problem_data = {}
state.problem_data['t1']    = 10.0
state.problem_data['tw1']   = 10.0
state.problem_data['amp']    = 0.1

claw = pyclaw.Controller()
claw.solution = pyclaw.Solution(state,domain)
claw.solver = solver

claw.num_output_times = 100
claw.tfinal =  1000.
claw.keep_copy = True
claw.output_format = None

claw.run()

fig = plt.figure(figsize=[8,4])
ax = plt.axes(xlim=(0, 300), ylim=(0, 0.6))
line, = ax.plot([], [])

animation.FuncAnimation(fig, fplot, frames=len(claw.frames))

py.iplot(xc,claw.frames[-1].q[0,:])

import copy
claw.solution = copy.deepcopy(claw.frames[0])  # Reset simulation

claw.solution.state.aux[0,:] = 0.5*(rhoA+rhoB) + 0.5*(rhoA-rhoB)*np.sin(2*np.pi*xc) #Density
claw.solution.state.aux[1,:] = 0.5*(KA+KB) + 0.5*(KA-KB)*np.sin(2*np.pi*xc) #Bulk modulus
claw.solution.state.aux[2,:] = 0. # not used

solver = pyclaw.ClawSolver1D(riemann_solver)
solver.fwave = True

solver.bc_lower[0] = pyclaw.BC.custom
solver.bc_upper[0] = pyclaw.BC.extrap

solver.aux_bc_lower[0] = pyclaw.BC.extrap
solver.aux_bc_upper[0] = pyclaw.BC.extrap

solver.user_bc_lower = moving_wall_bc
solver.before_step = set_bc_periodic
claw.solver = solver
claw.tfinal = 1000
claw.frames = []
claw.run()

fig = plt.figure(figsize=[8,4])
ax = plt.axes(xlim=(0, 300), ylim=(0, 0.6))
line, = ax.plot([], [])

animation.FuncAnimation(fig, fplot, frames=len(claw.frames))

py.iplot(xc,claw.frames[-1].q[0,:])