#!/usr/bin/env python
# coding: utf-8

# In[1]:


import holoviews as hv
import numpy as np
import pandas as pd
import colorcet as cc


# In[2]:


hv.extension("bokeh")


def get_data():
    data = {
        "1998": np.random.rand(365),
        "1999": np.random.rand(365),
        "2000": np.random.rand(365),
        "2002": np.random.rand(365),
        "2003": np.random.rand(365),
    }
    df = pd.DataFrame(data, index=range(0, 365))
    return df


# In[3]:


# utility to help me placing the month label around the 2nd week of each month

def split_list(a, n):
    k, m = divmod(len(a), n)
    return list(
        list(a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)]) for i in range(n)
    )


def get_ticks(df, pos):
    splitter = split_list(df.index, 12)
    months = [
        "Jan",
        "Feb",
        "Mar",
        "Apr",
        "May",
        "Jun",
        "Jul",
        "Aug",
        "Sep",
        "Oct",
        "Nov",
        "Dec",
    ]
    xticks_map = [i for i in zip([splitter[i][pos] for i in range(0, 12)], months)]
    return xticks_map


# In[4]:


def get_mplot(df, cols=None):
    if cols:
        df = df[cols]
    if len(df.columns) == 0:
        print("No coumns selected")
        return None
    grid_style = {
        "grid_line_color": "black",
        "grid_line_width": 1.1,
        "minor_ygrid_line_color": "lightgray",
        "minor_xgrid_line_color": "lightgray",
        "xgrid_line_dash": [4, 4],
    }
    colors = cc.glasbey_light[: len(list(df.columns))]
    xticks_map = get_ticks(df, 15)
    multi_curve = [
        hv.Curve((df.index, df[v]), label=str(v)).opts(
            xticks=xticks_map,
            xrotation=45,
            width=900,
            height=400,
            line_color=colors[i],
            gridstyle=grid_style,
            show_grid=True,
        )
        for i, v in enumerate(df)
    ]
    mplot = hv.Overlay(multi_curve)
    return mplot


# In[5]:


import panel as pn

pn.extension()

df = get_data()

years = pn.widgets.MultiChoice(
    name="Years", options=list(df.columns), margin=(0, 20, 0, 0)
)


# In[6]:


get_mplot(df, years.value)


# In[7]:


@pn.depends(years)
def get_plot(years):
    df = get_data()
    if years:
        df = df[years]
    mplot = get_mplot(df, years)
    return mplot


# In[8]:


pn.Column("Plot!", get_plot, pn.Row(years), width_policy="max").servable()


# In[ ]: