In [1]:
import numpy as np
import xarray as xr
import math
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os
import f90nml
from salishsea_tools import metric_tools_5x5 as met

%matplotlib inline
plt.rcParams['image.cmap'] = 'jet'
plt.rc('xtick', labelsize=20) 
plt.rc('ytick', labelsize=20) 
In [2]:
reference_namelist_file = '/data/jpetrie/MEOPAR/SS-run-sets/SS-SMELT/namelists/namelist_pisces_cfg_5x5_NewIC'
reference_bio_params = f90nml.read(reference_namelist_file)
In [3]:
start_month_tracer_file = {"Jun":'SS5x5_1h_20150601_20150829_ptrc_T.nc',
                           "Feb":'SS5x5_1h_20150201_20150501_ptrc_T.nc',
                           "Apr":'SS5x5_1h_20150401_20150629_ptrc_T.nc'
}

batch_directories = [
                     '/data/jpetrie/MEOPAR/SalishSea/results/all_params_AprIC_june_22/',
                     '/data/jpetrie/MEOPAR/SalishSea/results/all_params_JunIC_june_23/',
                     '/data/jpetrie/MEOPAR/SalishSea/results/nampiszoo_june_14/',
                     '/data/jpetrie/MEOPAR/SalishSea/results/nampisopt_june_14/',
                     '/data/jpetrie/MEOPAR/SalishSea/results/nampismes_june_14/',
                     '/data/jpetrie/MEOPAR/SalishSea/results/nampissink_june_17/',
                     '/data/jpetrie/MEOPAR/SalishSea/results/nampisprod_june_16/',
                     '/data/jpetrie/MEOPAR/SalishSea/results/nampismort_june_17/',
                     '/data/jpetrie/MEOPAR/SalishSea/results/nampisrem_june_17/',
                     '/data/jpetrie/MEOPAR/SalishSea/results/nampismezo_june_20/',
]

metric_func_list = [
                    met.mean_NO3_at_20m,
                    #met.peak_3_day_biomass,
                    #met.time_surface_NO3_drops_below_4,
                    #met.time_of_peak_PHY2,
                    #met.mean_NH4_at_depth,
                    #met.mean_NO3_at_depth,
                    #met.mean_DON_at_depth,
                    #met.mean_PON_at_depth,     
]

param_metrics = pd.DataFrame()

for batch_dir in batch_directories:
    for file in os.listdir(batch_dir):
        if 'zz_frac_waste' not in file:
            found_tracer_file = False
            for start_month in start_month_tracer_file:
                if os.path.isfile(batch_dir + '/' + file + '/' + start_month_tracer_file[start_month]):
                    found_tracer_file = True
                    break
            if found_tracer_file is True: 
                last_underscore = file.rfind('_')
                first_underscore = file.find('_')
                param_section = file[:first_underscore]
                param_name = file[(first_underscore+1):last_underscore]
                param_val = float(file[(last_underscore+1):])
                param_desc = param_section + ' ' + param_name
                original_val = reference_bio_params[param_section][param_name]

                #if len(param_metrics.query("PARAM_NAME == @param_name and PARAM_VAL == @param_val and START_MONTH == @start_month")) > 0:
                #    continue
                
                if type(original_val) is list:
                    param_scale = round(param_val/original_val[0], 3)
                else:
                    param_scale = round(param_val/original_val, 3)

                grid_t = xr.open_dataset(batch_dir + '/' + file +'/' + start_month_tracer_file[start_month])

                for metric_func in metric_func_list:
                    metric_val = metric_func(grid_t)
                    metric_name = metric_func.__name__
                    # inefficient to keep appending, but much less expensive than other parts of the loop so it doesn't matter
                    param_metrics = param_metrics.append(pd.DataFrame({"PARAM_SECTION":[param_section],
                                                                       "PARAM_NAME":[param_name],
                                                                       "PARAM_DESC":[param_desc],
                                                                       "PARAM_VAL":[param_val],
                                                                       "PARAM_SCALE":[param_scale],
                                                                       "METRIC_NAME":[metric_name],
                                                                       "METRIC_VAL": [metric_val],
                                                                       "START_MONTH":[start_month]}))
In [4]:
num_plot_params = 15

grouped_by_plot_line = param_metrics.groupby(["METRIC_NAME", "PARAM_DESC", "START_MONTH"])

