###################################################################
# This script plot metric vs each parameter
# emulator prediction and uncertainty for a new collection of
# sample_size points
# It also plot the learning data base
# You must run htune_emulator_predictions.R -wave w before
#
# Najda Villefranque - Maelle Coulon--Decorzens
# aout 2025
###################################################################

import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import custom_legend as cl 
import argparse
import glob

parser=argparse.ArgumentParser()
parser.add_argument("-w",help="list of waves to plot (by number)", metavar="wavesList", default="1")
parser.add_argument("-p",help="list of params to plot (by name)", metavar="paramList", default="all")
parser.add_argument("-y",help="ylim of the plots", metavar="ylim", default=None)
parser.add_argument("-n",help="plot every n points", metavar="npoints", default=None)
parser.add_argument("-N",help="Nstd for confidence interval", metavar="Nstd", default=3)
args=parser.parse_args()

# retrieve args
npoints = int(args.n) if args.n is not None else None
ylim = [float(u) for u in args.y.split(",")] if args.y is not None else None
paramList = args.p.split(",")
wavesList = [int(i) for i in args.w.split(",")]
maxwave = max(wavesList)
nstd=float(args.N)

# fichiers qui contiennent les données (ref et predictions émulateurs)
f_references  = ["metrics_REF_%i.csv"%nwave for nwave in wavesList]
f_sondes      = ["WAVE%i/Metrics.csv"%nwave for nwave in wavesList]
p_sondes      = ["WAVE%i/Params.asc"%nwave for nwave in wavesList]
all_f_predictions = glob.glob("Predictions_Wave*.asc") #%maxwave
if len(all_f_predictions) == 0 : 
  print("error: could not find any Predictions_Wave*asc file.")
  print("Consider running\n  Rscript htune_emulator_predictions.R -wave %i\nand retry"%maxwave)
  exit(1)
all_f_predictions.sort()
for f in all_f_predictions: 
  wave_id = int(f.split("_Wave")[-1].split(".asc")[0])
  if wave_id >= maxwave:f_predictions = f ; break

print("Will plot file %s"%all_f_predictions)

def read_csv_file(f, sep=",", exclude_first_col=1):
  dat=[]
  with open(f) as csvfile:
    reader = csv.reader(csvfile, delimiter=sep)
    for row in reader: dat+= [row[exclude_first_col:]]
  return dat

# lire les références pour les différentes métriques
met_names  = []
met_values = []
met_uncs   = []
for f in f_references:
  nam,val,unc = read_csv_file(f)
  met_names  += [nam]
  met_values += [val]
  met_uncs   += [unc]

# lire le fichier des prédictions
npar, nmet = [int(u) for u in open(f_predictions).readline()[1:].split()]
prediction_df = pd.read_csv(f_predictions, sep=" ", skiprows=1)
tab_parameters = prediction_df.iloc[:,:npar].to_numpy()
tab_metrics    = prediction_df.iloc[:,npar:-1]
vec_iwave_ruled_out = prediction_df.iloc[:,-1].to_numpy()
par_names = prediction_df.columns[:npar]

sondes_met = []
sondes_par = []
for f,p in zip(f_sondes,p_sondes):
  sondes_met += [np.genfromtxt(f, skip_header=1, dtype=float, delimiter=",").transpose()[1:,:].tolist()]
  sondes_par += [np.genfromtxt(p, skip_header=1, dtype=float, delimiter=" ").transpose()[1:,:].tolist()]

if "all" in paramList : paramList = par_names
npar_toplot = len(paramList)
#fig,axes = plt.subplots(ncols=npar_toplot, figsize=(8*npar_toplot,6))
fig,axes = plt.subplots(figsize=(5.5,5.1))
axes.set_box_aspect(1)
#fig,axes = plt.subplots(1, figsize=(8*1,6))
#if npar_toplot == 1 : axes = [axes] ; axes[0].set_box_aspect(1)

# plot custom
fontsize=12
markers = [".","s","d"]
markers_waves = ["*","D","^","v"]
from mycolors import rainbow6 as cols
cols = cols[2:]
from mycolors import basic5 as cols
cols = cols[1:]
cols = plt.rcParams['axes.prop_cycle'].by_key()['color']
## couleurs Maelle (pour une vague) ##
col_predt="#185B8C"   #prediction +std emulateur points retenus
col_predf="#4CA3E1"    #prediction +std emulateurs, points rejetés
col_totstdt="#6FAD52"  #prediction + (std**2+tol**2)**0.5 points retenus
col_totstdf="#F5C2C2"   #prediction + (std**2+tol**2)**0.5 points rejetés


