import numpy as np
import xarray as xr
import math
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os
import f90nml
from salishsea_tools import metric_tools_5x5 as met
%matplotlib inline
plt.rcParams['image.cmap'] = 'jet'
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
reference_namelist_file = '/data/jpetrie/MEOPAR/SS-run-sets/SS-SMELT/namelists/namelist_pisces_cfg_5x5_NewIC'
reference_bio_params = f90nml.read(reference_namelist_file)
start_month_tracer_file = {"Jun":'SS5x5_1h_20150601_20150829_ptrc_T.nc',
"Feb":'SS5x5_1h_20150201_20150501_ptrc_T.nc',
"Apr":'SS5x5_1h_20150401_20150629_ptrc_T.nc'
}
batch_directories = [
'/data/jpetrie/MEOPAR/SalishSea/results/all_params_AprIC_june_22/',
'/data/jpetrie/MEOPAR/SalishSea/results/all_params_JunIC_june_23/',
'/data/jpetrie/MEOPAR/SalishSea/results/nampiszoo_june_14/',
'/data/jpetrie/MEOPAR/SalishSea/results/nampisopt_june_14/',
'/data/jpetrie/MEOPAR/SalishSea/results/nampismes_june_14/',
'/data/jpetrie/MEOPAR/SalishSea/results/nampissink_june_17/',
'/data/jpetrie/MEOPAR/SalishSea/results/nampisprod_june_16/',
'/data/jpetrie/MEOPAR/SalishSea/results/nampismort_june_17/',
'/data/jpetrie/MEOPAR/SalishSea/results/nampisrem_june_17/',
'/data/jpetrie/MEOPAR/SalishSea/results/nampismezo_june_20/',
]
metric_func_list = [
met.mean_NO3_at_20m,
#met.peak_3_day_biomass,
#met.time_surface_NO3_drops_below_4,
#met.time_of_peak_PHY2,
#met.mean_NH4_at_depth,
#met.mean_NO3_at_depth,
#met.mean_DON_at_depth,
#met.mean_PON_at_depth,
]
param_metrics = pd.DataFrame()
for batch_dir in batch_directories:
for file in os.listdir(batch_dir):
if 'zz_frac_waste' not in file:
found_tracer_file = False
for start_month in start_month_tracer_file:
if os.path.isfile(batch_dir + '/' + file + '/' + start_month_tracer_file[start_month]):
found_tracer_file = True
break
if found_tracer_file is True:
last_underscore = file.rfind('_')
first_underscore = file.find('_')
param_section = file[:first_underscore]
param_name = file[(first_underscore+1):last_underscore]
param_val = float(file[(last_underscore+1):])
param_desc = param_section + ' ' + param_name
original_val = reference_bio_params[param_section][param_name]
#if len(param_metrics.query("PARAM_NAME == @param_name and PARAM_VAL == @param_val and START_MONTH == @start_month")) > 0:
# continue
if type(original_val) is list:
param_scale = round(param_val/original_val[0], 3)
else:
param_scale = round(param_val/original_val, 3)
grid_t = xr.open_dataset(batch_dir + '/' + file +'/' + start_month_tracer_file[start_month])
for metric_func in metric_func_list:
metric_val = metric_func(grid_t)
metric_name = metric_func.__name__
# inefficient to keep appending, but much less expensive than other parts of the loop so it doesn't matter
param_metrics = param_metrics.append(pd.DataFrame({"PARAM_SECTION":[param_section],
"PARAM_NAME":[param_name],
"PARAM_DESC":[param_desc],
"PARAM_VAL":[param_val],
"PARAM_SCALE":[param_scale],
"METRIC_NAME":[metric_name],
"METRIC_VAL": [metric_val],
"START_MONTH":[start_month]}))
num_plot_params = 15
grouped_by_plot_line = param_metrics.groupby(["METRIC_NAME", "PARAM_DESC", "START_MONTH"])
var_by_plot_line = grouped_by_plot_line.var()
var_by_plot_line.reset_index(inplace = True)
rank_by_plot_line = (var_by_plot_line[["PARAM_DESC"]]).copy()
rank_by_plot_line["VAR_RANK"] = var_by_plot_line.groupby("METRIC_NAME")["METRIC_VAL"].rank()
max_rank_by_param = rank_by_plot_line.groupby("PARAM_DESC").max()
max_rank_by_param = max_rank_by_param.sort_values("VAR_RANK", ascending = False, na_position = 'last')
plot_params = list(max_rank_by_param.index[:num_plot_params]) # PARAM_DESC for highest variance params
plot_data = param_metrics.query("PARAM_SCALE < 10 and PARAM_DESC in @plot_params")
plot_data = plot_data.sort_values(["PARAM_DESC","PARAM_SCALE"])
sns.set(font_scale = 3)
plt.rcParams['image.cmap'] = 'jet'
markers = ['>','o', 'v', '8', 's', '<', 'p', '*', 'h', '^', 'H', 'D', 'd']
markers = markers*(len(np.unique(plot_data.PARAM_DESC))//len(markers)) + markers[:(len(np.unique(plot_data.PARAM_DESC))%len(markers))]
fg = sns.FacetGrid(data=plot_data, row = "METRIC_NAME", col = "START_MONTH",sharex=False, sharey=False, size = 16, col_order = ["Feb","Apr","Jun"])
fg.map(sns.pointplot, "PARAM_SCALE", "METRIC_VAL", "PARAM_DESC", scale = 3, markers = markers, palette="gnuplot")
fg.add_legend()
fg.set_titles("{row_name} using {col_name} ICs", size = 50)
plt.subplots_adjust(top=0.9)
param_metrics["PARAM_SCALE_STR"] = param_metrics["START_MONTH"] + "_PARAM_SCALE_" + param_metrics["PARAM_SCALE"].astype(str)
wide_format_metrics = pd.pivot_table(param_metrics, values='METRIC_VAL', index=['PARAM_SECTION','PARAM_NAME', 'PARAM_DESC', 'METRIC_NAME'], columns=['PARAM_SCALE_STR'])
wide_format_metrics.reset_index(inplace=True)
for start_month in np.unique(param_metrics.START_MONTH):
wide_format_metrics[start_month + "IC_SLOPE"] = (wide_format_metrics[start_month + "_PARAM_SCALE_1.1"] - wide_format_metrics[start_month + "_PARAM_SCALE_0.9"])/0.2
wide_format_metrics["Feb"] = wide_format_metrics["FebIC_SLOPE"]
wide_format_metrics["Apr"] = wide_format_metrics["AprIC_SLOPE"]
wide_format_metrics["Jun"] = wide_format_metrics["JunIC_SLOPE"]
sns.set(font_scale = 2)
cmap = plt.get_cmap('Set2')
for metric_name in np.unique(wide_format_metrics["METRIC_NAME"]):
x = wide_format_metrics.query("METRIC_NAME == @metric_name")
x.reset_index(inplace=True)
categories = np.unique(x["PARAM_SECTION"])
colors = np.linspace(0, 1, len(categories))
colordict = dict(zip(categories, colors))
x["COLOR"] = x["PARAM_SECTION"].apply(lambda x: colordict[x])
plot_columns = ["Feb", "Apr", "Jun"]
axes = x.plot("PARAM_DESC", plot_columns,kind = 'barh', figsize = (7,0.5*len(x[plot_columns[0]])),
title = metric_name, subplots = True, sharex = True, sharey = True, layout = (1,len(plot_columns)),
color = [cmap(x.COLOR)]*len(plot_columns), legend = False)
plt.subplots_adjust(top=0.96)
for i in range(len(plot_columns)):
for tick in axes[0,i].get_xticklabels():
tick.set_rotation(-45)
#Fill in missing values with "Null" on the barplots
col_val_is_null = pd.isnull(x[plot_columns[i]])
col_null_desc = list(map(lambda b: (["","Null"][int(b)]), col_val_is_null))
for idx, p in enumerate(axes[0,i].patches):
axes[0,i].annotate(col_null_desc[idx], (p.get_x() - 0.8, p.get_y() ))