import xarray as xr
import numpy as np
import cartopy.crs as ccrs
import copy
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from WD.plotting import plot_map, add_label_to_axes
# usually this would go on top of the notebook:
plt.rcParams.update({'font.size': 8}) # sets font size for all cells
plt.rcParams['figure.dpi'] = 300 # sets dpi for all cells
alphabet_letters = np.array(list(map(chr, range(65, 91)))) # used for labelling subplots
textwidth = 170 # 144 # in mm
mm_to_inch = 0.0393701
textwidth = textwidth * mm_to_inch # textwidth in inches
# ds_id = "96FE8A" # "278771"
# model_id = "290AD8" # "011A3B" "C7A2A3"
# targets = xr.load_dataset(f"/data/compoundx/WeatherDiff/model_output/{ds_id}/{model_id}_target.nc")
# predictions = xr.load_dataset(f"/data/compoundx/WeatherDiff/model_output/{ds_id}/{model_id}_gen.nc")
data_template_name = "geopotential_500_highres"
experiment_name = "fourcastnet_highres"
run_date = "2023-08-18_13-10-40"
eval_date = "2023-08-19_11-01-41"
"""
data_template_name = "geopotential_500"
experiment_name = "fourcastnet_small"
run_date = "2023-08-18_16-46-00"
eval_date = "2023-08-18_19-01-32"
"""
"""
data_template_name = "rasp_thuerey_geopotential"
experiment_name = "fourcastnet_rasp_thuerey"
run_date = "2023-08-18_14-18-40"
eval_date = "2023-08-18_17-47-29"
"""
"""
data_template_name = "geopotential_500_highres"
experiment_name = "fourcastnet_highres"
run_date = "2023-08-17_20-18-22"
eval_date = "2023-08-18_12-11-04"
"""
"""
data_template_name = "geopotential_500"
experiment_name = "fourcastnet"
run_date = "2023-08-18_13-10-44"
eval_date = "2023-08-18_16-07-24"
"""
targets = xr.load_dataset(f"/data/compoundx/WeatherDiff/model_output/{data_template_name}/{experiment_name}/{run_date}/{eval_date}/target.nc")
predictions = xr.load_dataset(f"/data/compoundx/WeatherDiff/model_output/{data_template_name}/{experiment_name}/{run_date}/{eval_date}/gen.nc")
diff = targets - predictions
n_images = 8
timesteps = np.random.choice(np.arange(len(predictions["init_time"])), size=(n_images,))
# do configurations for plotting - these can also be shared or "inherited" for plots that are similar!
config = {
"CMAP": "viridis",
"NORM": matplotlib.colors.Normalize(vmin=49000, vmax=59000),
"RASTERIZED": True, # don't plot map pixels as individual points to make created files smaller
"SHOW_COLORBAR": True,
"CBAR_ORIENTATION": "horizontal",
"CBAR_EXTEND": "both",
"SHOW_COLORBAR_LABEL": False,
"CBAR_LABEL": r"Geopotential [$m^2/s^2$]",
"TITLE": "",
"TITLE_FONTSIZE": 8,
"PROJECTION": ccrs.Robinson(), # this is not called by plot_map, but by the function we create the entire plot with.
"ASPECT_RATIO": 6/5 # can be used to calculate a figsize that looks nice for a given type of plot
}
config_diff = copy.deepcopy(config)
config_diff["CMAP"] = "RdBu"
config_diff["NORM"] = matplotlib.colors.Normalize(vmin=-3000, vmax=3000)
n_rows = n_images
n_cols = 3
figure_width = textwidth
# calculate height from number of rows, cols and aspect ratio (+ do some fine tuning)
figure_height = textwidth * (n_rows / n_cols) / config["ASPECT_RATIO"]
fig = plt.figure(figsize = [figure_width, figure_height])
gs = gridspec.GridSpec(n_rows, n_cols, figure=fig, width_ratios=[1,1,1])
for i, i_t in enumerate(timesteps):
ax = fig.add_subplot(gs[i, 0], projection=config["PROJECTION"])
# plot the map:
plot_map(ax, data=targets.isel({"init_time":i_t, "lead_time":0, "ensemble_member": 0})[list(targets.keys())], plotting_config=config, title="Target")
# add a lael to the panel of the plot:
add_label_to_axes(ax, "({}1)".format(alphabet_letters[i]))
ax = fig.add_subplot(gs[i, 1], projection=config["PROJECTION"])
# plot the map:
plot_map(ax, data=predictions.isel({"init_time":i_t, "lead_time":0, "ensemble_member": 0})[list(predictions.keys())], plotting_config=config, title="Prediction")
# add a lael to the panel of the plot:
add_label_to_axes(ax, "({}2)".format(alphabet_letters[i]))
ax = fig.add_subplot(gs[i, 2], projection=config["PROJECTION"])
# plot the map:
plot_map(ax, data=diff.isel({"init_time":i_t, "lead_time":0, "ensemble_member": 0})[list(diff.keys())], plotting_config=config_diff, title="Difference:")
# add a lael to the panel of the plot:
add_label_to_axes(ax, "({}3)".format(alphabet_letters[i]))
fig.canvas.draw()
fig.tight_layout()
plt.show()
from benchmark.bm.score import compute_weighted_rmse, compute_weighted_mae, compute_weighted_acc
rmse = compute_weighted_rmse(predictions.isel({"ensemble_member": 0}), targets.isel({"ensemble_member": 0}))
print("RMSE is {:.1f}".format(rmse.z_500.values))
mae = compute_weighted_mae(predictions.isel({"ensemble_member": 0}), targets.isel({"ensemble_member": 0}))
print("MAE is {:.1f}".format(mae.z_500.values))
acc = compute_weighted_acc(predictions.isel({"ensemble_member": 0}), targets.isel({"ensemble_member": 0}))
print("ACC is {:.2f}".format(acc.z_500.values))
RMSE is 457.4 MAE is 283.7 ACC is 0.90
# from WD.regridding import regrid_to_res
interpolated_predictions = regrid_to_res(predictions, "5.625deg", reuse_weights=False)
interpolated_targets = regrid_to_res(targets, "5.625deg", reuse_weights=False)
Create weight file: bilinear_64x128_32x64_peri.nc using dimensions ('lat', 'lon') from data variable z_500 as the horizontal dimensions for this dataset. Overwrite existing file: bilinear_64x128_32x64_peri.nc You can set reuse_weights=True to save computing time. using dimensions ('lat', 'lon') from data variable z_500 as the horizontal dimensions for this dataset.
/home/wider/.conda/envs/WD_eval/lib/python3.7/site-packages/xesmf/frontend.py:391: FutureWarning: ``output_sizes`` should be given in the ``dask_gufunc_kwargs`` parameter. It will be removed as direct parameter in a future version. temp_horiz_dims[1]: self.shape_out[1] /home/wider/.conda/envs/WD_eval/lib/python3.7/site-packages/xesmf/frontend.py:391: FutureWarning: ``output_sizes`` should be given in the ``dask_gufunc_kwargs`` parameter. It will be removed as direct parameter in a future version. temp_horiz_dims[1]: self.shape_out[1]
interpolated_targets.mean(("lat", "lon", "lead_time")).z_500.plot()
interpolated_predictions.mean(("lat", "lon", "lead_time")).z_500.plot()
[<matplotlib.lines.Line2D at 0x2b435165e7d0>]
rmse = compute_weighted_rmse(interpolated_predictions.isel({"ensemble_member": 0}), interpolated_targets.isel({"ensemble_member": 0}))
print("RMSE of interpolated data is {:.1f}".format(rmse.z_500.values))
mae = compute_weighted_mae(interpolated_predictions.isel({"ensemble_member": 0}), interpolated_targets.isel({"ensemble_member": 0}))
print("MAE of interpolated data is {:.1f}".format(mae.z_500.values))
acc = compute_weighted_acc(interpolated_predictions.isel({"ensemble_member": 0}), interpolated_targets.isel({"ensemble_member": 0}))
print("ACC of interpolated data is {:.2f}".format(acc.z_500.values))
RMSE of interpolated data is 435.7 MAE of interpolated data is 275.3
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) /tmp/ipykernel_227907/754361577.py in <module> 5 print("MAE of interpolated data is {:.1f}".format(mae.z_500.values)) 6 ----> 7 acc = compute_weighted_acc(interpolated_predictions.isel({"ensemble_member": 0}), interpolated_targets.isel({"ensemble_member": 0})) 8 print("ACC of interpolated data is {:.2f}".format(acc.z_500.values)) /gpfs1/schlecker/home/wider/Projects/diffusion-models-for-weather-prediction/benchmark/bm/score.py in compute_weighted_acc(da_fc, da_true, mean_dims) 84 a_prime = a - a.mean() 85 ---> 86 acc = np.sum(w * fa_prime * a_prime) / np.sqrt( 87 np.sum(w * fa_prime**2) * np.sum(w * a_prime**2) 88 ) <__array_function__ internals> in sum(*args, **kwargs) ~/.conda/envs/WD_eval/lib/python3.7/site-packages/numpy/core/fromnumeric.py in sum(a, axis, dtype, out, keepdims, initial, where) 2258 2259 return _wrapreduction(a, np.add, 'sum', axis, dtype, out, keepdims=keepdims, -> 2260 initial=initial, where=where) 2261 2262 ~/.conda/envs/WD_eval/lib/python3.7/site-packages/numpy/core/fromnumeric.py in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs) 82 return reduction(axis=axis, dtype=dtype, out=out, **passkwargs) 83 else: ---> 84 return reduction(axis=axis, out=out, **passkwargs) 85 86 return ufunc.reduce(obj, axis, dtype, out, **passkwargs) ~/.conda/envs/WD_eval/lib/python3.7/site-packages/xarray/core/common.py in wrapped_func(self, dim, skipna, **kwargs) ~/.conda/envs/WD_eval/lib/python3.7/site-packages/xarray/core/dataset.py in reduce(self, func, dim, keep_attrs, keepdims, numeric_only, **kwargs) ValueError: passing 'axis' to Dataset reduce methods is ambiguous. Please use 'dim' instead.