# This script plot metric vs parameter 
#        + (optional) best simulations or choosen higlighted simulations
#        + (optional) ensembles highlighted in colors
#
# by default plot ensembles from iwavemin to iwavemax (you must specify at the begining of the script)
# you can also specify ensemble files and their corresponding parameters files and highlighted simulations
# with options
#
# Ensemble files must be csv files :
#   - containing an header with metric_name=head_metric+"_"+metric (head_metric can be WAVEX)
#   - the head of the first line must be "SIM"
#   - separated by ","
# 
# Parameters files must be asc files :
#   - containing an header with parameters name 
#   - the head of the first line must be "t_IDs"
#   - separated by " "

# Adapted from CCCma_scripts and scatter_plot.py (almost same arguments)
#
# Stuff to be done : 
# - maybe add an option to plot on log scales (when parameter are explores in log scale ?)
# - pdfjam by metrics ? -> installer pdfjam sur ricard
# - gérer les colors bars (gestion à la main pour l'instant)
# - gérer les logscale pour les paramètres explorés en log
#
# Maelle Coulon--Decorzens - Frédéric Hourdin - Najda Villefranque - Mars 2024

import pandas as pd
import re
import matplotlib.pyplot as plt
import os
import argparse
import sys
import matplotlib.colors as mcolors
import math

#------------------------------------------------------------------------------
# Some previous variables the user might change 
#------------------------------------------------------------------------------

iwavemin=1
iwavemax=1
path_fig="fig/param_vs_param/" #path where figures are saved (from here)
#param_user="param"

#------------------------------------------------------------------------------
# Script optional arguments 
#------------------------------------------------------------------------------
parser=argparse.ArgumentParser()
parser.add_argument("-o",help="Observation csv file",metavar='OBS.csv')
parser.add_argument("-e",help='List of ensemble csv files "ENS1.csv,ENS2.csv,..."',metavar="ENS1.csv,ENS2.csv,...")
parser.add_argument("-p",help='List of param csv files corresponding to ensemble "PARAM1.asc,PARAM2.asc,..."',metavar="PARAM1.asc,PARAM2.asc,...")
parser.add_argument("-param",help='Name of initial parameter file',metavar="param_NAME or param", default="param")
parser.add_argument("-l",help='List of ensemble csv files you want to be highlighted in colors ENS1.csv,ENS10.csv,...' ,metavar="ENS1.csv,ENS10.csv,...")
parser.add_argument("-b",help='List of names of best simulations sim1,sim2,...',metavar="sim1,sim2,...")
parser.add_argument("-n",help='Number of best simulations to be automatically displayed',metavar="N",default='10')
parser.add_argument("-fig",help='path where figure are saved' ,metavar="dirfig/dir1/dir2/",default='"fig/param_vs_param/"')
parser.add_argument("-x",help="xlim of the plots", metavar="xlim", default=None)
parser.add_argument("-y",help="ylim of the plots", metavar="ylim", default=None)
#parser.add_argument("-c",help='matplotlib colormap : eg, BrBG PuOr RdYlGn Spectral_r bwr coolwarm_r earth_r ocean_r twilight_shifted Accent',metavar="cmap",default='Paired')
args=parser.parse_args()

xlim = [float(u) for u in args.x.split(",")] if args.y is not None else None
ylim = [float(u) for u in args.y.split(",")] if args.y is not None else None

## Quelques trucs pour mes plots ##
darkblue="#1155cc"
lightblue="#6fa8dcff"
red="#cc0000ff"
darkgreen="#314e23"
green="#5b8d43cc"

## User parameter files ##

#if args.param: 
#    param_user = args.param
#    df = pd.read_csv(param_user, sep=r"\s+", header=None, engine="python")
#    #file_param_user = df.set_index(0).T.reset_index(drop=True)
#    file_param_user = df.set_index(0).T.reset_index(drop=True)
#    file_param_user=df.drop(0)


#if args.param:  
#    param_user = args.param
#    file_param_user = pd.read_csv(param_user, delim_whitespace=True, header=None)
#    file_param_user.columns = file_param_user.iloc[0] 
#    file_param_user = file_param_user[1:].reset_index(drop=True).T

if(args.param) : 
  param_user=args.param
file_param_user = pd.read_csv(param_user, delim_whitespace=True, header=None).T
file_param_user.columns = file_param_user.iloc[0]
file_param_user = file_param_user[1:].reset_index(drop=True)

print(file_param_user)
### Highlighted ensembles ##i#
#wave number or ens file metric name you want to be in color

if args.l : 
  enstoprint=args.l.split(',')
  print('Highlighted ensemble from arguments : ',enstoprint)