var_by_plot_line = grouped_by_plot_line.var()
var_by_plot_line.reset_index(inplace = True)
rank_by_plot_line = (var_by_plot_line[["PARAM_DESC"]]).copy()

rank_by_plot_line["VAR_RANK"] = var_by_plot_line.groupby("METRIC_NAME")["METRIC_VAL"].rank()

max_rank_by_param = rank_by_plot_line.groupby("PARAM_DESC").max()
max_rank_by_param = max_rank_by_param.sort_values("VAR_RANK", ascending = False, na_position = 'last')


plot_params = list(max_rank_by_param.index[:num_plot_params]) # PARAM_DESC for highest variance params
In [5]:
plot_data = param_metrics.query("PARAM_SCALE < 10 and PARAM_DESC in @plot_params")
plot_data = plot_data.sort_values(["PARAM_DESC","PARAM_SCALE"])

sns.set(font_scale = 3)
plt.rcParams['image.cmap'] = 'jet'

markers = ['>','o', 'v', '8', 's', '<', 'p', '*', 'h', '^', 'H', 'D', 'd']
markers = markers*(len(np.unique(plot_data.PARAM_DESC))//len(markers)) + markers[:(len(np.unique(plot_data.PARAM_DESC))%len(markers))]

fg = sns.FacetGrid(data=plot_data, row = "METRIC_NAME", col = "START_MONTH",sharex=False, sharey=False, size = 16, col_order = ["Feb","Apr","Jun"])
fg.map(sns.pointplot, "PARAM_SCALE", "METRIC_VAL", "PARAM_DESC", scale = 3, markers = markers, palette="gnuplot")
fg.add_legend()
fg.set_titles("{row_name} using {col_name} ICs", size = 50)
plt.subplots_adjust(top=0.9)
In [6]:
param_metrics["PARAM_SCALE_STR"] = param_metrics["START_MONTH"] + "_PARAM_SCALE_" + param_metrics["PARAM_SCALE"].astype(str)
wide_format_metrics = pd.pivot_table(param_metrics, values='METRIC_VAL', index=['PARAM_SECTION','PARAM_NAME', 'PARAM_DESC', 'METRIC_NAME'], columns=['PARAM_SCALE_STR'])
wide_format_metrics.reset_index(inplace=True)
for start_month in np.unique(param_metrics.START_MONTH):
    wide_format_metrics[start_month + "IC_SLOPE"] = (wide_format_metrics[start_month + "_PARAM_SCALE_1.1"] - wide_format_metrics[start_month + "_PARAM_SCALE_0.9"])/0.2
In [7]:
wide_format_metrics["Feb"] = wide_format_metrics["FebIC_SLOPE"]
wide_format_metrics["Apr"] = wide_format_metrics["AprIC_SLOPE"]
wide_format_metrics["Jun"] = wide_format_metrics["JunIC_SLOPE"]

sns.set(font_scale = 2)
cmap = plt.get_cmap('Set2')
for metric_name in np.unique(wide_format_metrics["METRIC_NAME"]):
    x = wide_format_metrics.query("METRIC_NAME == @metric_name")
    x.reset_index(inplace=True)
    categories = np.unique(x["PARAM_SECTION"])
    colors = np.linspace(0, 1, len(categories))
    colordict = dict(zip(categories, colors))  
    x["COLOR"] = x["PARAM_SECTION"].apply(lambda x: colordict[x])
    
    plot_columns = ["Feb", "Apr", "Jun"]
    axes = x.plot("PARAM_DESC",  plot_columns,kind = 'barh', figsize = (7,0.5*len(x[plot_columns[0]])), 
                  title = metric_name, subplots = True, sharex = True, sharey = True, layout = (1,len(plot_columns)),
                  color = [cmap(x.COLOR)]*len(plot_columns), legend = False)
    plt.subplots_adjust(top=0.96)
    for i in range(len(plot_columns)):
        for tick in axes[0,i].get_xticklabels():
            tick.set_rotation(-45)
            
        #Fill in missing values with "Null" on the barplots    
        col_val_is_null = pd.isnull(x[plot_columns[i]])
        col_null_desc = list(map(lambda b: (["","Null"][int(b)]), col_val_is_null))
        for idx, p in enumerate(axes[0,i].patches):
            axes[0,i].annotate(col_null_desc[idx], (p.get_x() - 0.8, p.get_y() ))