import matplotlib.pyplot as plt
import seaborn as sns
from glob import glob
import pandas as pd
import numpy as np
import argparse
import os

pd.options.mode.chained_assignment = None  # default='warn'

"""
2025/08 
N. Villefranque
parameter scatterplot of simulations sampled during htexplo experiments
one experiment = a directory containing score*csv and WAV*/Params.asc files

it can be run in the directory of a given experiment directly
=> only this experiment will be plotted

alternatively, the user can define a list of experiments 
=> data from all experiments will be plotted on the same figure
"""

##########################
# Command line arguments #
##########################

parser=argparse.ArgumentParser()
def add_arg(a,h,d,m):
  parser.add_argument(a, help=h, default=d, metavar=m)
add_arg("-e", "List of experiments", None, "exp1,exp2")
add_arg("-d", "Path to experiments", "./", "dir")
add_arg("-l", "List of labels for experiments", None, "lab1,lab2")
add_arg("-p", "List of parameters to plot", None, "par1,par2,par3")
add_arg("-s", "Max score to plot", 999, "<float>")
add_arg("-N", "Max number of simulations to plot", 0, "<int>")
add_arg("-n", "Number of best simulations to highlight", 5, "<int>")
add_arg("-b", "List of names of simulations to highlight", None, "sim1,sim12")
add_arg("-o", "Extension in fig name", "", "_myfig")
add_arg("-C", "Common prefix to remove from labels", None, "exp_prefix")
add_arg("-m", "Map for color", None, "cmap1,cmap2,ncolors")
add_arg("-c", "List of named colors", None, "tab:blue,tab:orange")
add_arg("-a", "alpha transparency of points", 0.4, "<float>")
add_arg("-w", "first wave to plot (list of exp:wave or one number for all waves)", None, "exp1:w5,exp3:w10 | w3")
args=parser.parse_args()

list_exps = args.e          # List of experiments to plot
path_exps = args.d          # Path to experiments
list_labels = args.l        # List of labels for experiments
list_params = args.p        # List of parameters to plot
threshold = float(args.s)   # Max score to plot
number_plot = int(args.N)   # Max number of simulations to plot
nbests = int(args.n)        # Number of best simulations to highlight
list_highlight = args.b     # List of simulations to highlight
ext_figname = args.o        # Extension if figure name
common_prefix_exp = args.C  # Common prefix to remove from labels
cmaps = args.m              # Settings for color maps
list_colors = args.c        # List of named colors
alpha = args.a              # alpha transparency of points
dict_keep_only_waves_gt = args.w # keep only 

if list_colors is not None and cmaps is not None : 
  print("warning: both options -c and -m were provided ; priority is given to -c")
  cmaps = None

cmap="Pastel2" # one color per experiment
cmap2="Set2"   # colormap for "best" simulations
ncols=8        # number of different colors in the colormaps

select_best = "max" # how to select best simulations ? max|mean
best_markers = ["o", ">", "s", "d", "P", "v", "^", "<", "D", "H"]*5
figname="param_matrix"+ext_figname+".pdf"

###############
# Here we begin 

# if relevant, make lists out of coma-separated strings
list_exps      = ["."] if list_exps      is None else list_exps.split(",")
list_labels    = None  if list_labels    is None else list_labels.split(",")
list_params    = None  if list_params    is None else list_params.split(",")
list_highlight = None  if list_highlight is None else list_highlight.split(",")
list_colors    = None  if list_colors    is None else list_colors.split(",")

nexp = len(list_exps)

# where to find data, based on experiment name and path
dict_exp_dirs = {e:path_exps+"/"+e for e in list_exps}

# make a dictionnary for experiment labels
if common_prefix_exp is None:
  common_prefix_exp = os.path.commonprefix(list_exps)
if list_labels is None:
  dict_exp_labels = {e: e.replace(common_prefix_exp,"") for e in list_exps}
else:
  dict_exp_labels = {e: l for (e,l) in zip(list_exps, list_labels)}
list_labels = [dict_exp_labels[e] for e in dict_exp_labels]