for iwave,wave in enumerate(wavesList):
  sondes_par_w = sondes_par[iwave]
  sondes_met_w = sondes_met[iwave]
  for imet,metric in enumerate(met_names[iwave]):
    label_mw = "%s_WAVE%i"%(metric,wave)
    expect = prediction_df["E_"+label_mw].to_numpy()
    emul_std = np.sqrt(prediction_df["V_"+label_mw].to_numpy())
    sondes_met_mw = sondes_met_w[imet]

    ruled_out_now = vec_iwave_ruled_out==wave
    not_ruled_out_yet = ~((vec_iwave_ruled_out!=0) & (vec_iwave_ruled_out<=wave))

    if npoints is not None :
      expect            = expect[::npoints]
      emul_std          = emul_std[::npoints]
      ruled_out_now     = ruled_out_now[::npoints]
      not_ruled_out_yet = not_ruled_out_yet[::npoints]

    iparam=-1
    for param in par_names:
      if not param in paramList: continue
      iparam += 1
      sondes_par_pw = sondes_par_w[iparam]

      param_vals = tab_parameters[:,iparam]
      if npoints is not None : param_vals = param_vals[::npoints]

      #plt.sca(axes[iparam])
      std = np.sqrt(float(met_uncs[iwave][imet]))

      # les points qui ont été éliminés à cette vague en claire
      l,_,_ = plt.errorbar(param_vals[ruled_out_now],
              expect[ruled_out_now], yerr=nstd*emul_std[ruled_out_now],
              ls="", marker=markers[imet], elinewidth=1.5, 
              alpha=1, color=col_predf, zorder=2)
      handle_predf = Line2D([0], [0], color=col_predf, marker='o', linestyle='')
      l,_,_ = plt.errorbar(param_vals[ruled_out_now],
              expect[ruled_out_now], yerr=nstd*((emul_std[ruled_out_now]**2+std**2)**0.5),
              ls="", marker=markers[imet], elinewidth=1.5, 
              alpha=1, color=col_totstdf, zorder=1)
      handle_totstdf   = Line2D([0], [0], color=col_totstdf, marker='o', linestyle='', alpha=1.)

      # les points qui n'ont pas encore été éliminés en foncé
      plt.errorbar(param_vals[not_ruled_out_yet],
              expect[not_ruled_out_yet], yerr=nstd*emul_std[not_ruled_out_yet],
              ls="", marker=markers[imet], elinewidth=1.5,
              color=col_predt, 
              label="$\mu_{OLR} \pm %i \sigma_{OLR}$ "%nstd,zorder=3)
      handle_predt = Line2D([0], [0], color=col_predt, marker='o', linestyle='')
      plt.errorbar(param_vals[not_ruled_out_yet],
              expect[not_ruled_out_yet], yerr=nstd*((emul_std[not_ruled_out_yet]**2+std**2)**0.5),
              ls="", marker=markers[imet], elinewidth=1.5,
              color=col_totstdt, alpha=1., 
              label="$\mu_{OLR} \pm %i \sqrt{\sigma_{OLR}^2+T_{OLR}^2}$ "%nstd,zorder=2)
      handle_totstdt = Line2D([0], [0], color=col_totstdt, alpha=1., marker='o', linestyle='')

      plt.scatter(sondes_par_pw, sondes_met_mw, color="black", s=49, marker=markers_waves[iwave], zorder=10)
      handle_sonde = Line2D([0], [0], color='black', marker=markers_waves[iwave], markersize=7, linestyle='')
      # la ref et son range d'incertitude
      ref = float(met_values[iwave][imet])
      std = np.sqrt(float(met_uncs[iwave][imet]))
      plt.axhline(ref, color="black", ls="-", lw=1, zorder=-1)
      plt.axhline(ref-nstd*std, color="black", ls="--", lw=1.0, zorder=-1)
      plt.axhline(ref+nstd*std, color="black", ls="--", lw=1.0, zorder=-1 )
      handle_ref=Line2D([0], [0], color="black", linestyle='-', lw=1)
      #plt.ylabel(metric,fontsize=fontsize)
      plt.ylabel("$f_{OLR}$",fontsize=fontsize)
      
      if iwave==0:
        #plt.plot([],[], color="black", label="Reference $\pm$ %i std "%nstd)
        plt.plot([],[], color="black") #, label="$r_{olr} \pm %i T_{OLR}$ "%nstd)
        plt.xlabel(param, fontsize=fontsize)
        plt.xticks(fontsize=fontsize)
        plt.yticks(fontsize=fontsize)
        plt.ylim(ylim)
        leg_pred="$\mu_{OLR} \pm %i \sigma_{OLR}$ "%nstd
        leg_totstd="$\mu_{OLR} \pm %i \sqrt{\sigma_{OLR}^2+T_{OLR}^2}$ "%nstd 
        leg_ref="$r_{OLR} \pm %i T_{OLR}$ "%nstd
        # Légende : deux handles regroupés + deux textes seuls
        plt.legend([(handle_predt,handle_predf), (handle_totstdt,
                  handle_totstdf), handle_sonde, handle_ref], 
                  [leg_pred, leg_totstd, "$F_{OLR}^1$", leg_ref], 
                  handler_map={tuple: cl.HandlerErrorBar()}, 
                  loc="upper center", bbox_to_anchor=(0.5,-0.15),
                  ncol=2, fontsize=fontsize, frameon=False)
      #plt.legend([(handle_t[i], handle_f[i]) for i in range(iwave+1)],
      #           [f"Points acceptés / rejetés vague {i+1}" for i in range(iwave+1)],
      #           handler_map={tuple:cl.HandlerErrorBarbicol()},
      #           loc="upper center", bbox_to_anchor=(0.5,-0.15),
      #           frameon=False, ncol=1, fontsize=fontsize)
  #plt.savefig("predictions_up_to_wave%i.png"%wave, dpi=360)
  plt.savefig("predictions_up_to_wave%iwleg.png"%wave, bbox_inches='tight')
  axes.get_legend().remove()
  plt.savefig("predictions_up_to_wave%inoleg.png"%wave, bbox_inches='tight')
#plt.show()
