###################################################################
# This script plot metric vs each parameter
# emulator prediction and uncertainty for a new collection of
# sample_size points
# It also plot the learning data base
# You must run htune_emulator_predictions.R -wave w before
# To do : 
# - clean and merge with plot_emulator_prediction.py
#
# Najda Villefranque - Maelle Coulon--Decorzens
# aout 2025
###################################################################

import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import custom_legend as cl 
import argparse
import glob

parser=argparse.ArgumentParser()
parser.add_argument("-w",help="list of waves to plot (by number)", metavar="wavesList", default="1")
parser.add_argument("-p",help="list of params to plot (by name)", metavar="paramList", default="all")
parser.add_argument("-y",help="ylim of the plots", metavar="ylim", default=None)
parser.add_argument("-n",help="plot every n points", metavar="npoints", default=None)
parser.add_argument("-N",help="Nstd for confidence interval", metavar="Nstd", default=3)
args=parser.parse_args()

# retrieve args
npoints = int(args.n) if args.n is not None else None
ylim = [float(u) for u in args.y.split(",")] if args.y is not None else None
paramList = args.p.split(",")
wavesList = [int(i) for i in args.w.split(",")]
maxwave = max(wavesList)
nstd=float(args.N)

# fichiers qui contiennent les données (ref et predictions émulateurs)
f_references  = ["metrics_REF_%i.csv"%nwave for nwave in wavesList]
f_sondes      = ["WAVE%i/Metrics.csv"%nwave for nwave in wavesList]
p_sondes      = ["WAVE%i/Params.asc"%nwave for nwave in wavesList]
all_f_predictions = glob.glob("Predictions_Wave*.asc") #%maxwave
if len(all_f_predictions) == 0 : 
  print("error: could not find any Predictions_Wave*asc file.")
  print("Consider running\n  Rscript htune_emulator_predictions.R -wave %i\nand retry"%maxwave)
  exit(1)
all_f_predictions.sort()
for f in all_f_predictions: 
  wave_id = int(f.split("_Wave")[-1].split(".asc")[0])
  if wave_id >= maxwave:f_predictions = f ; break

print("Will plot file %s"%all_f_predictions)

def read_csv_file(f, sep=",", exclude_first_col=1):
  dat=[]
  with open(f) as csvfile:
    reader = csv.reader(csvfile, delimiter=sep)
    for row in reader: dat+= [row[exclude_first_col:]]
  return dat

# lire les références pour les différentes métriques
met_names  = []
met_values = []
met_uncs   = []
for f in f_references:
  nam,val,unc = read_csv_file(f)
  met_names  += [nam]
  met_values += [val]
  met_uncs   += [unc]

# lire le fichier des prédictions
npar, nmet = [int(u) for u in open(f_predictions).readline()[1:].split()]
prediction_df = pd.read_csv(f_predictions, sep=" ", skiprows=1)
tab_parameters = prediction_df.iloc[:,:npar].to_numpy()
tab_metrics    = prediction_df.iloc[:,npar:-1]
vec_iwave_ruled_out = prediction_df.iloc[:,-1].to_numpy()
par_names = prediction_df.columns[:npar]

sondes_met = []
sondes_par = []
for f,p in zip(f_sondes,p_sondes):
  sondes_met += [np.genfromtxt(f, skip_header=1, dtype=float, delimiter=",").transpose()[1:,:].tolist()]
  sondes_par += [np.genfromtxt(p, skip_header=1, dtype=float, delimiter=" ").transpose()[1:,:].tolist()]

if "all" in paramList : paramList = par_names
npar_toplot = len(paramList)
#fig,axes = plt.subplots(ncols=npar_toplot, figsize=(8*npar_toplot,6))
fig,axes = plt.subplots(figsize=(5.5,5.1))
axes.set_box_aspect(1)
#fig,axes = plt.subplots(1, figsize=(8*1,6))
#if npar_toplot == 1 : axes = [axes] ; axes[0].set_box_aspect(1)

