import exo_k as xk
import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u
import astropy.constants as cst
import glob
from chemistry import isotopologue_to_species, species_name_to_common_isotopologue_name

### choose spectral resolution for corrk
R=10000
wnedges=xk.wavenumber_grid_R(10000/100, 10000/.3, R)
print("Spectral resolution asked: wn=", wnedges.min(), wnedges.max(), ", R =", R)

### set data paths
# dace_path = "datadir/cold/"
dace_path = "datadir/hot/"
corrk_dir = "datadir/corrk_data/"
dir_out = corrk_dir+f"/R{R}_from_dace/"

### you can set your desired p and t grid here (otherwise, first molecule is used)
logpgrid = None
tgrid = None
# ref=xk.Ktable(filename=corrk_dir + '/R500_from_R15000/CH4_R500.corrk.h5', mol="CH4")
# logpgrid = ref.logpgrid
# tgrid = ref.tgrid



### select (output) molecules and their corresponding (input) isotopologues
### set value to 'None' to get common isotopologue
molecules = {
    "CO":  "12C-16O",
    "CO2":  None,
    # "H2O":  "1H2-16O",
    # # "C2H2":"12C2-1H2",
    # # "C2H4":"12C2-1H4",
    # "CH4": "12C-1H4",
    # "TiO": "48Ti-16O",
    # "VO": "51V-16O",
}

# xk.Settings().set_mks(True)
xk.Settings().set_log_interp(False)
xk.Settings().set_case_sensitive(True)

def dace_to_corrk(molecule=None, isotopologue=None, order=17, logpgrid=logpgrid, tgrid=tgrid):
    if molecule is None:
        molecule=isotopologue_to_species(isotopologue)
    elif isotopologue is None:
        isotopologue=species_name_to_common_isotopologue_name(molecule)
    print(molecule, isotopologue)

    # use first-found linelist
    try:
        mol_corrk = glob.glob(dace_path+f'{isotopologue}__*/')[0]
    except IndexError:
        print(f"Isotopologue {isotopologue} not found in {dace_path}")
        return

    print(f"Reading: {mol}")
    hr = xk.Hires_spectrum(glob.glob(mol_corrk+"/*.bin")[0], helios=True)
    print(f"Initial spectral resolution = {hr.Nw}")

    tmp_ktab=xk.hires_to_ktable(path=mol_corrk, helios=True, wnedges=wnedges, mol=molecule, order=order) # test conversion from a whole folder to tmp

    tmp_ktab.kdata = tmp_ktab.kdata * 10 * xk.Molar_mass().fetch(molecule) / cst.N_A
    # cm^2/g * 10000/1000 * kg/mol / (molecule/mol) -> m^2/molecule
    tmp_ktab.kdata_unit = "m^2/molecule"
    tmp_ktab.remove_zeros()


    if logpgrid is None or tgrid is None: # use first molecule to set p and t grid
            logpgrid = tmp_ktab.logpgrid
            tgrid = tmp_ktab.tgrid

    if (tmp_ktab.logpgrid != logpgrid).any() or (tmp_ktab.tgrid != tgrid).any():
        print("Initial P grid:", tmp_ktab.pgrid)
        print("Initial T grid:", tmp_ktab.tgrid)
        print(f"Remapping {mol}")
        print("New P grid:", np.power(logpgrid,10))
        print("New T grid:", tgrid)
        tmp_ktab.remap_logPT(logp_array=logpgrid, t_array=tgrid)

    print(tmp_ktab)
    print(f"Writing: {mol}")
    tmp_ktab.write_hdf5(dir_out+f"{molecule}_corrk_dace_{R}.h5")

    # %matplotlib notebook
    p_plot=1e5 # in Pa
    t_plot=300
    fig,ax=plt.subplots(figsize=(9,6.5))
    tmp_ktab.plot_spectrum(ax,p=p_plot,t=t_plot,g=1.,yscale='log',xscale='log',label='g=1')
    fig.tight_layout()
    plt.savefig(dir_out+f"corrk_dace_R{R}_{molecule}_p{p_plot}_t{t_plot}.pdf")

if __name__ == '__main__':
    for mol in molecules:
        dace_to_corrk(molecule=mol, isotopologue=molecules[mol], order=17)