import os
import numpy as np
import matplotlib.pyplot as plt
import exo_k as xk

### choose spectral resolution for corrk
R=500
wnedges=xk.wavenumber_grid_R(10000/100, 10000/.3, R)
print("Spectral resolution asked: wn=", wnedges.min(), wnedges.max(), ", R =", R)

### folder where corrk will be saved
data_path="datadir/"
dir_out = data_path+f"corrk_data/R{R}_from_R15000/"

### folder with xsec data, downloaded from exomol.com for example
xk.Settings().set_search_path(data_path+"exomol/taurex_R15000/")
pattern = 'xsec.TauREx.h5'

### you can set your desired p and t grid here (otherwise, first molecule is used)
logpgrid = None
tgrid = None

mols = [
# "C2H2",
"C2H4",
# "CH4",
# "CO2",
"CO",
"H2O",
"H2S",
"HCN",
"K",
"MgH",
"MgO",
"NaH",
"Na",
"NaOH",
"NH3",
"O2",
"OCS",
"OH",
"SiH2",
"SiH4",
"SiH",
"SO2",
"TiO",
"VO",
]


xk.Settings().set_mks(True)
xk.Settings().set_log_interp(False)
xk.Settings().set_case_sensitive(True)


print (f"Generating R{R} in {dir_out}")
for mol in mols:
    try:
        if os.path.exists(dir_out+mol+f"_R{R}.corrk.h5"):
            print("Already exists: "+mol)
            # continue # uncomment to skip molecules
        else:
            print(f"Reading: {mol}")

        xtable=xk.Xtable(pattern, mol=mol, remove_zeros=True)
        print(f"Initial spectral resolution = {xtable.Nw}")

        tmp_ktab=xk.Ktable(wnedges=wnedges, xtable=xtable)

        # print(f"Binning: {mol}")
        # tmp_ktab.bin_down(wnedges)


        if logpgrid is None or tgrid is None: # use first molecule to set p and t grid
            logpgrid = tmp_ktab.logpgrid
            tgrid = tmp_ktab.tgrid

        if (tmp_ktab.logpgrid != logpgrid).any() or (tmp_ktab.tgrid != tgrid).any():
            print("Initial P grid:", tmp_ktab.pgrid)
            print("Initial T grid:", tmp_ktab.tgrid)
            print(f"Remapping {mol}")
            print("New P grid:", np.power(logpgrid,10))
            print("New T grid:", tgrid)
            tmp_ktab.remap_logPT(logp_array=logpgrid, t_array=tgrid)

        print(tmp_ktab)
        print(f"Writing: {mol}")
        tmp_ktab.write_hdf5(dir_out+mol+f"_R{R}.corrk.h5")

        p_plot=1
        t_plot=300.
        fig,ax=plt.subplots(figsize=(9,6.5))
        tmp_ktab.plot_spectrum(ax,p=p_plot,t=t_plot,g=1.,yscale='log',xscale='log',label=mol)
        plt.title(f"{mol} @ P   = {p_plot}, T = {t_plot}")
        fig.tight_layout()
        plt.legend()
        plt.savefig(dir_out+f"corrk_{R}_{mol}_p{p_plot}_t{t_plot}.pdf")
    except Exception as e:
        print(f"Skipping {mol}. {e}")