elif(iwavemax > 3 ) : 
  enstoprint=[1,40,iwavemax] #wave number or ens file metric name you to be in color
  print('Highlighted ensemble from WAVES '+str(iwavemin)+' to '+str(iwavemax))
else : 
  enstoprint=[1,iwavemax]
  print('Highlighted ensemble from WAVES '+str(iwavemin)+' to '+str(iwavemax))


### Obs file ###
if args.o :
  obs_file=args.o
  print('Obs file from arguments : ',obs_file)
else :
  #obs_file="WAVE"+str(iwavemin)+"/metrics_REF_"+str(iwavemin)+".csv"
  obs_file="metrics_REF_"+str(iwavemin)+".csv"
  print('Obs file from WAVE'+str(iwavemin)+' : '+obs_file)

obs=pd.read_csv(obs_file,sep=",")
obs_mean=obs.loc[0]
obs_var=obs.loc[1]


### Metrics and Parameters ensembles ###

if ( ((args.e is not None) & (args.p is None)) | ((args.p is not None) & (args.e is None) )) : 
  print("if you give specific ensemble in argument you must give their parameters files with -p option")
  print("Stop the program here")
  sys.exit()

if args.e :
  metric_files=args.e.split(",")
  print('Ensemble metrics files from arguments : ',metric_files)
  nwave=len(metric_files)
  file_metrics_ref=metric_files[0]
  param_files=args.p.split(",")
  file_params_ref=param_files[0]
  print('Ensemble parameter files from arguments : ',param_files)
else :
  metric_files=[]
  param_files=[]
  for iwave in range(iwavemin,iwavemax+1) : 
    metric_files.append("WAVE"+str(iwave)+"/metrics_WAVE"+str(iwave)+"_"+str(iwave)+".csv")
    param_files.append("WAVE"+str(iwave)+"/Par1D_Wave"+str(iwave)+".asc")
  file_metrics_ref="WAVE"+str(iwavemin)+"/metrics_WAVE"+str(iwavemin)+"_"+str(iwavemin)+".csv"
  file_params_ref="WAVE"+str(iwavemin)+"/Par1D_Wave"+str(iwavemin)+".asc"
  print('Ensemble metrics from WAVES : ',metric_files)
  print('Ensemble parameter files from WAVES : ',param_files)

metrics_ref=pd.read_csv(file_metrics_ref,sep=",")
params_ref=pd.read_csv(file_params_ref,sep=' ')

### Bests simulations ###
list_bests=[]
nbests_auto=int(args.n)

if args.b :
  list_bests=args.b.split(",")
elif(args.e is None) : #controled by waves
  if(nbests_auto > 0 ) : 
    #file_score="score"+str(iwavemin)+"to"+str(iwavemax)+".csv"
    file_score="score41to42.csv" #"+str(iwavemin)+"to"+str(iwavemax)+".csv"
    if(os.path.exists(file_score) ==  False) :
      os.system(" head -1 score"+str(iwavemin)+".csv > "+file_score)
      os.system(" for i in $( seq "+str(iwavemin)+" "+str(iwavemax)+") ; do tail -n +2 score$i.csv >> "+file_score+" ; done")
    
    score=pd.read_csv(file_score, sep=",")
    sorted_score=score.sort_values(by="MAX", ascending=True)
    selected_simu=sorted_score[0:nbests_auto]['SIM']
    for simu in selected_simu : 
      list_bests.append(simu)
else : 
  print("if you specify ensembles you must specify with option -b the list of simulations you want to highlight, or it will be none of them")
  list_bests=[]

### Dico bests -> file metrics et bests -> file params ###
dico_bests={}
dico_bests["metrics"]={}
dico_bests["params"]={}

imet=0
for fm in metric_files : 
  isimu=0
  for simu in list_bests : 
    res=os.system("res=`grep "+str(simu)+" "+fm+"`")
    if( res == 0 ) :
      dico_bests["metrics"][simu]=fm
      dico_bests["params"][simu]=param_files[imet]
    isimu=isimu+1
  imet=imet+1
print("list_bests = ", list_bests)

### Plot colors ###
#colors_ens=['#d0d0d0', '#3bafba', '#d56a6a', '#cdc25b']
listcoul=['red', 'green','fuchsia','blue','lime','darkviolet','cyan','darkorange','slateblue','brown','gold']
#colors_ens=['lightgrey','mistyrose','lemonchiffon','palegreen','paleturquoise']
colors_ens=['lightgrey','#5fb6be', '#d56a6a','#cdc25b','#72BB72','#3bafba']