# plot custom
fontsize=12
markers = [".","s","d"]
markers_waves = ["*","D","^","v"]
from mycolors import rainbow6 as cols
cols = cols[2:]
from mycolors import basic5 as cols
cols = cols[1:]
cols = plt.rcParams['axes.prop_cycle'].by_key()['color']
cols = ["#dea73a", "#d92120", "#404096"]*5
cols_mark =["#8C6518", "#8E1515", "#151532"]*5
fact=3
size_mark =[fact*2, fact*1, fact*1,fact*1,fact*1]*5
## couleurs Maelle (pour une vague) ##
col_predt="#185B8C"   #prediction +std emulateur points retenus
col_predf="#4CA3E1"    #prediction +std emulateurs, points rejetés
col_totstdt="#6FAD52"  #prediction + (std**2+tol**2)**0.5 points retenus
col_totstdf="#F5C2C2"   #prediction + (std**2+tol**2)**0.5 points rejetés

handle_f=[None]*len(wavesList)
handle_t=[None]*len(wavesList)

for iwave,wave in enumerate(wavesList):
  sondes_par_w = sondes_par[iwave]
  sondes_met_w = sondes_met[iwave]
  for imet,metric in enumerate(met_names[iwave]):
    label_mw = "%s_WAVE%i"%(metric,wave)
    expect = prediction_df["E_"+label_mw].to_numpy()
    emul_std = np.sqrt(prediction_df["V_"+label_mw].to_numpy())
    sondes_met_mw = sondes_met_w[imet]
    print(min(emul_std),max(emul_std))
    ruled_out_now = vec_iwave_ruled_out==wave
    not_ruled_out_yet = ~((vec_iwave_ruled_out!=0) & (vec_iwave_ruled_out<=wave))
    ruled_out = (vec_iwave_ruled_out<=wave)

    if npoints is not None :
      expect            = expect[::npoints]
      emul_std          = emul_std[::npoints]
      ruled_out_now     = ruled_out_now[::npoints]
      not_ruled_out_yet = not_ruled_out_yet[::npoints]

    iparam=-1
    for param in par_names:
      if not param in paramList: continue
      iparam += 1
      sondes_par_pw = sondes_par_w[iparam]

      param_vals = tab_parameters[:,iparam]
      if npoints is not None : param_vals = param_vals[::npoints]

      #plt.sca(axes[iparam])
      std = np.sqrt(float(met_uncs[iwave][imet]))

      # les points qui ont été éliminés à cette vague en claire
      l,_,_ = plt.errorbar(param_vals[ruled_out],
              expect[ruled_out], yerr=nstd*emul_std[ruled_out],
              ls="", marker=markers[imet], elinewidth=1.,  fmt="none",
              alpha=0.1, color=cols[iwave], zorder=2+iwave*10, markersize=4,
              label="Points rejetés vague "+str(iwave+1))
      plt.scatter(param_vals[ruled_out], expect[ruled_out],
              marker=markers[imet],s=4**2,
              alpha=0.1, color=cols_mark[iwave], zorder=3+iwave*10)
      #handle_predf = Line2D([0], [0], color=col_predf, marker='o', linestyle='')
      #l,_,_ = plt.errorbar(param_vals[ruled_out_now],
      #        expect[ruled_out_now], yerr=nstd*((emul_std[ruled_out_now]**2+std**2)**0.5),
      #        ls="", marker=markers[imet], elinewidth=1.5, 
      #        alpha=1, color=col_totstdf, zorder=1)
      if(handle_f[iwave]==None) : 
        handle_f[iwave]   = Line2D([0], [0], markerfacecolor=cols_mark[iwave], 
                                   markeredgecolor=cols[iwave],marker='o',
                                   linestyle='', alpha=0.5)
      
      #trie des tableau : 
      #param_vals=param_vals
      #expect=expect
      #emul_std=emul_std
      #idx = np.argsort(param_vals)
      #param_vals_sorted = param_vals[idx]
      #expect_sorted = expect[idx]
      #emul_std_sorted = emul_std[idx]

      #plt.plot(param_vals_sorted, expect_sorted, ls="-", marker=markers[imet], lw=2, 
      #        color=cols[iwave], zorder=3)
      #plt.fill_between(param_vals_sorted,
      #        y1=expect_sorted+nstd*emul_std_sorted,
      #        y2=expect_sorted-nstd*emul_std_sorted,
      #        alpha=0.5, color=cols[iwave], zorder=2)



      #trie des tableau : 
      #param_vals_ruled_out=param_vals[ruled_out]
      #expect_ruled_out=expect[ruled_out]
      #emul_std_ruled_out=emul_std[ruled_out]
      #idx = np.argsort(param_vals_ruled_out)
      #param_vals_ro_sorted = param_vals_ruled_out[idx]
      #expect_ro_sorted = expect_ruled_out[idx]
      #emul_std_ro_sorted = emul_std_ruled_out[idx]

      #plt.plot(param_vals_ro_sorted, expect_ro_sorted, ls="-", lw=1.5, 
      #        alpha=0.3, color=cols[iwave], zorder=3+iwave)
      #plt.fill_between(param_vals_ro_sorted,
      #        y1=expect_ro_sorted+nstd*emul_std_ro_sorted,
      #        y2=expect_ro_sorted-nstd*emul_std_ro_sorted,
      #        alpha=0.2, color=cols[iwave], zorder=2+iwave)
      #plt.plot(param_vals_ro_sorted,
      #        expect_ro_sorted+nstd*((emul_std_ro_sorted**2+std**2)**0.5),
      #        alpha=1, color=col_totstdf, zorder=1)
      #plt.plot(param_vals_ro_sorted,
      #        expect_ro_sorted-nstd*((emul_std_ro_sorted**2+std**2)**0.5),
      #        alpha=1, color=col_totstdf, zorder=1)

      # les points qui n'ont pas encore été éliminés en foncé
      #plt.errorbar(param_vals[not_ruled_out_yet],
      #        expect[not_ruled_out_yet], yerr=nstd*emul_std[not_ruled_out_yet],
      #        ls="", marker=markers[imet], elinewidth=1.5,
      #        color=col_predt, 
      #        label="$\mu_{OLR} \pm %i \sigma_{OLR}$ "%nstd,zorder=3)
      #handle_predt = Line2D([0], [0], color=col_predt, marker='o', linestyle='')
      #plt.errorbar(param_vals[not_ruled_out_yet],
      #        expect[not_ruled_out_yet], yerr=nstd*((emul_std[not_ruled_out_yet]**2+std**2)**0.5),
      #        ls="", marker=markers[imet], elinewidth=1.5,
      #        color=col_totstdt, alpha=1., 
      #        label="$\mu_{OLR} \pm %i \sqrt{\sigma_{OLR}^2+T_{OLR}^2}$ "%nstd,zorder=2)
      #handle_totstdt = Line2D([0], [0], color=col_totstdt, alpha=1., marker='o', linestyle='')

      l,_,_ = plt.errorbar(param_vals[not_ruled_out_yet],
              expect[not_ruled_out_yet], yerr=nstd*emul_std[not_ruled_out_yet],
              ls="", marker=markers[imet], elinewidth=1., 
              alpha=1., color=cols[iwave], zorder=2+iwave*10, markersize=4,
              label="Points acceptés vague "+str(iwave))
      plt.scatter(param_vals[not_ruled_out_yet], expect[not_ruled_out_yet],
              marker=markers[imet],s=4**2,
              alpha=1., color=cols_mark[iwave], zorder=3+iwave*10)
      if(handle_t[iwave]==None) : 
        handle_t[iwave]   = Line2D([0], [0], markerfacecolor=cols_mark[iwave], 
                            markeredgecolor=cols[iwave],marker='o',
                            linestyle='', alpha=1.)

      
      #param_vals_keep=param_vals[not_ruled_out_yet]
      #expect_keep=expect[not_ruled_out_yet]
      #emul_std_keep=emul_std[not_ruled_out_yet]
      #idx = np.argsort(param_vals_keep)
      #param_vals_k_sorted = param_vals_keep[idx]
      #expect_k_sorted = expect_keep[idx]
      #emul_std_k_sorted = emul_std_keep[idx]

      #plt.plot(param_vals_k_sorted, expect_k_sorted, ls="-", lw=1.5, 
      #        color=cols[iwave], zorder=3+iwave)
      #plt.fill_between(param_vals_k_sorted,
      #        y1=expect_k_sorted+nstd*emul_std_k_sorted,
      #        y2=expect_k_sorted-nstd*emul_std_k_sorted,
      #        alpha=0.5, color=cols[iwave], zorder=2+iwave)
      #plt.plot(param_vals_k_sorted,
      #        expect_k_sorted+nstd*((emul_std_k_sorted**2+std**2)**0.5),
      #        alpha=1, color=col_totstdt, zorder=1)
      #plt.plot(param_vals_k_sorted,
      #        expect_k_sorted-nstd*((emul_std_k_sorted**2+std**2)**0.5),
      #        alpha=1, color=col_totstdt, zorder=1)

      #plot des points sondes
      #plt.scatter(sondes_par_pw, sondes_met_mw, color=cols_mark[iwave],
      #            s=size_mark[iwave]**2, marker=markers_waves[iwave], zorder=3+iwave,alpha=0.6)
      #handle_sonde = Line2D([0], [0], color='black', marker=markers_waves[iwave], markersize=7, linestyle='')

      # la ref et son range d'incertitude
      ref = float(met_values[iwave][imet])
      std = np.sqrt(float(met_uncs[iwave][imet]))
      
      plt.axhline(ref, color="black", ls="-", lw=1, zorder=-1)
      plt.axhline(ref-nstd*std, color="black", ls="--", lw=1.0, zorder=-1)
      plt.axhline(ref+nstd*std, color="black", ls="--", lw=1.0, zorder=-1 )
      handle_ref=Line2D([0], [0], color="black", linestyle='-', lw=1)
      #plt.ylabel(metric,fontsize=fontsize)
      plt.ylabel("$f_{OLR}$",fontsize=fontsize)
      
      if iwave==0:
        #plt.plot([],[], color="black", label="Reference $\pm$ %i std "%nstd)
        plt.plot([],[], color="black") #, label="$r_{olr} \pm %i T_{OLR}$ "%nstd)
        plt.xlabel(param, fontsize=fontsize)
        plt.xticks(fontsize=fontsize)
        plt.yticks(fontsize=fontsize)
        plt.ylim(ylim)
        leg_pred="$\mu_{OLR} \pm %i \sigma_{OLR}$ "%nstd
        leg_totstd="$\mu_{OLR} \pm %i \sqrt{\sigma_{OLR}^2+T_{OLR}^2}$ "%nstd 
        leg_ref="$r_{OLR} \pm %i T_{OLR}$ "%nstd
        # Légende : deux handles regroupés + deux textes seuls
        #plt.legend([(handle_predt,handle_predf), (handle_totstdt,
        #          handle_totstdf), handle_sonde, handle_ref], [leg_pred, leg_totstd, "Sondes", leg_ref], 
         #         handler_map={tuple: cl.HandlerErrorBar()}, loc="upper left")
      plt.legend([(handle_t[i], handle_f[i]) for i in range(iwave+1)],
                 [f"Points acceptés / rejetés vague {i+1}" for i in range(iwave+1)],
                 handler_map={tuple:cl.HandlerErrorBarbicol()},
                 loc="upper center", bbox_to_anchor=(0.5,-0.15),
                 frameon=False, ncol=1, fontsize=fontsize)
  #plt.savefig("predictions_up_to_wave%i.png"%wave, dpi=360)
  plt.savefig("predictions_up_to_wave%iwleg.png"%wave, bbox_inches='tight')
  axes.get_legend().remove()
  plt.savefig("predictions_up_to_wave%inoleg.png"%wave, bbox_inches='tight')
#plt.show()