# if relevant, parse dictionnary for min wave to plot 
if dict_keep_only_waves_gt is not None :
  if ":" in dict_keep_only_waves_gt:
    # dict_keep_only_waves_gt = "exp1:w10,exp2:w5" => k = exp1:w10
    dict_keep_only_waves_gt = {k.split(":")[0]:int(k.split(":")[1]) for k in dict_keep_only_waves_gt.split(",")}
  else : dict_keep_only_waves_gt = {e: int(dict_keep_only_waves_gt) for e in list_exps}
else: dict_keep_only_waves_gt={}

# if relevant, set colormaps
if cmaps is not None :
  l_cmaps = cmaps.split(",")
  if l_cmaps[0] !="": cmap=l_cmaps[0]
  cmap2 = cmap
  if len(l_cmaps)>1:cmap2=l_cmaps[1]
  ncols = nexp
  if len(l_cmaps)>2 and l_cmaps[2] !="": ncols=int(l_cmaps[2])
if cmap2 is None or cmap2=="none" or cmap2=="None" or cmap2=="": cmap2=cmap

# make a list of colors out of colormap(s)
nhue = len(np.unique(np.array(list_labels)))
if list_colors is not None :
  from matplotlib.colors import LinearSegmentedColormap
  LSC = LinearSegmentedColormap.from_list("custom_cmap", list_colors, N=len(list_colors))
  LSC2 = LinearSegmentedColormap.from_list("custom_cmap", list_colors, N=len(list_colors))
  ncols = len(list_colors)
else:
  LSC = plt.get_cmap(cmap)
  LSC2 = plt.get_cmap(cmap2)
col_loc = [(float(i)/(ncols))  for i in range(nhue)]
palcolors_ = [LSC(x) for x in col_loc]
palcolors2 = [LSC2(float(i)/ncols) for i in range(nexp) for u in range(nbests)]

# change alpha value of points for the general population
palcolors = [(r,g,b,alpha) for (r,g,b,a) in palcolors_]
palcolors += palcolors2
palette = palcolors

def parse_score_line(l):
  return l.strip("\n").split(",")
def convert_score_line_to_float(p_l):
  return [p_l[0]]+[float(u) for u in p_l[1:]]

# make dict best by reading score*.csv files
dict_exps_sims = {}
dict_exps_scor = {}
for exp in list_exps:
  dir = dict_exp_dirs[exp]
  if not os.path.isdir(dir): print("error: wrong directory %s"%dir); exit(1)
  keep_sims=[]
  files_all_scores = glob(dir+"/score*csv")
  files_all_scores.sort()
  if len(files_all_scores)==0: print("error: no score files; run post_scores.sh"); exit(1)
  if "scores.csv" in files_all_scores: files_all_scores.remove("scores.csv")

  for file_score_wave in files_all_scores:
    if "score*to" in file_score_wave: continue
    wave=int(file_score_wave.split("score")[-1].split(".csv")[0])
    if exp in dict_keep_only_waves_gt and wave < dict_keep_only_waves_gt[exp]: continue
    line_count = 0
    for l in open(file_score_wave, "r"):
      line_count += 1
      p_line = parse_score_line(l)
      if line_count==1: header = p_line; continue
      f_line = convert_score_line_to_float(p_line)
      if f_line[-1] < threshold:
        keep_sims += [f_line]

  data = pd.DataFrame(keep_sims, columns=header).sort_values("MAX")
  ##############################################
  # !!!!        PAS OUF CETTE LIGNE         !!!!
  ##############################################
  data = data.drop_duplicates(subset=["MAX"])  #
  ##############################################
  pre_list = data["SIM"].tolist()
  if number_plot>0: pre_list = pre_list[:min(len(keep_sims), number_plot)]
  dict_exps_sims[exp] = pre_list
  dict_exps_scor[exp] = data.set_index('SIM').loc[pre_list].reset_index()

def parse_sim(sim):
  nwave = sim.split("-")[1]
  nsamp = sim.split("-")[-1].split(".nc")[0]
  return nwave, nsamp

def sim_wave(sim):
  # name of the simulation
  nwave, nsamp = parse_sim(sim)
  return int(nwave)

