import os
import numpy as np
import matplotlib.pyplot as plt
import exo_k as xk

### choose spectral resolution for corrk
R=100
wnedges=xk.wavenumber_grid_R(10000/100, 10000/.3, R)
print("Spectral resolution asked: wn=", wnedges.min(), wnedges.max(), ", R =", R)

### folder where corrk will be saved
data_path="datadir/"
dir_out = data_path+f"corrk_data/R{R}_from_R15000/"

### folder with input corrk data
xk.Settings().set_search_path(data_path+f"corrk_data/R500_from_R15000/")
# pattern = 'corrk'
# xk.Settings().set_search_path(data_path+"corrk/R500_ExoREM/")
# pattern = 'ktable.exorem.h5'
# xk.Settings().set_search_path(data_path+"corrk/R10000_dace/")

### you can set your desired p and t grid here (otherwise, first molecule is used)
logpgrid = None
tgrid = None

mols = [
"C2H2",
"C2H4",
"CH4",
"CO2",
"CO",
"H2O",
"H2S",
"HCN",
"K",
"MgH",
"MgO",
"NaH",
"Na",
"NaOH",
"NH3",
"O2",
"OCS",
"OH",
"SiH2",
"SiH4",
"SiH",
"SO2",
"TiO",
"VO",
    ]


xk.Settings().set_mks(True)
xk.Settings().set_log_interp(False)
xk.Settings().set_case_sensitive(True)

print (f"Bin down to R{R} in {dir_out}")
for mol in mols:
    try:
        if os.path.exists(dir_out+mol+f"_R{R}.corrk.h5"):
            print("Already exists: "+mol)
            # continue
        else:
            print(f"Reading: {mol}")

        tmp_ktab=xk.Ktable(mol=mol, remove_zeros=True)
        print(f"Initial spectral resolution = {tmp_ktab.Nw}")

        if logpgrid is None or tgrid is None: # use first molecule to set p and t grid
            logpgrid = tmp_ktab.logpgrid
            tgrid = tmp_ktab.tgrid

        print(f"Binning: {mol}")
        tmp_ktab.bin_down(wnedges)

        if (tmp_ktab.logpgrid != logpgrid).any() or (tmp_ktab.tgrid != tgrid).any():
            print("Initial P grid:", tmp_ktab.pgrid)
            print("Initial T grid:", tmp_ktab.tgrid)
            print(f"Remapping {mol}")
            print("New P grid:", np.power(logpgrid,10))
            print("New T grid:", tgrid)
            tmp_ktab.remap_logPT(logp_array=logpgrid, t_array=tgrid)

        print(tmp_ktab)
        print(f"Writing: {mol}")
        tmp_ktab.write_hdf5(dir_out+mol+f"_R{R}.corrk.h5")

        p_plot=1
        t_plot=300.
        fig,ax=plt.subplots(figsize=(9,6.5))
        tmp_ktab.plot_spectrum(ax,p=p_plot,t=t_plot,g=1.,yscale='log',xscale='log',label=mol)
        plt.title(f"{mol} @ P = {p_plot}, T = {t_plot}")
        fig.tight_layout()
        plt.legend()
        plt.savefig(dir_out+f"corrk_{R}_{mol}_p{p_plot}_t{t_plot}.pdf")
    except ValueError as e:
        print(f"Skipping {mol}. {e}")

