from __future__ import division

from collections import OrderedDict
from datetime import timedelta

import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

def plot_runtimes(system, data, ax1_ylim=None, ax2_ylim=None):
    seconds = np.array(
        [timedelta(minutes=int(v.split(':')[0]), seconds=int(v.split(':')[1])).total_seconds()
         for v in data.values()])
    cpus = np.array(
        [int(k.split('x')[0]) * int(k.split('x')[1]) 
         for k in data.keys()])

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 4))

    ax1.plot(seconds, 'o')
    if ax1_ylim is not None:
        ax1.set_ylim(*ax1_ylim)
    ax1.set_ylabel('Runtime [s]')
    ax1.set_xlabel('MPI Decomposition [CPUs]')
    ax1.set_xticks(range(len(data)))
    ax1.set_xticklabels(list(data.keys()))
    ax1.grid()

    ax2.plot(seconds*cpus/3600, 'o')
    if ax2_ylim is not None:
        ax2.set_ylim(*ax2_ylim)
    ax2.set_ylabel('Total CPU Time [hr]')
    ax2.set_xlabel('Number of CPUs')
    ax2.set_xticks(range(len(cpus)))
    ax2.set_xticklabels(['{num}\n{shape}'.format(num=p, shape=s) for p, s in zip(cpus, data.keys())])
    ax2.grid()

    fig_title = fig.suptitle('{system} MPI Decomposition Runtimes'.format(system=system))
    
    return ax1.get_ylim(), ax2.get_ylim()

orcinus_data = OrderedDict([
    ('6x14', '21:36'),
    ('7x16', '17:46'),
    ('8x18', '11:27'),
    ('9x20', '9:21'),
    ('10x22', '8:20'),
    ('11x25', '7:19'),
    ('12x27', '5:55'),
    ('13x29', '6:34'),
    ('14x32', '5:30'),
    ('15x34', '4:48'),
    ('16x36', '4:30'),
])

jasper_data = OrderedDict([
    ('6x14', '19:34'),
    ('7x16', '14:16'),
    ('8x18', '12:40'),
    ('9x20', '9:08'),
    ('10x22', '8:06'),
    ('11x25', '6:05'),
    ('12x27', '6:04'),
    ('13x29', '5:25'),
    ('14x32', '4:42'),
    ('15x34', '14:31'),
    ('16x36', '21:40'),
])

ax1_ylim, ax2_ylim = plot_runtimes('Orcinus', orcinus_data)
ax1_ylim, ax2_ylim = plot_runtimes('Jasper', jasper_data, ax1_ylim, ax2_ylim)

westcloud_data = OrderedDict([
    ('4x8',  '175:00'),
    ('4x10', '146:00'),
    ('5x11', '118:00'),
    ('6x13', '88:00'),
    ('7x16', '73:00'),
    ('8x18', '78:00'),
    ('9x20', '62:00'),
    ('10x23', '50:24'),
    ('11x24', '49:00'),
])

ax1_ylim, ax2_ylim = plot_runtimes('West.Cloud', westcloud_data)