"""
Generic PCM Photochemistry post-processing library
Written by Maxime Maurice in 2024 anno domini
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as tk
import xarray as xr
from scipy.constants import R, N_A

import warnings
warnings.filterwarnings("ignore", message="The following kwargs were not used by contour: 't'")
warnings.filterwarnings("ignore", message="The following kwargs were not used by contour: 'lon'")
warnings.filterwarnings("ignore", message="The following kwargs were not used by contour: 'lat'")

class GPCM_simu:
    """ Generic PCM Simulation

    Stores the netCDF file and path to the simulation

    Attributes
    ----------
    data : xr.Dataset
        NetCDF file of the simulation
    path : str
        Path to simulation directory

    Methods
    -------
    get_subset(field,**kw)
        Get pa subset at fixed given coordinate of the data
    plot_meridional_slice(field,logcb,labelcb,**kw)
        Plot a meridional slice of a field
    plot_time_evolution(field,logcb,labelcb,**kw)
        Plot time evolution of a field (as a time vs altitude contour plot)
    plot_time_series(field,lat,lon,alt,logy)
        Plot time series of a field (at a specific location (lon, lat, alt), averged, or a combination thereof)
    plot_atlas(field,t,alt,logcb,labelcb)
        Plot atlas of a field
    plot_profile(field,**kw)
        Plot a profile of a field

    """
    def __init__(self,path,filename='diagfi.c',verbose=False):
        """
        Parameters
        ----------
        path : str
            Path to simulation directory
        datafilename : str (optional)
            Name of the netCDF file (by default: diagfi)
            works with start, startfi, concat etc.
            Do not add .nc type suffix
        """
        self.path = path
        try:
            self.data = xr.open_dataset(path+'/'+filename,decode_times=False)
            print(path+'/'+filename,'loaded, simulations lasts',self.data['Time'].values[-1],'sols')
        except:
            raise Exception('Data not found')

    def __getitem__(self,field):
        return self.data[field]

    def __setitem__(self,field,value):
        self.data[field] = value

    def __area_weight__(self,data_array):
        return self['aire']*data_array/self['aire'].mean('latitude')

    def get_subset(self,field='all',**kw):
        """ Get a subset at fixed given coordinate of the data

        Can also average over a given dimension. In this case,
        the meridional average is area-weighted.

        Parameters
        ----------
        field : str (optional, default = 'all')
            Field name. If nothing or 'all'
            specified, return all fields.
        t : float (optional)
            Time of the slice. If nothing
            specified, use time average
        lon : float (optional)
            Longitude of the slice. If nothing
            specified, use meridional average
        lat : float (optional)
            Latitude of the slice. If nothing
            specified, use area-weighted zonal average
        alt : float (optional)
            Altitude of the slice. If nothing
            specified, use time-average
        """
        if field == 'all':
            data_subset = self.data
        else:
            data_subset = self[field]
            
        if 't' in kw and 'Time' in data_subset.dims:
            if kw['t'] == 'avg':
                data_subset = data_subset.mean(dim='Time')
            else:
                data_subset = data_subset.sel(Time=kw['t'],method='nearest')
        if 'lat' in kw and 'latitude' in data_subset.dims:
            if kw['lat'] == 'avg':
                data_subset = self.__area_weight__(data_subset).mean(dim='latitude')
            else:
                data_subset = data_subset.sel(latitude=kw['lat'],method='nearest')
        if 'lon' in kw and 'longitude' in data_subset.dims:
            if kw['lon'] == 'avg':
                data_subset = data_subset.mean(dim='longitude')
            else:
                data_subset = data_subset.sel(longitude=kw['lon'],method='nearest')
        if 'alt' in kw and 'altitude' in data_subset.dims:
            if kw['alt'] == 'avg':
                data_subset = data_subset.mean(dim='altitude')
            else:
                data_subset = data_subset.sel(altitude=kw['alt'],method='nearest')

        return data_subset

    def plot_meridional_slice(self,field,t='avg',lon='avg',logcb=False,labelcb=None,**plt_kw):
        """ Plot a meridional slice of a field

        Parameters
        ----------
        field : str
            Field name to plot
        logcb : bool (optional)
            Use logarithmic colorscale
        labelcb : str (optional)
            Use custom colorbar label
        t : float (keyword)
            Time at which to plot (if nothing specified use time average)
        lon : float (keyword)
            Longitude at which to plot (if nothing specified use zonal average)
        """

        if self['latitude'].size == 1:
            # safety check
            raise Exception('Trying to plot a meridional slice of a 1D simulation')

        meridional_slice = self.get_subset(field,t=t,lon=lon)
            
        if logcb:
            plt.contourf(meridional_slice['latitude'],meridional_slice['altitude'],meridional_slice,locator=tk.LogLocator(),**plt_kw)
        else:
            plt.contourf(meridional_slice['latitude'],meridional_slice['altitude'],meridional_slice,**plt_kw)
            
        plt.colorbar(label=field if labelcb==None else labelcb)
        plt.xlabel('latitude [°]')
        plt.ylabel('altitude [km]')

    def plot_time_evolution(self,field,lat='avg',lon='avg',logcb=False,labelcb=None,**plt_kw):
        """ Plot time evolution of a field (as a time vs altitude contour plot)

        Parameters
        ----------
        field : str
            Field name to plot
        lat : float (optional)
            Latitude at which to plot (if nothing specified use area-weighted meridional average)
        lon : float (optional)
            Longitude at which to plot (if nothing specified use zonal average)
        logcb : bool (optional)
            Use logarithmic colorscale
        labelcb : str (optional)
            Use custom colorbar label
        matplotlib contourf keyword arguments
        """

        time_evolution = self.get_subset(field,lon=lon,lat=lat)
            
        if logcb:
            plt.contourf(time_evolution['Time'],time_evolution['altitude'],time_evolution.T,locator=tk.LogLocator(),**plt_kw)
        else:
            plt.contourf(time_evolution['Time'],time_evolution['altitude'],time_evolution.T,**plt_kw)
            
        plt.colorbar(label=field if labelcb==None else labelcb)
        plt.xlabel('time [day]')
        plt.ylabel('altitude [km]')

    def plot_time_series(self,field,lat='avg',lon='avg',alt='avg',logy=False,**plt_kw):
        """ Plot time series of a field (at a specific location (lon, lat, alt), averged, or a combination thereof)

        Parameters
        ----------
        field : str
            Field name to plot
        lat : float (optional)
            Latitude at which to plot (if nothing specified use area-weighted meridional average)
        lon : float (optional)
            Longitude at which to plot (if nothing specified use zonal average)
        logy : bool (optional)
            Use logarithmic y-axis
        matplotlib plot keyword arguments
        """

        time_series = self.get_subset(field,lon=lon,lat=lat,alt=alt)

        if not 'label' in plt_kw:
            plt_kw['label'] = self[field].units
        if logy:
            plt.semilogy(time_series['Time'],time_series,**plt_kw)
        else:
            plt.plot(time_series['Time'],time_series,**plt_kw)
            
        plt.xlabel('time [day]')
        plt.ylabel(field+' ['+self[field].units+']')

    def plot_atlas(self,field,t='avg',alt='avg',logcb=False,labelcb=None,**plt_kw):
        """ Plot atlas of a field

        Parameters
        ----------
        field : str
            Field name to plot
        t : float (optional)
            Time at which to pot (if nothing specified, use time average)
        alt : float (optional)
            Altitude at which to plot (if nothing specified use vertical average)
        logcb : bool (optional)
            Use logarithmic colorscale
        labelcb : str (optional)
            Use custom colorbar label
        matplotlib contourf keyword arguments
        """

        if 'altitude' in self[field].dims:
            atlas = self.get_subset(field,t=t,alt=alt)
        else:
            atlas = self.get_subset(field,t=t)
            
        if logcb:
            plt.contourf(atlas['longitude'],atlas['latitude'],atlas,locator=tk.LogLocator(),**plt_kw)
        else:
            plt.contourf(atlas['longitude'],atlas['latitude'],atlas,**plt_kw)
            
        plt.colorbar(label=field if labelcb==None else labelcb)
        plt.xlabel('longitude [°]')
        plt.ylabel('matitude [°]')

    def plot_profile(self,field,t='avg',lon='avg',lat='avg',logx=False,**plt_kw):
        """ Plot a profile of a field

        Parameters
        ----------
        field : str
            Field name to plot
        logx : bool (optional)
            Use logarithmic x axis
        t : float (optional)
            Time at which to select (if nothing specified use time-average)
        lat : float (optional)
            Latitude at which to plot (if nothing specified use area-weighted meridional average)
        lon : float (optional)
            Longitude at which to plot (if nothing specified use zonal average)
        matplotlib's plot / semilogx keyword arguments
        """

        profile = self.get_subset(field,t=t,lon=lon,lat=lat)
        
        if logx:
            plt.semilogx(profile,profile['altitude'],**plt_kw)
        else:
            plt.plot(profile,profile['altitude'],**plt_kw)
            
        plt.xlabel(field+' ['+self[field].units+']')
        plt.ylabel('altitude [km]')

    def read_tracfile(self,filename='traceur.def'):
        """ Read the traceurs of a simulation
    
        Parameters
        ----------
        filename : string (optional)
            Name of the tracer file. Default: 'traceur.def'
        """
        self.tracers = {}
        self.M       = {}
        with open(self.path+'/'+filename) as tracfile:
            
            for iline,line in enumerate(tracfile):
                
                # First line
                if iline == 0:
                    if not '#ModernTrac-v1' in line:
                        raise Exception('Can only read modern traceur.def')
                    continue
                    
                # Second line (number of tracers)
                elif iline == 1:
                    continue
                
                # Empty line
                elif len(line.split()) == 0:
                    continue
                    
                # Commented line
                elif line[0] == '!':
                    continue
    
                # Regular entry
                else:
                    line       = line.replace('=',' ').split()
                    tracparams = {}
                    for i in range(int(len(line)/2)):
                        try:
                            tracparams[line[2*i+1]] = float(line[2*i+2])
                        except:
                            tracparams[line[2*i+1]] = line[2*i+2]
                    self.tracers = self.tracers | {line[0]:tracparams}

    def compute_rates(self):
        """ Computes reaction rates for a simulation
    
        Parameters
        ----------
        s : GPCM_simu
            Simulation object
        reactions : dict (optional)
            Dictionnary of reactions whose rate to compute as returned by read_reactfile
            If nothing passed, call read_reactfile to identify reactions
    
        Returns
        -------
        GPCM_simu
            Simulation object with reactions rates, rates constants, species vmr and densities as new fields
        """
    
        # self.network     = network.from_file(self.path+'/chemnetwork/reactfile')
        self.read_tracfile() # we need to read the traceur.def to know the molar mass of species
        reactions = self.network.reactions
        if not set(self.network.species) <= set(list(self.tracers.keys())):
            raise Exception('Chemical network contains species that are not in the traceur.def file')
                
        densities = {}
    
        # Total density
        self['total density']     = (self['p'] / R / self['temp'] / 1e6 * N_A).assign_attrs({'units':'cm^-3.s^-1'}) # 1e6 converts m³ to cm^3
    
        for sp in self.network.species:
            # volume mixing ratios
            self[sp+' vmr'] = (self[sp] * self.tracers[self.background_species]['mmol'] / self.tracers[sp]['mmol']).assign_attrs({'units':'m^3/m^3'})
            # molecular densities
            self[sp+' density'] = (self[sp+' vmr'] * self['p'] / R / self['temp'] / 1e6 * N_A).assign_attrs({'units':'cm^-3.s^-1'}) # 1e6 converts m³ to cm^3
            densities[sp] = self[sp+' density']
    
        for r in reactions:
    
            # Photolysis
            if type(reactions[r]) == photolysis:
    
                # Cases with branching ratios
                if reactions[r].reactants[0] == 'co2':
                    if 'o1d' in reactions[r].products:
                        self['rate ('+reactions[r].formula+')'] = reactions[r].rate(densities,j=self['jco2_o1d'])
                    else:
                        self['rate ('+reactions[r].formula+')'] = reactions[r].rate(densities,j=self['jco2_o'])
                elif reactions[r].reactants[0] == 'o2':
                    if 'o1d' in reactions[r].products:
                        self['rate ('+reactions[r].formula+')'] = reactions[r].rate(densities,j=self['jo2_o1d'])
                    else:
                        self['rate ('+reactions[r].formula+')'] = reactions[r].rate(densities,j=self['jo2_o'])
                elif reactions[r].reactants[0] == 'o3':
                    if 'o1d' in reactions[r].products:
                        self['rate ('+reactions[r].formula+')'] = reactions[r].rate(densities,j=self['jo3_o1d'])
                    else:
                        self['rate ('+reactions[r].formula+')'] = reactions[r].rate(densities,j=self['jo3_o'])
                elif reactions[r].reactants[0] == 'ch2o':
                    if 'cho' in reactions[r].products:
                        self['rate ('+reactions[r].formula+')'] = reactions[r].rate(densities,j=self['jch2o_cho'])
                    else:
                        self['rate ('+reactions[r].formula+')'] = reactions[r].rate(densities,j=self['jch2o_co'])
                elif reactions[r].reactants[0] == 'h2o_vap':
                    self['rate ('+reactions[r].formula+')'] = reactions[r].rate(densities,j=self['jh2o'])
                else:
                    # General case
                    self['rate ('+reactions[r].formula+')'] = reactions[r].rate(densities,j=self['j'+reactions[r].reactants[0]])
            else:
                self['k ('+reactions[r].formula+')'] = reactions[r].constant(self['temp'],densities[self.background_species])
                self['rate ('+reactions[r].formula+')'] = reactions[r].rate(self['temp'],densities,self.background_species)
    
                # 3-body reaction
                if type(reactions[r]) == termolecular_reaction:
                    self['k ('+reactions[r].formula+')'] = self['k ('+reactions[r].formula+')'].assign_attrs({'units':'cm^6.s^-1'})
    
                # 2-body reaction
                else:
                    self['k ('+reactions[r].formula+')'] = self['k ('+reactions[r].formula+')'].assign_attrs({'units':'cm^3.s^-1'})
                    
            self['rate ('+reactions[r].formula+')'] = self['rate ('+reactions[r].formula+')'].assign_attrs({'units':'cm^-3.s^-1'})

    def to_chempath(self,t,dt,filename_suffix='_chempath',lon='avg',lat='avg',alt='avg'):
        """ Create files redable by the chemical path analyzer chempath (DOI: 10.5194/gmd-2024-163)

        Parameters
        ----------
        t : float
            Initial time
        dt : float
            Timestep
        filename_suffix : str (optional)
            Suffix to append to the files being created (species.txt,
            reactions.txt, model_time.dat, rates.dat, concentrations.dat)
        lat : float (optional)
            Latitude at which to plot (if nothing specified use area-weighted meridional average)
        lon : float (optional)
            Longitude at which to plot (if nothing specified use zonal average)
        alt : float (optional)
            Altitude of the slice. If nothing
        matplotlib's plot / semilogx keyword arguments
        """

        # Save pecies list in chempath format
        with open(self.path+'/species'+filename_suffix+'.txt', 'w') as fp:
            for sp in self.network.species:
                fp.write("%s\n" % sp)

        # Save reactions list in chempath format
        with open(self.path+'/reactions'+filename_suffix+'.txt', 'w') as fp:
            for r in self.network:
                fp.write("%s\n" % r.to_string(format='chempath'))

        # Save rates, concentrations and times in chempath format
        rates   = np.array([],dtype=np.float128)
        conc    = np.array([],dtype=np.float128)
        conc_dt = np.array([],dtype=np.float128)
        times   = np.array([t,t+dt],dtype=np.float128)
        
        out    = self.get_subset(t=t,lon=lon,lat=lat,alt=alt)  
        for r in self.network.reactions:
            rates = np.append(rates,out[f'rate ({r})'])
        for sp in self.network.species:
            conc = np.append(conc,out[f'{sp} vmr'])
        
        out    = self.get_subset(t=t+dt,lon=lon,lat=lat,alt=alt)  
        for sp in self.network.species:
            conc_dt = np.append(conc_dt,out[f'{sp} vmr'])
        conc = np.vstack((conc,conc_dt))

        times.tofile(self.path+'/model_time'+filename_suffix+'.dat')
        rates.tofile(self.path+'/rates'+filename_suffix+'.dat')
        conc.tofile(self.path+'/concentrations'+filename_suffix+'.dat')

    def to_chemical_pathway_analyzer(self,t,dt,filename_suffix='_cpa',lon='avg',lat='avg',alt='avg'):

        # Save species list an concentrations in cpa format
        with open(self.path+'/concentrations'+filename_suffix+'.txt', 'w') as fp:
            for sp in self.network.species:
                fp.write("%s = %e\n" % (sp, self.get_subset(f'{sp} vmr',t=t,lon=lon,lat=lat,alt=alt).values))
                
        # Save reactions list ("model") in cpa format
        with open(self.path+'/model'+filename_suffix+'.txt', 'w') as fp:
            for r in self.network:
                fp.write("%s R %e\n" % (r.to_string(format='cpa'),
                                        self.get_subset(f'rate ({r.formula})',t=t,lon=lon,lat=lat,alt=alt).values \
                                      / self.get_subset('total density',t=t,lon=lon,lat=lat,alt=alt).values))
            
            
class reaction:
    """ Instantiates a basic two-body reaction
    
    Attributes
    ----------
    formula : str
        Reaction formula (e.g. "A + B -> C + D")
    reactants : list
        Reactanting molecules formulae (e.g. ["A", "B"])
    products : list
        Produced molecules formulae (e.g. ["C", "D"])
    constant : callable
        Reaction rate constant, function of potentially temperature and densities

    Methods
    -------
    rate(T,densities,third_body)
        Reaction rate for given temperature and densities
    from_string(line,format)
        Set up from an ASCII string
    to_string(format)
        Return an ASCII line readable by a photochemical model
    """
    def __init__(self,reactants,products,constant):
        """
        Parameters
        ----------
        reactants : list(string)
            Reacting molecules formulae (e.g. ["A", "B"])
        products : list(string)
            Produced molecules formulae (e.g. ["C", "D"])
        constant : fun
            Reaction rate constant, function of potentially temperature and densities
        """
        
        self.formula   = ''.join([r_+' + ' for r_ in reactants[:-1]])+reactants[-1]+' -> '+''.join([r_+' + ' for r_ in products[:-1]])+products[-1]
        self.products  = products
        self.reactants = reactants
        self.constant  = constant

    def rate(self,T,densities,third_body):
        """ Computes reaction rate

        Parameters
        ----------
        T : float
            Temperature [K]
        densities : dict
            Molecular densities [cm^-3]
        third_body : string
            Third body molecule

        Returns
        -------
        float
            Value of the reaction rate [cm^-3.s^-1]
        """
        
        return self.constant(T,densities[third_body])*densities[self.reactants[0]]*densities[self.reactants[1]]

    @classmethod
    def from_string(cls,line,format='GPCM',high_pressure_term=False):
        """ Set up from an ASCII string

        Format
        ------
        GPCM
            A               B (... to col. 50) B               C (... to col. 100) + cst string
        vulcan
            [ A + B -> C + D (... to col. 40) ] + cst string

        Parameter
        ---------
        line : string
            ASCII string (usually formula and rate constant parameter)
        format : string (optional)
            Model format to write in (default: GPCM, options: GPCM, vulcan)
        high_pressure_term : bool (optional)
            Does the rate constant include a high-pressure term? (default: False)
        """
        if format == 'GPCM':

            reactants          = line[:50].split()
            products           = line[50:100].split()
            cst_params         = line[101:]

            if 'hv' in reactants:
                reactants.pop(reactants.index('hv'))
                if int(line[100]) == 0:
                    # photolysis calculated with cross sections
                    return cls(reactants,products)
            else:
                high_pressure_term = int(line[100]) == 2

        elif format == 'vulcan':

            reactants  = line[line.index('[')+1:line.index('->')].split()[::2]
            products   = line[line.index('->')+2:line.index(']')].split()[::2]
            cst_params = line[line.index(']')+1:]

            if cst_params.split()[0][0].isalpha():
                # photolysis calculated with cross sections
                return cls(reactants,products,None)
      
        if 'M' in reactants:
            reactants.pop(reactants.index('M'))
        if 'M' in products:
            products.pop(products.index('M'))

        if high_pressure_term:
            return cls(reactants,products,reaction_constant_dens_dep.from_string(cst_params,format))
        else:
            return cls(reactants,products,reaction_constant.from_string(cst_params,format))

    def to_string(self,format='GPCM'):
        """ Return an ASCII line readable by a photochemical model

        Format
        ------
        GPCM
            A               B (... to col. 50) B               C (... to col. 100) + cst string
        vulcan
            [ A + B -> C + D (... to col. 40) ] + cst string
        chempath
            A+B=C+D

        Parameter
        ---------
        format : string (optional)
            Model format to write in (default: GPCM, options: GPCM, vulcan, chempath)

        Returns
        -------
        string
            ASCII line readable by a photochemical model
        """

        if format == 'GPCM':
            line = ''
            # reactants (characters 1 to 50)
            for molecule in self.reactants:
                line += molecule.lower().ljust(16,' ')
            line = line.ljust(50,' ')
            
            # products (characters 51 to 100)
            for molecule in self.products:
                line += molecule.lower().ljust(16,' ')
            line = line.ljust(100,' ')
            
        elif format == 'vulcan':
            # formula
            line = '[ '
            for molecule in self.reactants:
                line = line + molecule[:-4].upper() + ' + ' if '_vap' in molecule else line + molecule.upper() + ' + '
            line = line[:-2] + '-> '
            for molecule in self.products:
                line = line + molecule[:-4].upper() + ' + ' if '_vap' in molecule else line + molecule.upper() + ' + '
            line = line[:-2]
            line = line.ljust(40,' ') + ' ]   '

        elif format == 'chempath':
            line = self.formula
            line = line.replace(' ','')
            line = line.replace('->','=')
            return line

        elif format == 'cpa':
            line = self.formula
            line = line.replace('->','=>')
            return line
            
        # constant
        if type(self.constant ) in [reaction_constant,reaction_constant_dens_dep]:
            line += self.constant.to_string(format)
        else:
            print(self.formula,'has a custom reaction constant: add it manually')
        return line
            

class termolecular_reaction(reaction):
    """ Instantiates a three-body reaction
    
    Attributes
    ----------
    formula : str
        Reaction formula (e.g. "A + B -> C + D")
    reactants : list
        Reactanting molecules formulae (e.g. ["A", "B"])
    products : list
        Produced molecules formulae (e.g. ["C", "D"])
    constant : callable
        Reaction rate constant, function of potentially temperature and densities

    Methods
    -------
    from_string(line,format)
        Set up from an ASCII string
    to_string(format)
        Return an ASCII line readable by a photochemical model
    """
    def __init__(self,reactants,products,constant):

        self.formula   = ''.join([r_+' + ' for r_ in reactants[:-1]])+reactants[-1]+' + M -> '+''.join([r_+' + ' for r_ in products[:-1]])+products[-1]+' + M'
        self.products  = products
        self.reactants = reactants
        self.constant  = constant

    @classmethod
    def from_string(cls,line,format='GPCM',high_pressure_term=False):
        """ Set up from an ASCII string

        Format
        ------
        GPCM
            A               B (... to col. 50) B               C (... to col. 100) + cst string
        vulcan
            [ A + B -> C + D (... to col. 40) ] + cst string

        Parameter
        ---------
        line : string
            ASCII string (usually formula and rate constant parameter)
        format : string (optional)
            Model format to write in (default: GPCM, options: GPCM, vulcan)
        high_pressure_term : bool (optional)
            Does the rate constant include a high-pressure term? (default: False)
        """
        new_instance = super().from_string(line,format,high_pressure_term)
        if not high_pressure_term:
            # In case we have a 3-body reaction without
            # a high pressure term, enforce density dependence
            new_instance.constant.params['d'] = 1.
        return new_instance

    def to_string(self,format='GPCM'):
        """ Return an ASCII line readable by a photochemical model

        Format
        ------
        GPCM
            A         B         M (... to col. 50) B         C (... to col. 100) + cst string
        vulcan
            [ A + B + M -> C + D + M (... to col. 40) ] + cst string
        chempath
            A+B=C+D

        Parameter
        ---------
        format : string (optional)
            Model format to write in (default: GPCM, options: GPCM, vulcan, chempath)

        Returns
        -------
        string
            ASCII line readable by a photochemical model
        """
        if format == 'GPCM':
            line = ''
            # reactants (characters 1 to 50)
            for molecule in self.reactants:
                line += molecule.lower().ljust(16,' ')
            line += 'M'
            line = line.ljust(50,' ')
            
            # products (characters 51 to 100)
            for molecule in self.products:
                line += molecule.lower().ljust(16,' ')
            line = line.ljust(100,' ')
            
            # constant (characters 101 to end of line)
            line += self.constant.to_string(format)
            return line
            
        elif format == 'vulcan':
            # formula
            line = '[ '
            for molecule in self.reactants:
                line += molecule[:-4].upper() + ' + ' if '_vap' in molecule else molecule.upper() + ' + '
            line += 'M -> '
            for molecule in self.products:
                line += molecule[:-4].upper() + ' + ' if '_vap' in molecule else molecule.upper() + ' + '
            line += 'M'
            line = line.ljust(40,' ') + ' ]   '
            
            # constant
            line += self.constant.to_string(format)
            return line

        elif format == 'chempath':
            line = self.formula
            line = line.replace(' ','')
            line = line.replace('->','=')
            return line

        elif format == 'cpa':
            line = self.formula
            line = line.replace('->','=>')
            return line

class photolysis(reaction):

    def __init__(self,reactants,products,constant=None):

        self.formula   = ''.join([r_+' + ' for r_ in reactants[:-1]])+reactants[-1]+' + hv -> '+''.join([r_+' + ' for r_ in products[:-1]])+products[-1]
        self.products  = products
        self.reactants = reactants
        self.constant  = constant

    def rate(self,densities,**kw):
        """ Computes reaction rate

        Parameters
        ----------
        j : float
            Photolysis rate [s^-1]
        densities : dict
            Molecular densities [cm^-3]

        Returns
        -------
        float
            Value of the reaction rate [cm^-3.s^-1]
        """

        if 'j' in kw:
            return kw['j']*densities[self.reactants[0]]
        else:
            # if a photolysis is prescribed with an Arrhenius constant, it is a trick to give
            # it a constant rate. We can simply call the constant with an arbitrary temperature.
            return self.constant(1.,densities[background])*densities[self.reactants[0]]

    def to_string(self,format='GPCM'):
        """ Return an ASCII line readable by a photochemical model

        Format
        ------
        GPCM
            A         hv (... to col. 50) B         C (... to col. 100) + cst string
        vulcan
            [ A -> C + D (... to col. 40) ]             A    br (to add manually)
        chemapth
            A=C+D (to check)

        Parameter
        ---------
        format : string (optional)
            Model format to write in (default: GPCM, options: GPCM, vulcan)

        Returns
        -------
        string
            ASCII line readable by a photochemical model
        """

        if format == 'GPCM':
            line = ''
            # reactants (characters 1 to 50)
            for molecule in self.reactants:
                line = line + molecule.lower().ljust(16,' ')
            line += 'hv'
            line = line.ljust(50,' ')
            
            # products (characters 51 to 100)
            for molecule in self.products:
                line += molecule.lower().ljust(16,' ')
            line = line.ljust(100,' ')
            
            # constant (characters 101 to end of line)
            if self.constant == None:
                line += '0'
                print(self.formula,'is a photolysis: you need to add the cross section files manually')
            else:
                line += self.constant.to_string(format)
            return line
            
        elif format == 'vulcan':
            # formula
            line = '[ '
            for molecule in self.reactants:
                line += molecule[:-4].upper() + ' + ' if '_vap' in molecule else molecule.upper() + ' + '
            line = line[:-2] + '-> '
            for molecule in self.products:
                line += molecule[:-4].upper() + ' + ' if '_vap' in molecule else molecule.upper() + ' + '
            line = line[:-2]
            line = line.ljust(40,' ') + ' ]   '
            molecule = self.reactants[0]
            line += molecule[:-4].upper() + ' + ' if '_vap' in molecule else molecule.upper()
            print(self.formula,'is a photolysis: you need to add the branching ratio manually')
            return line

        elif format == 'chempath':
            line = self.formula
            line = line.replace(' ','')
            line = line.replace('hv','')
            line = line.replace('->','=')
            return line

        elif format == 'cpa':
            line = self.formula
            line = line.replace('->','=>')
            return line

class reaction_constant:
    """ Basic (Arrhenius) reaction rate constant
    
    Instantiates type 1 rate constant for a particular reaction
    (https://lmdz-forge.lmd.jussieu.fr/mediawiki/Planets/index.php/Photochemistry#Reaction_rate_formulae)
    
    Attributes
    ----------
    params : dict
        Reaction-specific set of parameters for the rate constant: a,T0,c,b,d

    Methods
    -------
    call(T,background_density)
        Compute the reaction rate for given temperature and background density
    """

    def __init__(self,params):
        """
        Parameters
        ----------
        params : dict
            Reaction-specific set of parameters for the rate constant: a,T0,c,b,d
        """
        self.params = params

    def __call__(self,T,background_density):
        """
        Parameters
        ----------
        T : float
            Temperature [K]
        background_density : float
            Background gas density [cm^-3]

        Returns
        -------
        float
            Value of the reaction rate constant [cm^3.s^-1]
        """
        return self.params['a']*(T/self.params['T0'])**self.params['c']*np.exp(-self.params['b']/T)*background_density**self.params['d']

    @classmethod
    def from_string(cls,line,format='GPCM'):
        """ Creates an instance from an ASCII string in a variety of formats

        Currently read formats: Generic PDM, vulcan

        Parameters
        ----------
        line : string
            Rate constant parameters
        format : string (optional)
            Format in which parameters are writtenn (default: Generic PCM)

        Returns
        -------
        reaction_constant
            The instance of reaction_constant created

        """
        if format == 'GPCM':
            cst_param = line.split()
            return cls({'a':float(cst_param[0]),'b':float(cst_param[1]),
                        'c':float(cst_param[2]),'T0':float(cst_param[3]),
                        'd':float(cst_param[4])})

        elif format == 'vulcan':
            cst_param = line.split()
            T0 = 300.  # assumed default vaue
            return cls({'a':float(cst_param[0])*T0**float(cst_param[1]),'b':float(cst_param[2]),
                        'c':float(cst_param[1]),'T0':T0,'d':0.})

    def to_string(self,format='GPCM'):
        """ Return an ASCII line readable by a photochemical model

        Format
        ------
        GPCM
            1    a           T0          c           b          d
        vulcan
            A            B        C

        Parameter
        ---------
        format : string (optional)
            Model format to write in (default: GPCM, options: GPCM, vulcan)

        Returns
        -------
        string
            ASCII line readable by a photochemical model
        """

        if format == 'GPCM':
            return '1    '+'{:1.2e}'.format(self.params['a']).ljust(12,' ')+str(self.params['b']).ljust(12,' ') \
                          +str(self.params['c']).ljust(12,' ')+str(self.params['T0']).ljust(12,' ')  \
                          +str(self.params['d'])
        elif format == 'vulcan':
            
            return '{:1.2e}'.format(self.params['a']/self.params['T0']**self.params['c']).ljust(12,' ')  \
                  +str(self.params['c']).ljust(12,' ')+str(self.params['b'])

class reaction_constant_dens_dep(reaction_constant):
    """ Type 2 reaction rate constant (Arrhenius with high pressure term)
    
    Instantiates type 2 rate constant for a particular reaction
    (https://lmdz-forge.lmd.jussieu.fr/mediawiki/Planets/index.php/Photochemistry#Reaction_rate_formulae)
    
    Attributes
    ----------
    params : dict
        Reaction-specific set of parameters for the rate constant: k0,T0,n,a0,kinf,m,b0,g,h,dup,ddown,fc

    Methods
    -------
    call(T,background_density)
        Computes the reaction rate for given temperature and background density
    """
    def __call__(self,T,background_density):
        """
        Parameters
        ----------
        T : float
            Temperature [K]
        background_density : float
            Background gas density [cm^-3]

        Returns
        -------
        float
            The value of the reaction rate constant [cm^3.s^-1]
        """
        num = self.params['k0']*(T/self.params['T0'])**self.params['n']*np.exp(-self.params['a0']/T)
        den = self.params['kinf']*(T/self.params['T0'])**self.params['m']*np.exp(-self.params['b0']/T)
    
        return self.params['g']*np.exp(-self.params['h']/T)+num*background_density**self.params['dup']/(1+num/den*background_density**self.params['ddown'])*self.params['fc']**(1/(1+np.log10(num/den*background_density)**2))

    @classmethod
    def from_string(cls,line,format='GPCM'):
        """ Creates an instance from an ASCII string in a variety of formats

        Currently read formats: Generic PDM, vulcan

        Parameters
        ----------
        line : string
            Rate constant parameters
        format : string (optional)
            Format in which parameters are writtenn (default: Generic PCM)

        Returns
        -------
        reaction_constant
            The instance of reaction_constant created

        """
        if format == 'GPCM':
            cst_param = line.split()
            return cls({'k0':float(cst_param[0]),  'n':float(cst_param[1]),   'a0':float(cst_param[2]),
                        'kinf':float(cst_param[3]),'m':float(cst_param[4]),   'b0':float(cst_param[5]),
                        'T0':float(cst_param[6]),  'fc':float(cst_param[7]),  'g':float(cst_param[8]),
                        'h':float(cst_param[9]),   'dup':float(cst_param[10]),'ddown':float(cst_param[10])})

        elif format == 'vulcan':
            cst_param = line.split()
            T0 = 300.
            return cls({'k0':float(cst_param[0])*T0**float(cst_param[1]),  'n':float(cst_param[1]),'a0':float(cst_param[2]),
                        'kinf':float(cst_param[3])*T0**float(cst_param[4]),'m':float(cst_param[4]),'b0':float(cst_param[5]),
                        'T0':T0,'fc':0.6,'g':0.,'h':0.,'dup':1,'ddown':1})

    def to_string(self,format='GPCM'):
        """ Return an ASCII line readable by a photochemical model

        Format
        ------
        GPCM
            1    k0          n           a0          kinf        m           b0          T0          fc          g           h           dup         ddown
        vulcan
            A_0         B_0          C_0      A_inf       B_int        C_inf

        Parameter
        ---------
        format : string (optional)
            Model format to write in (default: GPCM, options: GPCM, vulcan)

        Returns
        -------
        string
            ASCII line readable by a photochemical model
        """
        if format == 'GPCM':
            return '2    '+'{:1.2e}'.format(self.params['k0']).ljust(12,' ') +str(self.params['T0']).ljust(12,' ')  +str(self.params['n']).ljust(12,' ')\
                          +str(self.params['a0']).ljust(12,' ') +'{:1.2e}'.format(self.params['kinf']).ljust(12,' ')+str(self.params['m']).ljust(12,' ') \
                          +str(self.params['b0']).ljust(12,' ') +str(self.params['g']).ljust(12,' ')  +str(self.params['h']).ljust(12,' ')   \
                          +str(self.params['dup']).ljust(12,' ')+str(self.params['ddown']).ljust(12,' ')+str(self.params['fc'])
        elif format == 'vulcan':
            return '{:1.2e}'.format(self.params['k0']/self.params['T0']**self.params['n']).ljust(12,' ')+str(self.params['n']).ljust(12,' ')\
                  +str(self.params['a0']).ljust(12,' ')+'{:1.2e}'.format(self.params['kinf']/self.params['T0']**self.params['m']).ljust(12,' ') \
                  +str(self.params['m']).ljust(12,' ')+str(self.params['b0'])

class network:
    """ Reaction network object
    
    Attributes
    ----------
    reactions : dict
        Chemical reactions in the network
    species : list(string)
        Chemical species in the network

    Methods
    -------
    append(to_append)
        Computes the reaction rate for given temperature and density
    update_species()
        Update list of species from list of reactions
    from_file(path,format)
        Instantiates a network from a file
    to_file(path,format)
        Save the network into a file
    get_subnetwork(criteria)
        Generate a subnetwork based on a dictionnary of given criteria
    """
    def __init__(self,reactions={}):

        self.reactions = reactions
        self.species   = []
        if reactions != {}:
            self.update_species()

    def __getitem__(self,formula):

        return self.reactions[formula]

    def __iter__(self):

        self.current = list(self.reactions.keys())[0]
        return self

    def __next__(self):

        if self.current == 'finished':
            raise StopIteration
        elif self.current == list(self.reactions.keys())[-1]:
            current = self.current
            self.current = 'finished'
        else:
            current = self.current
            self.current = list(self.reactions.keys())[list(self.reactions.keys()).index(self.current)+1]
            
        return self.reactions[current]
    
    def append(self,to_append):
        """ Append a reaction to the network

        Updates list of species to account for new species

        Parameter
        ---------
        to_append : reaction or network
            Reaction or network of reactions to append
        """
        if type(to_append) == network:
            for r in to_append:
                self.append(r)
        else:
            self.reactions = self.reactions | {to_append.formula:to_append}    
        self.update_species()

    def update_species(self):
        """ Update list of species from list of reactions
        
        """
        if self.reactions == {}:
            raise Exception('Network empty')
            
        for r in self:
            for sp in r.reactants:
                if not sp in self.species:
                    self.species.append(sp)
            for sp in r.products:
                if not sp in self.species:
                    self.species.append(sp)

    @classmethod
    def from_file(cls,path,format='GPCM'):
        """ Instantiates a network from a file

        Currently read formats: Generic PCM, VULCAN
        
        Parameters
        ----------
        path : str
            Path to the network file
        format : string (optional)
            Format of the network file to read. Default: GPCM
    
        Returns
        -------
        network
            The network instance created
        """
        reactions = {}

        if format == 'GPCM':
            with open(path) as reactfile:
                for line in reactfile:
                    # Commented line
                    if line[0] == '!':
                        # Hard-coded reaction
                        if 'hard' in line and 'coded' in line:
                            hard_coded_reaction = reaction(line[1:51].split(),line[51:101].split(),None)
                            print('reaction ',hard_coded_reaction.formula,'seems to be hard-coded. Add it manually if needed.')
                        continue
                    else:
                        # Photolysis
                        if 'hv' in line:
                            new_reaction = photolysis.from_string(line,format)
                        # Other reactions
                        else:
                            # three-body reaction
                            if 'M' in line:
                                if line[line.index('M')+2] != ' ':
                                    # if third body is not the background gas, treat it as a simple reaction
                                    new_reaction = reaction.from_string(line.replace('M',' '),format)
                                else:
                                    new_reaction = termolecular_reaction.from_string(line,format)
                            # two-body reaction
                            else:
                                new_reaction = reaction.from_string(line,format)
                        reactions[new_reaction.formula] = new_reaction

        elif format == 'vulcan':
            
            with open(path) as reactfile:
                re_tri     = False
                re_tri_k0  = False
                special_re = False
                conden_re  = False
                recomb_re  = False
                photo_re   = False
                ion_re     = False
                for line in reactfile:
                    if line.startswith("# 3-body"): 
                        re_tri = True
                        
                    if line.startswith("# 3-body reactions without high-pressure rates"):
                        re_tri_k0 = True
                        
                    elif line.startswith("# special"): 
                        re_tri = False
                        re_tri_k0 = False
                        special_re = True # switch to reactions with special forms (hard coded)  
                    
                    elif line.startswith("# condensation"): 
                        re_tri = False
                        re_tri_k0 = False
                        special_re = False 
                        conden_re = True
                    
                    elif line.startswith("# radiative"):
                        re_tri = False
                        re_tri_k0 = False
                        special_re = False 
                        conden_re = False
                        recomb_re = True
                        
                    elif line.startswith("# photo"):
                        re_tri = False
                        re_tri_k0 = False
                        special_re = False # turn off reading in the special form
                        conden_re = False
                        recomb_re = False
                        photo_re = True
                         
                    elif line.startswith("# ionisation"):
                        re_tri = False
                        re_tri_k0 = False
                        special_re = False # turn off reading in the special form
                        conden_re = False
                        recomb_re = False
                        photo_re = False
                        ion_re = True

                    if line.startswith("#") or line.split() == []:
                        continue
                    else:
                        line = line[line.index('['):]

                    if re_tri:
                        new_reaction = termolecular_reaction.from_string(line,format,False)
                    elif re_tri_k0:
                        new_reaction = termolecular_reaction.from_string(line,format,True)
                    elif special_re:
                        print('special reaction :',line[line.index('[')+1:line.index(']')])
                    elif conden_re:
                        print('condensation reaction :',line[line.index('[')+1:line.index(']')])
                    elif recomb_re:
                        print('recombination reaction :',line[line.index('[')+1:line.index(']')])
                    elif photo_re:
                        new_reaction = photolysis.from_string(line,format)
                    elif ion_re:
                        print('ionisation reaction :',line[line.index('[')+1:line.index(']')])
                    else:
                        new_reaction = reaction.from_string(line,format)

                    reactions[new_reaction.formula] = new_reaction
        
        return cls(reactions)

    def to_file(self,path,format='GPCM'):
        """ Save the network into a file

        Currently read formats: Generic PCM, VULCAN
        
        Parameters
        ----------
        path : str
            Path to the network file to create
        format : string (optional)
            Format of the network file to read. Default: GPCM
        """
        if format == 'GPCM':
            with open(path, 'w') as reactfile:
                for i,r in enumerate(self):
                    reactfile.write(f'! Reaction {str(i+1)}: {r.formula}\n')
                    reactfile.write(r.to_string(format)+'\n')
                    
        elif format == 'vulcan':
            with open(path, 'w') as reactfile:
                
                # header
                reactfile.write('# VULCAN photochemical network\n')
                reactfile.write('# Generated by the photochemical tools of the Generic PCM\n')
                reactfile.write('#########################################################\n')
                reactfile.write('# in the form of k = A T^B exp(-C/T)\n')
                # 2-body reactions
                reactfile.write('# Two-body Reactions\n')
                reactfile.write('# id	Reactions                                    A           B           C\n')
                reactfile.write('\n')
                n = 1
                for r in self.reactions:
                    if type(self.reactions[r]) == reaction:
                        reactfile.write(' '+str(n).ljust(7,' ')+self.reactions[r].to_string(format)+'\n')
                        n += 1

                # 3-body reactions with high-pressure term
                reactfile.write('\n')
                reactfile.write('# 3-body and Disscoiation Reactions\n')
                reactfile.write('# id	# Reactions                                  A_0         B_0         C_0         A_inf       B_inf       C_inf\n')
                reactfile.write('\n')
                for r in self.reactions:
                    if type(self.reactions[r]) == termolecular_reaction and type(self.reactions[r].constant) == reaction_constant_dens_dep:
                        reactfile.write(' '+str(n).ljust(7,' ')+self.reactions[r].to_string(format)+'\n')
                        n += 1

                # 3-body reactions without high-pressure term
                reactfile.write('\n')
                reactfile.write('# 3-body reactions without high-pressure rates\n')
                reactfile.write('# id	# Reactions                                  A_0         B_0         C_0\n')
                reactfile.write('\n')
                for r in self.reactions:
                    if type(self.reactions[r]) == termolecular_reaction and type(self.reactions[r].constant) == reaction_constant:
                        reactfile.write(' '+str(n).ljust(7,' ')+self.reactions[r].to_string(format)+'\n')
                        n += 1

                # photolysis
                reactfile.write('\n')
                reactfile.write('# reverse stops\n')
                reactfile.write('# photo disscoiation (no reversals) 		            # use sp to link br_index to RXXX\n')
                reactfile.write('# id	# Reactions                                  sp	    br_index #(starting from 1)\n')
                reactfile.write('\n')
                for r in self.reactions:
                    if type(self.reactions[r]) == photolysis:
                        reactfile.write(' '+str(n).ljust(7,' ')+self.reactions[r].to_string(format)+'\n')
                        n += 1

    def change_species_name(self,old_name,new_name):
        """ Change the name of a species throughout the whole network

        Parameter
        ---------
        old_name : string
            Name of the species to change
        new_name : string
            New name to put instead

        """
        for r in self:
            if old_name in r.reactants:
                r.reactants[r.reactants.index(old_name)] = new_name
            if old_name in r.products:
                r.products[r.products.index(old_name)] = new_name
            # No need to change it in the formula:
            # the goal is to convert between the
            # naming conventions of various models
        self.species[self.species.index(old_name)] = new_name

    def save_traceur_file(self,path):
        """ Generate a traceur.def file readable by the Generic PCM

        Almost readable: you have to add manually the molecular masses

        Parameter
        ---------
        path : string
            Path to the file to be created

        """
        with open(path, 'w') as tracfile: 
            tracfile.write('#ModernTrac-v1\n')
            tracfile.write(str(len(self.species))+'\n')
            for sp in self.species:
                tracfile.write(sp.lower().ljust(24,' ')+'mmol='.ljust(10,' ')+'is_chim=1\n')

    def get_subnetwork(self,criteria,**kw):
        """ Generate a subnetwork based on a dictionnary of given criteria

        Parameter
        ---------
        criteria : dict
            Selection criteria to include reactions in the subnetwork.
            Criteria are: - 'species' : list of any of the species from the network
                          - 'element' : list of elements to include
                          - 'type'    : reaction, termolecular_reaction, photolysis

        Return
        ------
        network : A subnetwork composed following the input criteria
        """
        subnetwork = network()
        
        for r in self.reactions:
            
            keep = False
            
            if 'type' in criteria:
                if type(self.reactions[r]) in criteria['type']:
                    keep = True
                    
            if 'species' in criteria:
                if 'only' in kw and kw['only']:
                    keep = True
                    for sp in self.reactions[r].reactants + self.reactions[r].products:
                        if not sp in criteria['species']:
                            keep = False
                            break
                else:
                    for sp in criteria['species']:
                        if sp in self.reactions[r].reactants + self.reactions[r].products:
                            keep = True
                            
            if 'elements' in criteria:
                if 'only' in kw and kw['only']:
                    keep = True
                    for sp in self.reactions[r].reactants + self.reactions[r].products:
                        for elem_or_stoichio in sp:
                            # This will only work for single-letter elements
                            if elem_or_stoichio.isalpha() and not elem_or_stoichio in criteria['elements']:
                                keep = False
                                break
                else:
                    for elem in criteria['elements']:
                        for sp in self.reactions[r].reactants + self.reactions[r].products:
                            if elem in sp:
                                keep = True
                            
            if keep:
                subnetwork.append(self.reactions[r])
                
        return subnetwork