def exp_sim_params_names(exp, sim):
  # param values of the simulation
  nwave, nsamp = parse_sim(sim)
  p = dict_exp_dirs[exp]+"/WAVE"+nwave+"/Params.asc"
  for l in open(p, "r"):
    # return first line
    return [m.strip('"') for m in l.strip("\n").split(" ")[1:]] 

def exp_sim_params(exp, sim):
  # param values of the simulation
  nwave, nsamp = parse_sim(sim)
  p = dict_exp_dirs[exp]+"/WAVE"+nwave+"/Params.asc"
  name = "-".join(sim.split("-")[1:])
  for l in open(p, "r"):
    if name in l : return([float(u) for u in l.strip("\n").split(" ")[1:]])

def exp_waves(exp):
  # vector of sim waves for all sims in dict_exps_sims[exp]
  return np.array([sim_wave(sim) for sim in dict_exps_sims[exp]])

def exp_params(exp):
  # vector of parameters for all bests of exp
  return np.array([exp_sim_params(exp, sim) for sim in dict_exps_sims[exp]])

def exp_max_scores(exp):
  df_scores = dict_exps_scor[exp]
  return df_scores["MAX"].tolist()

def exp_mean_scores(exp):
  df_scores = dict_exps_scor[exp]
  return df_scores["AVE"].tolist()

def plot_params():
  df = []
  df_bests=[]
  markers=[]
  columns = None
  nsims_tot = 0
  for exp in list_exps:
    exp_sims = np.array(dict_exps_sims[exp])
    exp_label = dict_exp_labels[exp]
    nsims = len(exp_sims)
    if nsims==0: continue
    # first time, get parameter names and number
    if columns is None : 
      columns = exp_sim_params_names(exp, exp_sims[0]); npar = len(columns)

    # get parameter vectors of this exp simulations
    this_params = exp_params(exp)

    # convert params to dataframe
    this_df = pd.DataFrame(this_params, columns=columns)
    this_df = this_df.assign(SIM=exp_sims)

    # add "hue" column for coloring ("hue" = label exp)
    this_hue = np.array([exp_label]*nsims)
    this_df = this_df.assign(hue=this_hue).assign(size="normal")

    # add mean and max scores columns
    this_scores_mean = exp_mean_scores(exp)
    this_scores_max  = exp_max_scores(exp)
    this_df = this_df.assign(mean_score=this_scores_mean).assign(max_score=this_scores_max)

    nsims_tot += nsims

    # special symbol for bests of the exp 
    # need to create new datasets to handle label/color/symbol
    if nbests > 0:
      if select_best == "mean"  : scores = np.array(this_scores_mean); key="mean_score"
      elif select_best == "max" : scores = np.array(this_scores_max);  key="max_score"
      sorted_scores = np.sort(np.unique(scores))
      nselect = min(nbests, nsims)
      select_scores = sorted_scores[:nselect]

      select_data = this_df.set_index(key).loc[select_scores].reset_index()
      select_data.loc[:,'size'] = "bests"
      markers += best_markers[:nselect]
      for i in range(nselect):
        this_sim_name = select_data["SIM"][i]
        this_label = "best %s %s"%(exp_label, "-".join(parse_sim(this_sim_name)))
        select_data.loc[i, 'hue'] = this_label
      df_bests += [select_data]

    # add this dataframe to the list of dataframes to plot
    df = df+[this_df]

  # make one dataframe from the list of dataframes
  all_df_ = pd.concat(df) # without bests
  df += df_bests 
  all_df = pd.concat(df)  # with bests 

  # custom markers symbol and size
  markers = ["o"]*nhue + markers
  sizes = all_df["size"]
  all_df = all_df.drop(columns="size")

  # plot dataframe using pairplot, coloring according to "hue" column
  g = sns.pairplot(all_df, vars=list_params, kind="scatter",
          hue="hue", palette=palette, markers=markers,
          plot_kws={"size":sizes, "sizes":(80,20)},
          diag_kind="hist",
          diag_kws={"element":"step"}, corner=True)
  plt.savefig(figname, bbox_inches="tight")

plot_params()