# darkblue
colors_ens=[lightblue,darkblue,'darkblue',darkblue,lightblue, '#d56a6a','#cdc25b','#72BB72','#3bafba']
#colors_ens=[red,lightblue,darkblue,'darkblue',darkgreen,green,lightblue, '#d56a6a','#cdc25b','#72BB72','#3bafba']
markeredgecolor='k'
marker_ens=["o","o","D","*","*","*"]
marker_ens=["o","D","*","*","*"]
#marker_ens=["o","o","o","o","*","*"]
fact=1.5
ms_ens=[fact*10,fact*10+2,fact*10,fact*36]
ms_ens=[fact*10+2,fact*10,fact*40,fact*40]
#ms_ens=[fact*10,fact*10,fact*10,fact*10]

label_ens=["Membres de $\Lambda_{10}^1$", "Membres de $\Lambda_{10}^2$", "Membres de $\Lambda_{10}^3$"]
if(args.fig) : 
  path_fig=args.fig+"/"
  if(os.path.exists(path_fig) ==  False) :
    os.system("mkdir -p "+path_fig)
#------------------------------------------------------------------------------
# Loop on metrics and parameters
#------------------------------------------------------------------------------

for imetric,metric_ in enumerate(metrics_ref.columns.values[1:]):
  print(imetric,metric_)
  metric_split=metric_.split('_')
  metric=""
  for x in metric_split[1:] :
    metric=metric+'_'+x
  if(len(metric_split) > 1) : 
    metric=metric[1:]
    metric_head=metric_split[0]
  else : 
    metric=metric_split[0]
    metric_head=""
  print("metric = ", metric)
  #normalisaiton à la main !!!
  if("precip" in metric) : 
    min_norm=1
    max_norm=8
    min_norm=0
    max_norm=10
    #min_norm=1
    #max_norm=7
  elif("conv" in metric) : 
    min_norm=-160
    max_norm=-30
    min_norm=-140
    max_norm=-30
    #min_norm=-140
    #max_norm=-80
  else : 
    min_norm=-70
    max_norm=-15
    min_norm=-50
    max_norm=-15
    #min_norm=-60
    #max_norm=-40
  print(min_norm,max_norm)
  norm = mcolors.Normalize(vmin=min_norm, vmax=max_norm)
  cmap = plt.get_cmap("magma")
  print(params_ref.columns.values[1:])
  for iparam,param in enumerate(params_ref.columns.values[1:]) :
    for iparam2,param2 in enumerate(params_ref.columns.values[1:]) :
      #print(param, param2)
      fig,ax=plt.subplots(figsize=(5,5)) 
      #plt.subplots_adjust(left=0.1, right=0.7, top=0.6, bottom=0.1)
      #plt.subplots_adjust()
      ax.set_box_aspect(1)
      print("param = ", param)
      if(file_param_user[param].iloc[3]=="log") : 
         plt.xscale("log")
      if(file_param_user[param2].iloc[3]=="log") : 
         plt.yscale("log")
      xmin=file_param_user[param].iloc[0]
      xmax=file_param_user[param].iloc[1]
      xprior=file_param_user[param].iloc[2]
      ymin=file_param_user[param2].iloc[0]
      ymax=file_param_user[param2].iloc[1]
      yprior=file_param_user[param2].iloc[2]
      
      #-------------------------------------------------------------------
      # Plotting ensembles
      #-------------------------------------------------------------------
      iens=0
      ii=1
      
      for file_ens in metric_files :
        metrics=pd.read_csv(file_ens,sep=",",index_col="SIM")
        file_param=param_files[iens]
        params=pd.read_csv(file_param,sep=" ",index_col="t_IDs")
        if "WAVE" in metric_head :
          iwave=metrics.keys()[imetric].split('_')[0][4:]
          metric_head="WAVE"+str(iwave)
          label='Wave'+str(iwave)
        else :
          label=metric+" "+str(iens)
          iwave=-1
        if(metric_head=="") : 
          name_metric=metric
        else : 
          name_metric=metric_head+'_'+metric
        #name_metric=metric
        if ( (("WAVE" in metric_head) & (int(iwave) in enstoprint)) | (file_ens in enstoprint) ) :
          zorder=ii
          color=colors_ens[ii]
          ii=ii+1
        else :
          zorder=-100
          color=colors_ens[0]
        #plt.scatter(params[param],metrics[name_metric],label=label,s=20,marker='x', color=color, zorder=zorder)
        #plt.scatter(params[param],params[param2], c=metrics[name_metric], label=label,s=20, marker='x', cmap=cmap, norm=norm, zorder=zorder)
        plt.scatter(params[param],params[param2],
                    marker=marker_ens[iens],color=colors_ens[iens], s=ms_ens[iens],
                    label=label_ens[iens], zorder=20-iens) #, cmap=cmap, norm=norm, zorder=zorder)
        iens=iens+1
      order_of_magnitude = 10 ** math.floor(math.log10(abs(xmin)))
      ax.set_xlim(xlim) #xmin-0.1*order_of_magnitude, xmax+0.1*order_of_magnitude)
      print(xmin,xmax)
      ax.set_ylim(ylim)
      #ax.set_ylim(0.5, ymax)
      lw=2
      #plt.axhline(y=0.81,color='k', ls='--', lw=lw)   #opt-3T 0.836
      plt.axvline(x=0.81,color='k', ls='--', lw=lw)   #opt-3T 0.836
      #plt.axhline(y=0.84,color='k', ls='--')   #opt-3T 0.836
      #plt.axhline(y=0.91,color='k', ls='-.')   #opt-2T
      #plt.axhline(y=1,color='k', ls=':')  #opt-T
      #plt.axhline(y=1.11, color='k', lw=2)  #opt
      #plt.axhline(y=1.25,color='k', ls=':')  #opt+T
      #plt.axhline(y=1.41,color='k', ls='-.')  #opt+2T
      #plt.axhline(y=1.61,color='k', ls='--')  #opt+3T
      plt.axvline(x=1.69,color='k', ls='--', lw=lw)  #opt+3T


      ## end loop on ensembles (or waves)
      #plt.colorbar(label=metric)

      #-------------------------------------------------------------------
      # Plotting targets and errors
      #-------------------------------------------------------------------
      #obs_head=obs.keys()[imetric+1].split('_')[0]
      #name_obs_metric=obs_head+"_"+metric
      #plt.axhline(obs_mean[name_obs_metric],xmin=0,xmax=1,color='k')
      #plt.axhline(obs_mean[name_obs_metric]+(obs_var[name_obs_metric])**0.5,xmin=0,xmax=1,color='k', ls=':')
      #plt.axhline(obs_mean[name_obs_metric]-(obs_var[name_obs_metric])**0.5,xmin=0,xmax=1, color='k', ls=':')

      ##-------------------------------------------------------------------
      ## Plotting bests simulations
      ##-------------------------------------------------------------------
      isimu=0
      for simu in list_bests :
        file_metrics=dico_bests["metrics"][simu]
        file_params=dico_bests["params"][simu]
        metrics=pd.read_csv(file_metrics,sep=",", index_col='SIM')
        params=pd.read_csv(file_params,sep=" ", index_col="t_IDs")
        if "SCM" in simu :
          iwave=simu.split('-')[1]
          metric_head="WAVE"+str(iwave)
        name_metric=metric_head+'_'+metric
        #name_metric=metric
        met_value=metrics.loc[simu][name_metric]
        
        param_value=params.loc[simu][param]
        param2_value=params.loc[simu][param2]
        color=listcoul[isimu]
        #plt.scatter(param_value,met_value,label=simu,s=20, marker='o', edgecolor="k", zorder=100,color=color)
        #plt.scatter(param_value,param2_value, c=met_value, label=label,s=20, marker='o', cmap=cmap, norm=norm, zorder=100, edgecolor='k')
        plt.scatter(param_value,param2_value ,s=20, marker='o', color="tab:red", zorder=100, edgecolor='k')
        isimu=isimu+1
      ##end loop on bests simu

      #-------------------------------------------------------------------
      # Closing grpahics
      #-------------------------------------------------------------------
      fontsize=12
      plt.xlabel(param,fontsize=fontsize)
      plt.ylabel(param2,fontsize=fontsize)
      #plt.colorbar(label=metric)
      #plt.legend(ncol=2,loc=(1.05,0.),fontsize=6)
      ax.tick_params(labelsize=fontsize)
      plt.legend(loc='upper center', bbox_to_anchor=(0.5,-0.15),
                 fontsize=fontsize, frameon=False, ncol=2)
      #plt.legend(loc=(1.05,0.8),fontsize=fontsize)
      plt.savefig(path_fig+metric+"_"+param+"_"+param2+"wleg.png",bbox_inches='tight')
      ax.get_legend().remove()
      plt.savefig(path_fig+metric+"_"+param+"_"+param2+"noleg.png", bbox_inches='tight')

      plt.close()
  ## end loop on parameters
  # on regroupe les figures par métrique (idem Plots_Metric.pdf mais avec toutes les vagues)
  #os.system("pdfjam --fitpaper true "+path_fig+metric+"_*.pdf --outfile "+path_fig+metrics+"_vs_param.pdf")
  #print(path_fig+metrics+"_vs_param.pdf")
##end loop on metrics

# on regroupe les figures par paramètres
#for iparam,param in enumerate(params_ref.columns.values[1:]) :
#  os.system("pdfjam --nup 3x3 --landscape "+path_fig+"*"+param+"*.pdf --outfile "+path_fig++param+"_vs_metrics.pdf")
#  print(path_fig+param+"_vs_metrics.pdf")
  

