#!/usr/bin/env python3
"""Plot stratification evolution across multiple PEM restartevo NetCDF files.

Combines time-series diagnostics from restartevo files with optional orbital
forcing traces loaded from an ASCII table.
"""

import os
import sys
import numpy as np
from glob import glob
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.colors import LinearSegmentedColormap, LogNorm
from matplotlib.ticker import FuncFormatter
from scipy.interpolate import interp1d


def get_user_inputs():
    """
    Prompt the user for:
      - folder_path: directory containing NetCDF files (default: "starts")
      - base_name:   base filename (default: "restartevo")
      - infofile:    name of the PEM info file (default: "pem_workflow.sts")
    Validates existence of folder and infofile before returning.
    """
    folder_path = input(
        "Enter the folder path containing the NetCDF files "
        "(press Enter for default [starts]): "
    ).strip() or "starts"
    while not os.path.isdir(folder_path):
        print(f"  » \"{folder_path}\" does not exist or is not a directory.")
        folder_path = input(
            "Enter a valid folder path (press Enter for default [starts]): "
        ).strip() or "starts"

    base_name = input(
        "Enter the base name of the NetCDF files "
        "(press Enter for default [restartevo]): "
    ).strip() or "restartevo"

    infofile = input(
        "Enter the name of the PEM info file "
        "(press Enter for default [pem_workflow.sts]): "
    ).strip() or "pem_workflow.sts"
    while not os.path.isfile(infofile):
        print(f"  » \"{infofile}\" does not exist or is not a file.")
        infofile = input(
            "Enter a valid PEM info filename (press Enter for default [pem_workflow.sts]): "
        ).strip() or "pem_workflow.sts"

    orbfile = input(
        "Enter the name of the orbital parameters ASCII file "
        "(press Enter for default [obl_ecc_lsp_mars.asc]): "
    ).strip() or "obl_ecc_lsp_mars.asc"
    while not os.path.isfile(orbfile):
        print(f"  » \"{orbfile}\" does not exist or is not a file.")
        orbfile = input(
            "Enter a valid orbital parameters ASCII filename (press Enter for default [obl_ecc_lsp_mars.asc]): "
        ).strip() or "pem_workflow.sts"

    return folder_path, base_name, infofile, orbfile


def list_netcdf_files(folder_path, base_name):
    """
    List and sort all NetCDF files matching the pattern {base_name}#.nc
    in folder_path. Returns a sorted list of full file paths.
    """
    pattern = os.path.join(folder_path, f"{base_name}[0-9]*.nc")
    all_files = glob(pattern)
    if not all_files:
        return []

    def extract_index(pathname):
        fname = os.path.basename(pathname)
        idx_str = fname[len(base_name):-3]
        return int(idx_str) if idx_str.isdigit() else float('inf')

    sorted_files = sorted(all_files, key=extract_index)
    return sorted_files


def open_sample_dataset(file_path):
    """
    Open a single NetCDF file and extract:
      - ngrid, nslope
      - longitude, latitude
    Returns (ngrid, nslope, longitude_array, latitude_array).
    """
    with Dataset(file_path, 'r') as ds:
        ngrid = ds.dimensions['physical_points'].size
        nslope = ds.dimensions['nslope'].size
        longitude = ds.variables['longitude'][:].copy()
        latitude = ds.variables['latitude'][:].copy()
    return ngrid, nslope, longitude, latitude


def _get_axis(dim_names, aliases):
    """Return axis index matching aliases in a dimension-name list."""
    for i, name in enumerate(dim_names):
        lname = name.lower()
        for alias in aliases:
            if lname == alias or alias in lname:
                return i
    return None


def _var_to_tsgi(nc_var):
    """
    Convert a NetCDF stratification variable to shape (time, slope, strata, grid).
    Handles both slope-specific variables and generic variables with an nslope dimension.
    """
    arr = np.asarray(nc_var[:], dtype=np.float32)
    dims = [d.lower() for d in nc_var.dimensions]

    ax_time = _get_axis(dims, ('time',))
    ax_slope = _get_axis(dims, ('nslope', 'slope'))
    ax_str = _get_axis(dims, ('nb_str_max', 'nb_str'))
    ax_grid = _get_axis(dims, ('physical_points', 'ngrid', 'grid'))

    if arr.ndim == 4 and None not in (ax_time, ax_slope, ax_str, ax_grid):
        return np.transpose(arr, (ax_time, ax_slope, ax_str, ax_grid))

    if arr.ndim == 3 and None not in (ax_time, ax_str, ax_grid):
        arr_tsg = np.transpose(arr, (ax_time, ax_str, ax_grid))
        return arr_tsg[:, np.newaxis, :, :]

    if arr.ndim == 3 and None not in (ax_slope, ax_str, ax_grid):
        arr_ssg = np.transpose(arr, (ax_slope, ax_str, ax_grid))
        return arr_ssg[np.newaxis, :, :, :]

    if arr.ndim == 3:
        # Historical fallback used by this script: (time, strata, grid)
        return arr[:, np.newaxis, :, :]

    if arr.ndim == 2:
        return arr[np.newaxis, np.newaxis, :, :]

    raise ValueError(f"Unsupported variable shape {arr.shape} for '{nc_var.name}'")


def _slope_index_from_var(vname):
    """Return 0-based slope index from a slope-specific variable name, else None."""
    if 'slope' not in vname:
        return None
    try:
        return int(vname.split('slope')[1].split('_')[0]) - 1
    except (ValueError, IndexError):
        return None


def collect_stratification_variables(files, base_name):
    """
    Scan all files to collect:
      - variable names for each stratification property
      - max number of strata (max_nb_str)
      - global min base elevation and max top elevation
    Returns:
      - var_info: dict mapping each property_name -> sorted list of var names
      - max_nb_str: int
      - min_base_elev: float
      - max_top_elev: float
    """
    max_nb_str = 0
    min_base_elev = np.inf
    max_top_elev = -np.inf

    property_markers = {
        'heights':   'top_elevation',
        'co2_ice':   'h_co2ice',
        'h2o_ice':   'h_h2oice',
        'dust':      'h_dust',
        'pore':      'h_pore',
        'poreice_coef1': 'poreice_coef1',
        'poreice_coef2': 'poreice_coef2',
        'poreice_coef3': 'poreice_coef3',
        'poreice_coef4': 'poreice_coef4'
    }
    var_info = {prop: set() for prop in property_markers}

    for file_path in files:
        with Dataset(file_path, 'r') as ds:
            if 'nb_str_max' in ds.dimensions:
                max_nb_str = max(max_nb_str, ds.dimensions['nb_str_max'].size)

            for full_var in ds.variables:
                if 'stratif' not in full_var:
                    continue

                if property_markers['heights'] in full_var:
                    arr = np.asarray(ds.variables[full_var][:])
                    if arr.size > 0:
                        min_base_elev = min(min_base_elev, float(np.nanmin(arr)))
                        max_top_elev = max(max_top_elev, float(np.nanmax(arr)))
                    var_info['heights'].add(full_var)

                for prop, marker in property_markers.items():
                    if prop != 'heights' and marker in full_var:
                        var_info[prop].add(full_var)

    for prop in var_info:
        var_info[prop] = sorted(var_info[prop])

    return var_info, max_nb_str, min_base_elev, max_top_elev


def load_full_datasets(files):
    """
    Open all NetCDF files and return a list of Dataset objects.
    (They should be closed by the caller after use.)
    """
    return [Dataset(fp, 'r') for fp in files]


def extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str):
    """
    Build:
      - heights_data[t_idx][isl] = 2D array (ngrid, n_strata_current) of top_elevations.
      - raw_prop_arrays[prop] = 4D array (ngrid, ntime, nslope, max_nb_str) of per-strata values.
    Returns:
      - heights_data: list (ntime) of lists (nslope) of 2D arrays
      - raw_prop_arrays: dict mapping each property_name -> 4D array
      - ntime: number of time steps (files)
    """
    ntime = len(datasets)

    heights_data = [
        [None for _ in range(nslope)]
        for _ in range(ntime)
    ]
    for t_idx, ds in enumerate(datasets):
        for var_name in var_info['heights']:
            if var_name not in ds.variables:
                continue
            raw_tsgi = _var_to_tsgi(ds.variables[var_name])
            raw_sgi = raw_tsgi[0]  # (nslope_var, nstrata, ngrid)

            slope_idx = _slope_index_from_var(var_name)
            if slope_idx is not None and raw_sgi.shape[0] == 1:
                if 0 <= slope_idx < nslope:
                    heights_data[t_idx][slope_idx] = raw_sgi[0].T
                continue

            nsl = min(nslope, raw_sgi.shape[0])
            for isl in range(nsl):
                heights_data[t_idx][isl] = raw_sgi[isl].T

    raw_prop_arrays = {}
    for prop in var_info:
        if prop == 'heights':
            continue
        raw_prop_arrays[prop] = np.zeros((ngrid, ntime, nslope, max_nb_str), dtype=np.float32)

    for prop in raw_prop_arrays:
        arr = raw_prop_arrays[prop]
        for t_idx, ds in enumerate(datasets):
            for var_name in var_info[prop]:
                if var_name not in ds.variables:
                    continue

                raw_tsgi = _var_to_tsgi(ds.variables[var_name])
                raw_sgi = raw_tsgi[0]  # (nslope_var, nstrata, ngrid)
                slope_idx = _slope_index_from_var(var_name)

                if slope_idx is not None and raw_sgi.shape[0] == 1:
                    if 0 <= slope_idx < nslope:
                        n_strata_current = min(raw_sgi.shape[1], max_nb_str)
                        arr[:, t_idx, slope_idx, :n_strata_current] = raw_sgi[0, :n_strata_current, :].T
                    continue

                nsl = min(nslope, raw_sgi.shape[0])
                for isl in range(nsl):
                    n_strata_current = min(raw_sgi.shape[1], max_nb_str)
                    arr[:, t_idx, isl, :n_strata_current] = raw_sgi[isl, :n_strata_current, :].T

    return heights_data, raw_prop_arrays, ntime


def normalize_to_fractions(raw_prop_arrays):
    """
    Given raw_prop_arrays for 'co2_ice', 'h2o_ice', 'dust', 'pore' (in meters),
    normalize each set of strata so that the sum of those four = 1 per cell.
        Returns:
            - frac_arrays: dict mapping same keys -> 4D arrays of fractions (0..1).
            - includes an additional 'pore_ice' fraction derived from pore-ice coefficients.
    """
    co2 = raw_prop_arrays['co2_ice']
    h2o = raw_prop_arrays['h2o_ice']
    dust = raw_prop_arrays['dust']
    pore = raw_prop_arrays['pore']

    total = co2 + h2o + dust + pore
    mask = total > 0.0

    frac_co2 = np.zeros_like(co2, dtype=np.float32)
    frac_h2o = np.zeros_like(h2o, dtype=np.float32)
    frac_dust = np.zeros_like(dust, dtype=np.float32)
    frac_pore = np.zeros_like(pore, dtype=np.float32)
    frac_pore_ice = np.zeros_like(pore, dtype=np.float32)

    frac_co2[mask] = co2[mask] / total[mask]
    frac_h2o[mask] = h2o[mask] / total[mask]
    frac_dust[mask] = dust[mask] / total[mask]
    frac_pore[mask] = pore[mask] / total[mask]

    if all(key in raw_prop_arrays for key in ('poreice_coef1', 'poreice_coef2', 'poreice_coef3', 'poreice_coef4')):
        z_norm = np.linspace(0.0, 1.0, 33, dtype=np.float32)
        c1 = raw_prop_arrays['poreice_coef1']
        c2 = raw_prop_arrays['poreice_coef2']
        c3 = raw_prop_arrays['poreice_coef3']
        c4 = raw_prop_arrays['poreice_coef4']
        profile = (c1[..., None] + c2[..., None] * z_norm + c3[..., None] * z_norm**2 + c4[..., None] * z_norm**3)
        profile = np.clip(profile, 0.0, 1.0)
        poreice_mean = np.clip(np.trapz(profile, z_norm, axis=-1), 0.0, 1.0).astype(np.float32)
    else:
        poreice_mean = np.zeros_like(pore, dtype=np.float32)

    frac_pore_ice[mask] = pore[mask] * poreice_mean[mask] / total[mask]

    return {
        'co2_ice': frac_co2,
        'h2o_ice': frac_h2o,
        'dust':     frac_dust,
        'pore':     frac_pore,
        'pore_ice': frac_pore_ice
    }


def read_infofile(file_name):
    """
    Reads "pem_workflow.sts". Expects:
      - First line: parameters where the 3rd value is martian_to_earth conversion factor.
      - Each subsequent line: floats where first value is simulation timestamp (in Mars years).
    Returns:
      - date_time: 1D numpy array of timestamps (Mars years)
      - martian_to_earth: float conversion factor
    """
    date_time = []
    with open(file_name, 'r') as fp:
        first = fp.readline().split()
        martian_to_earth = float(first[2])
        for line in fp:
            parts = line.strip().split()
            if not parts:
                continue
            try:
                date_time.append(float(parts[0]))
            except ValueError:
                continue
    return np.array(date_time, dtype=np.float64), martian_to_earth


def get_yes_no_input(prompt: str) -> bool:
    """
    Prompt the user with a yes/no question. Returns True for yes, False for no.
    """
    while True:
        choice = input(f"{prompt} (y/n): ").strip().lower()
        if choice in ['y', 'yes']:
            return True
        elif choice in ['n', 'no']:
            return False
        else:
            print("Please respond with y or n.")


def prompt_discretization_step(max_top_elev):
    """
    Prompt for a positive float dz such that 0 < dz <= max_top_elev.
    """
    while True:
        entry = input(
            "Enter the discretization step of the reference grid for the elevation [m]: "
        ).strip()
        try:
            dz = float(entry)
            if dz <= 0:
                print("  » Discretization step must be strictly positive!")
                continue
            if dz > max_top_elev:
                print(
                    f"  » {dz:.3e} m is greater than the maximum top elevation "
                    f"({max_top_elev:.3e} m). Please enter a smaller value."
                )
                continue
            return dz
        except ValueError:
            print("  » Invalid numeric value. Please try again.")


def interpolate_data_on_refgrid(
    heights_data,
    prop_arrays,
    min_base_for_interp,
    max_top_elev,
    dz,
    exclude_sub=False
):
    """
    Build a reference elevation grid and interpolate strata fractions onto it.

    Returns:
      - ref_grid: 1D array of elevations (nz,)
      - gridded_data: dict mapping each property_name to 4D array
        (ngrid, ntime, nslope, nz) with interpolated fractions.
      - top_index: 3D array (ngrid, ntime, nslope) of ints:
        number of levels covered by the topmost stratum.
    """
    if exclude_sub and (dz > max_top_elev):
        ref_grid = np.array([0.0, max_top_elev], dtype=np.float32)
    else:
        ref_grid = np.arange(min_base_for_interp, max_top_elev + dz/2, dz)
    nz = len(ref_grid)
    print(f"> Number of reference grid points = {nz}")

    sample_prop = next(iter(prop_arrays.values()))
    ngrid, ntime, nslope, max_nb_str = sample_prop.shape

    gridded_data = {
        prop: np.full((ngrid, ntime, nslope, nz), -1.0, dtype=np.float32)
        for prop in prop_arrays
    }
    top_index = np.zeros((ngrid, ntime, nslope), dtype=np.int32)

    for ig in range(ngrid):
        for t_idx in range(ntime):
            for isl in range(nslope):
                h_mat = heights_data[t_idx][isl]
                if h_mat is None:
                    continue

                raw_h = h_mat[ig, :]
                h_all = np.full((max_nb_str,), np.nan, dtype=np.float32)
                n_strata_current = raw_h.shape[0]
                h_all[:n_strata_current] = raw_h

                if exclude_sub:
                    epsilon = 1e-6
                    valid_mask = (h_all >= -epsilon)
                else:
                    valid_mask = (~np.isnan(h_all)) & (h_all != 0.0)

                if not np.any(valid_mask):
                    continue

                h_valid = h_all[valid_mask]
                top_h = np.max(h_valid)
                i_zmax = np.searchsorted(ref_grid, top_h, side='right')
                top_index[ig, t_idx, isl] = i_zmax
                if i_zmax == 0:
                    continue

                for prop, arr in prop_arrays.items():
                    prop_profile_all = arr[ig, t_idx, isl, :]
                    prop_profile = prop_profile_all[valid_mask]
                    if prop_profile.size == 0:
                        continue

                    f_interp = interp1d(
                        h_valid,
                        prop_profile,
                        kind='next',
                        bounds_error=False,
                        fill_value=-1.0
                    )
                    gridded_data[prop][ig, t_idx, isl, :i_zmax] = f_interp(ref_grid[:i_zmax])

    return ref_grid, gridded_data, top_index


def attach_format_coord(ax, mat, x, y, is_pcolormesh=True):
    """
    Attach a format_coord function to the axes to display x, y, and value at cursor.
    Works for both pcolormesh and imshow style grids.
    """
    # Determine dimensions
    if mat.ndim == 2:
        ny, nx = mat.shape
    elif mat.ndim == 3 and mat.shape[2] in (3, 4):
        ny, nx, nc = mat.shape
    else:
        raise ValueError(f"Unsupported mat shape {mat.shape}")
    # Edges or extents
    if is_pcolormesh:
        xedges, yedges = x, y
    else:
        x0, x1 = x.min(), x.max()
        y0, y1 = y.min(), y.max()

    def format_coord(xp, yp):
        # Map to indices
        if is_pcolormesh:
            col = np.searchsorted(xedges, xp) - 1
            row = np.searchsorted(yedges, yp) - 1
        else:
            col = int((xp - x0) / (x1 - x0) * nx)
            row = int((yp - y0) / (y1 - y0) * ny)
        # Within bounds?
        if 0 <= row < ny and 0 <= col < nx:
            if mat.ndim == 2:
                v = mat[row, col]
                return f"x={xp:.3g}, y={yp:.3g}, val={v:.3g}"
            else:
                vals = mat[row, col]
                txt = ", ".join(f"{vv:.3g}" for vv in vals[:3])
                return f"x={xp:.3g}, y={yp:.3g}, val=({txt})"
        return f"x={xp:.3g}, y={yp:.3g}"

    ax.format_coord = format_coord


def plot_stratification_over_time(
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    For each grid point and slope, generate a 2×2 figure of:
      - CO2 ice fraction
      - H2O ice fraction
      - Dust fraction
      - Pore fraction
    """
    prop_names = ['co2_ice', 'h2o_ice', 'dust', 'pore']
    titles = ["CO2 ice", "H2O ice", "Dust", "Pore"]
    cmap = plt.get_cmap('turbo').copy()
    cmap.set_under('white')
    vmin, vmax = 0.0, 1.0

    sample_prop = next(iter(gridded_data.values()))
    ngrid, ntime, nslope, nz = sample_prop.shape

    if exclude_sub:
        positive_indices = np.where(ref_grid >= 0.0)[0]
        sub_ref_grid = ref_grid[positive_indices]
    else:
        positive_indices = np.arange(nz)
        sub_ref_grid = ref_grid

    for ig in range(ngrid):
        for isl in range(nslope):
            fig, axes = plt.subplots(2, 2, figsize=(10, 8))
            fig.suptitle(
                f"Content variation over time for (Grid point {ig+1}, Slope {isl+1})",
                fontsize=14,
                fontweight='bold'
            )

            # Precompute valid stratum tops per time
            valid_tops_per_time = []
            for t_idx in range(ntime):
                raw_h = heights_data[t_idx][isl][ig, :]
                h_all = raw_h[~np.isnan(raw_h)]
                if exclude_sub:
                    h_all = h_all[h_all >= 0.0]
                valid_tops_per_time.append(np.unique(h_all))

            for idx, prop in enumerate(prop_names):
                ax = axes.flat[idx]
                data_3d = gridded_data[prop][ig, :, isl, :]
                mat_full = data_3d.T
                mat = mat_full[positive_indices, :].copy()
                mat[mat < 0.0] = np.nan

                # Mask above top stratum
                for t_idx in range(ntime):
                    i_zmax = top_index[ig, t_idx, isl]
                    if i_zmax <= positive_indices[0]:
                        mat[:, t_idx] = np.nan
                    else:
                        count_z = np.count_nonzero(positive_indices < i_zmax)
                        mat[count_z:, t_idx] = np.nan

                im = ax.pcolormesh(
                    date_time,
                    sub_ref_grid,
                    mat,
                    cmap=cmap,
                    shading='auto',
                    vmin=vmin,
                    vmax=vmax
                )
                x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1]-date_time[-2])]])
                attach_format_coord(ax, mat, x_edges, np.concatenate([sub_ref_grid, [sub_ref_grid[-1] + (sub_ref_grid[-1]-sub_ref_grid[-2])]]), is_pcolormesh=True)
                ax.set_title(titles[idx], fontsize=12)
                ax.set_xlabel("Time (Mars years)")
                ax.set_ylabel("Elevation (m)")

            fig.subplots_adjust(right=0.88)
            fig.tight_layout(rect=[0, 0, 0.88, 1.0])
            cbar_ax = fig.add_axes([0.90, 0.15, 0.02, 0.7])
            fig.colorbar(im, cax=cbar_ax, orientation='vertical', label="Content")

            fname = os.path.join(
                output_folder, f"layering_evolution_ig{ig+1}_is{isl+1}.png"
            )
            fig.savefig(fname, dpi=1200, bbox_inches='tight')


def plot_stratification_rgb_over_time(
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    Plot stratification over time colored using RGB ternary mix of H2O ice (blue), CO2 ice (violet), and dust (orange).
    Includes a triangular legend showing the mix proportions.
    """
    # Define constant colors
    violet = np.array([255,   0, 255], dtype=float) / 255
    blue   = np.array([  0,   0, 255], dtype=float) / 255
    orange = np.array([255, 165,   0], dtype=float) / 255

    # Elevation mask and array
    if exclude_sub:
        elevation_mask = (ref_grid >= 0.0)
        elev = ref_grid[elevation_mask]
    else:
        elevation_mask = np.ones_like(ref_grid, dtype=bool)
        elev = ref_grid

    # Pre-compute legend triangle
    res = 300
    u = np.linspace(0, 1, res)
    v = np.linspace(0, np.sqrt(3)/2, res)
    X, Y = np.meshgrid(u, v)
    V_bary = 2 * Y / np.sqrt(3)
    U_bary = X - 0.5 * V_bary
    W_bary = 1 - U_bary - V_bary
    mask_triangle = (U_bary >= 0) & (V_bary >= 0) & (W_bary >= 0)
    legend_rgb = (
        U_bary[..., None] * violet
        + V_bary[..., None] * orange
        + W_bary[..., None] * blue
    )
    legend_rgb = np.clip(legend_rgb, 0.0, 1.0)
    legend_rgba = np.zeros((res, res, 4))
    legend_rgba[..., :3] = legend_rgb
    legend_rgba[..., 3] = mask_triangle.astype(float)

    # Extract data arrays
    h2o = gridded_data['h2o_ice']
    co2 = gridded_data['co2_ice']
    dust = gridded_data['dust']
    ngrid, ntime, nslope, nz = h2o.shape

    # Fill missing depths
    ti = top_index.copy().astype(int)
    for ig in range(ngrid):
        for isl in range(nslope):
            for t in range(1, ntime):
                if ti[ig, t, isl] <= 0:
                    ti[ig, t, isl] = ti[ig, t-1, isl]

    # Loop over grid and slope
    for ig in range(ngrid):
        for isl in range(nslope):
            # Compute RGB stratification over time
            rgb = np.ones((nz, ntime, 3), dtype=float)
            frac_all = np.full((nz, ntime, 3), np.nan, dtype=float)  # store fH2O, fCO2, fDust
            for t in range(ntime):
                depth = ti[ig, t, isl]
                if depth <= 0:
                    continue
                cH2O = np.clip(h2o[ig, t, isl, :depth], 0, None)
                cCO2 = np.clip(co2[ig, t, isl, :depth], 0, None)
                cDust = np.clip(dust[ig, t, isl, :depth], 0, None)
                total = cH2O + cCO2 + cDust
                total[total == 0] = 1.0
                fH2O = cH2O / total
                fCO2 = cCO2 / total
                fDust = cDust / total
                frac_all[:depth, t, 0] = fH2O
                frac_all[:depth, t, 1] = fCO2
                frac_all[:depth, t, 2] = fDust
                rgb[:depth, t, 0] = fH2O * blue[0] + fCO2 * violet[0] + fDust * orange[0]
                rgb[:depth, t, 1] = fH2O * blue[1] + fCO2 * violet[1] + fDust * orange[1]
                rgb[:depth, t, 2] = fH2O * blue[2] + fCO2 * violet[2] + fDust * orange[2]

            # Mask elevation
            display_rgb = rgb[elevation_mask, :, :]
            display_frac = frac_all[elevation_mask, :, :]

            # Compute edges for pcolormesh
            dt = date_time[1] - date_time[0] if len(date_time) > 1 else 1
            x_edges = np.concatenate([date_time, [date_time[-1] + dt]])
            d_e = np.diff(elev)
            last_e = elev[-1] + (d_e[-1] if len(d_e)>0 else 1)
            y_edges = np.concatenate([elev, [last_e]])

            # Create figure with legend
            fig, (ax_main, ax_leg) = plt.subplots(
                1, 2, figsize=(8, 4), dpi=150,
                gridspec_kw={'width_ratios': [5, 1]}
            )

            # Main stratification panel
            mesh = ax_main.pcolormesh(
                x_edges,
                y_edges,
                display_rgb,
                shading='auto',
                edgecolors='none'
            )

            # Custom coordinate formatter: show time, elevation, and mixture fractions
            def main_format(x, y):
                # check bounds
                if x < x_edges[0] or x > x_edges[-1] or y < y_edges[0] or y > y_edges[-1]:
                    return ''
                # locate cell
                i = np.searchsorted(x_edges, x) - 1
                j = np.searchsorted(y_edges, y) - 1
                i = np.clip(i, 0, display_rgb.shape[1] - 1)
                j = np.clip(j, 0, display_rgb.shape[0] - 1)
                # get fractions
                fH2O, fCO2, fDust = display_frac[j, i]
                return f"Time={x:.2f}, Elev={y:.2f}, H2O={fH2O:.4f}, CO2={fCO2:.4f}, Dust={fDust:.4f}"
            ax_main.format_coord = main_format
            ax_main.set_facecolor('white')
            ax_main.set_title(f"Ternary mix over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
            ax_main.set_xlabel("Time (Mars years)")
            ax_main.set_ylabel("Elevation (m)")

            # Legend panel using proper edges
            u_edges = np.linspace(0, 1, res+1)
            v_edges = np.linspace(0, np.sqrt(3)/2, res+1)
            ax_leg.pcolormesh(
                u_edges,
                v_edges,
                legend_rgba,
                shading='auto',
                edgecolors='none'
            )
            ax_leg.set_aspect('equal')

            # Custom coordinate formatter for legend: show barycentric fractions
            def legend_format(x, y):
                # compute barycentric coords from cartesian (x,y)
                V = 2 * y / np.sqrt(3)
                U = x - 0.5 * V
                W = 1 - U - V
                if U >= 0 and V >= 0 and W >= 0:
                    return f"H2O: {W:.2f}, Dust: {V:.2f}, CO2: {U:.2f}"
                else:
                    return ''
            ax_leg.format_coord = legend_format

            # Draw triangle border and gridlines
            triangle = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2], [0, 0]])
            ax_leg.plot(triangle[:, 0], triangle[:, 1], 'k-', linewidth=1, clip_on=False, zorder=10)
            ticks = np.linspace(0.25, 0.75, 3)
            for f in ticks:
                ax_leg.plot([1 - f, 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)
                ax_leg.plot([f, f + 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)
                y = (np.sqrt(3)/2) * f
                ax_leg.plot([0.5 * f, 1 - 0.5 * f], [y, y], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)

            # Legend labels
            ax_leg.text(0, -0.05, 'H2O ice', ha='center', va='top', fontsize=8)
            ax_leg.text(1, -0.05, 'CO2 ice', ha='center', va='top', fontsize=8)
            ax_leg.text(0.5, np.sqrt(3)/2 + 0.05, 'Dust', ha='center', va='bottom', fontsize=8)
            ax_leg.axis('off')

            # Save figure
            plt.tight_layout()
            fname = os.path.join(output_folder, f"layering_rgb_evolution_ig{ig+1}_is{isl+1}.png")
            fig.savefig(fname, dpi=1200, bbox_inches='tight')


def plot_dust_to_ice_ratio_over_time(
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    Plot the dust-to-ice ratio in the stratification over time,
    using a blue-to-orange colormap:
    - blue: ice-dominated (low dust-to-ice ratio)
    - orange: dust-dominated (high dust-to-ice ratio)
    """
    h2o = gridded_data['h2o_ice']
    co2 = gridded_data['co2_ice']
    dust = gridded_data['dust']
    ngrid, ntime, nslope, nz = h2o.shape

    # Define custom blue-to-orange colormap
    blue = np.array([0, 0, 255], dtype=float) / 255
    orange = np.array([255, 165, 0], dtype=float) / 255
    custom_cmap = LinearSegmentedColormap.from_list('BlueOrange', [blue, orange], N=256)

    # Log‑ratio bounds and small epsilon to avoid log(0)
    vmin, vmax = -2, 1
    epsilon = 1e-6

    # Loop over grids and slopes
    for ig in range(ngrid):
        for isl in range(nslope):
            ti = top_index[ig, :, isl].copy().astype(int)

            # Compute log10(dust/ice) profile at each time step
            log_ratio_array = np.full((nz, ntime), np.nan, dtype=np.float32)
            for t in range(ntime):
                if t > 0 and ti[t] <= 0:
                    ti[t] = ti[t-1]
                elif ti[t] <= 0:
                    continue
                zmax = ti[t]
                if zmax <= 0:
                    continue

                cH2O = np.clip(h2o[ig, t, isl, :zmax], 0, None)
                cCO2 = np.clip(co2[ig, t, isl, :zmax], 0, None)
                cDust = np.clip(dust[ig, t, isl, :zmax], 0, None)

                with np.errstate(divide='ignore', invalid='ignore'):
                    ratio = np.where(
                        cH2O > 0,
                        cDust / cH2O,
                        10**(vmax + 1)
                    )
                    log_ratio = np.log10(ratio + epsilon)
                    log_ratio = np.clip(log_ratio, vmin, vmax)

                log_ratio_array[:zmax, t] = log_ratio

            ratio_array = 10**log_ratio_array

            # Compute edges for pcolormesh
            x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1] - date_time[-2])]]) * martian_to_earth
            y_edges = np.concatenate([ref_grid, [ref_grid[-1] + (ref_grid[-1] - ref_grid[-2])]])

            # Plot
            fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
            im = ax.pcolormesh(
                date_time,
                elev,
                ratio_array,
                shading='auto',
                cmap='managua_r',
                norm=LogNorm(vmin=10**vmin, vmax=10**vmax),
            )
            attach_format_coord(ax, ratio_array, x_edges, y_edges, is_pcolormesh=True)
            ax.set_title(f"Dust-to-Ice ratio over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
            ax.set_xlabel('Time (Mars years)')
            ax.set_ylabel('Elevation (m)')

            # Add colorbar
            cbar = fig.colorbar(im, ax=ax, orientation='vertical', pad=0.15)
            cbar.set_label('Dust / H₂O ice (ratio)')
            cbar.set_ticks([1e-2, 1e-1, 1, 1e1])
            cbar.set_ticklabels(['1:100', '1:10', '1:1', '10:1'])

            # Save figure
            plt.tight_layout()
            outname = os.path.join(
                output_folder,
                f"dust_to_ice_ratio_grid{ig+1}_slope{isl+1}.png"
            )
            fig.savefig(outname, dpi=1200, bbox_inches='tight')


def plot_strata_count_and_total_height(heights_data, date_time, output_folder="."):
    """
    For each grid point and slope, plot:
      - Number of strata vs time
      - Total deposit height vs time
    """
    ntime = len(heights_data)
    nslope = len(heights_data[0])
    ngrid = heights_data[0][0].shape[0]

    for ig in range(ngrid):
        for isl in range(nslope):
            n_strata_t = np.zeros(ntime, dtype=int)
            total_height_t = np.zeros(ntime, dtype=float)

            for t_idx in range(ntime):
                h_mat = heights_data[t_idx][isl]
                raw_h = h_mat[ig, :]
                valid_mask = (~np.isnan(raw_h)) & (raw_h != 0.0)
                if np.any(valid_mask):
                    h_valid = raw_h[valid_mask]
                    n_strata_t[t_idx] = h_valid.size
                    total_height_t[t_idx] = np.max(h_valid)

            fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
            fig.suptitle(
                f"Strata count & total height over time for (Grid point {ig+1}, Slope {isl+1})",
                fontsize=14,
                fontweight='bold'
            )

            axes[0].plot(date_time, n_strata_t, marker='+', linestyle='-')
            axes[0].set_ylabel("Number of strata")
            axes[0].grid(True)

            axes[1].plot(date_time, total_height_t, marker='+', linestyle='-')
            axes[1].set_xlabel("Time (Mars years)")
            axes[1].set_ylabel("Total height (m)")
            axes[1].grid(True)

            fig.tight_layout(rect=[0, 0, 1, 0.95])
            fname = os.path.join(
                output_folder, f"strata_count_height_ig{ig+1}_is{isl+1}.png"
            )
            fig.savefig(fname, dpi=150)


def read_orbital_data(orb_file, martian_to_earth):
    """
    Read the .asc file containing obliquity, eccentricity and Ls p.
    Columns:
      0 = time in thousand Martian years
      1 = obliquity (deg)
      2 = eccentricity
      3 = Ls p (deg)
    Converts times to Earth years.
    """
    data = np.loadtxt(orb_file)
    dates_mka = data[:, 0]
    dates_yr = dates_mka * 1e3 / martian_to_earth
    obliquity = data[:, 1]
    eccentricity = data[:, 2]
    lsp = data[:, 3]
    return dates_yr, obliquity, eccentricity, lsp


def plot_orbital_parameters(infofile, orb_file, date_time, output_folder="."):
    """
    Plot the evolution of obliquity, eccentricity and Ls p
    versus simulated time.
    """
    # Read conversion factor from infofile
    _, martian_to_earth = read_infofile(infofile)

    # Read orbital data
    dates_yr, obl, ecc, lsp = read_orbital_data(orb_file, martian_to_earth)

    # Interpolate orbital parameters at simulation dates (date_time)
    obl_interp = interp1d(dates_yr, obl, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
    ecc_interp = interp1d(dates_yr, ecc, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
    lsp_interp = interp1d(dates_yr, lsp, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)

    # Plot
    fig, axes = plt.subplots(3, 1, figsize=(8, 10), sharex=True)
    fig.suptitle("Orbital parameters vs simulated time", fontsize=14, fontweight='bold')

    axes[0].plot(date_time, obl_interp, 'r-', marker='+')
    axes[0].set_ylabel("Obliquity (°)")
    axes[0].grid(True)

    axes[1].plot(date_time, ecc_interp, 'b-', marker='+')
    axes[1].set_ylabel("Eccentricity")
    axes[1].grid(True)

    axes[2].plot(date_time, lsp_interp, 'g-', marker='+')
    axes[2].set_ylabel("Ls of perihelion  (°)")
    axes[2].set_xlabel("Time (Mars years)")
    axes[2].grid(True)

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    fname = os.path.join(output_folder, "orbital_parameters_laskar.png")
    fig.savefig(fname, dpi=150)


def mars_ls(pday, peri_day, e_elips, year_day, lsperi=0.0):
    """
    Compute solar longitude (Ls) in radians for a given Mars date array 'pday'.
    Returns Ls in degrees [0, 360).
    """
    zz = (pday - peri_day) / year_day
    zanom = 2 * np.pi * (zz - np.round(zz))
    xref = np.abs(zanom)

    # Solve Kepler's equation via Newton–Raphson
    zx0 = xref + e_elips * np.sin(xref)
    for _ in range(10):
        f  = zx0 - e_elips * np.sin(zx0) - xref
        fp = 1 - e_elips * np.cos(zx0)
        dz = -f / fp
        zx0 += dz
        if np.all(np.abs(dz) <= 1e-7):
            break

    zx0 = np.where(zanom < 0, -zx0, zx0)
    zteta = 2 * np.arctan(
        np.sqrt((1 + e_elips) / (1 - e_elips)) * np.tan(zx0 / 2)
    )
    psollong = np.mod(zteta + lsperi, 2 * np.pi)

    return np.degrees(psollong)


def read_orbital_data_nc(starts_folder, infofile=None):
    """
    Read orbital parameters from restartfi_postpem*.nc files in starts_folder.
    """
    if not os.path.isdir(starts_folder):
        raise ValueError(f"Invalid starts_folder '{starts_folder}': not a directory.")

    # Read simulation time mapping if provided
    if infofile:
        dates_yr, martian_to_earth = read_infofile(infofile)
    else:
        dates_yr = None

    pattern = os.path.join(starts_folder, "restartfi_postpem*.nc")
    files = glob(pattern)
    if not files:
        raise FileNotFoundError(f"No NetCDF restart files found matching {pattern}")

    def extract_number(path):
        name = os.path.basename(path)
        prefix = 'restartfi_postpem'
        if name.startswith(prefix) and name.endswith('.nc'):
            num_str = name[len(prefix):-3]
            if num_str.isdigit():
                return int(num_str)
        return float('inf')

    files = sorted(files, key=extract_number)

    all_year_day, all_peri, all_aphe, all_date_peri, all_obl = [], [], [], [], []
    for nc_path in files:
        with Dataset(nc_path, 'r') as nc:
            ctrl = nc.variables['controle'][:]
            all_year_day.append(ctrl[13])
            all_peri.append(ctrl[14])
            all_aphe.append(ctrl[15])
            all_date_peri.append(ctrl[16])
            all_obl.append(ctrl[17])

    year_day      = np.array(all_year_day)
    perihelion    = np.array(all_peri)
    aphelion      = np.array(all_aphe)
    date_peri_day = np.array(all_date_peri)
    obliquity     = np.array(all_obl)

    eccentricity  = (aphelion - perihelion) / (aphelion + perihelion)
    ls_perihelion = mars_ls(date_peri_day,0.,eccentricity,year_day)

    return dates_yr, obliquity, eccentricity, ls_perihelion, martian_to_earth


def plot_orbital_parameters_nc(starts_folder, infofile, date_time, output_folder="."):
    """
    Plot the evolution of obliquity, eccentricity and Ls p coming from simulation data
    versus simulated time, plus an additional figure of sin(eccentricity)*Lsp.
    versus simulated time.
    """
    # Read orbital data
    times_yr, obl, ecc, lsp, martian_to_earth = read_orbital_data_nc(starts_folder, infofile)

    fargs = dict(kind='linear', bounds_error=False, fill_value='extrapolate')
    obl_i = interp1d(times_yr, obl, **fargs)(date_time)
    ecc_i = interp1d(times_yr, ecc, **fargs)(date_time)
    lsp_i = interp1d(times_yr, lsp, **fargs)(date_time)

    date_time = date_time * martian_to_earth / 1e6

    fig, axes = plt.subplots(3,1, figsize=(8,10), sharex=True)
    fig.suptitle("Orbital parameters vs simulated time", fontsize=14, fontweight='bold')

    # Plot
    axes[0].plot(date_time, obl_i, 'r-', marker='+')
    axes[0].set_ylabel("Obliquity (°)")
    axes[0].grid(True)

    axes[1].plot(date_time, ecc_i, 'b-', marker='+')
    axes[1].set_ylabel("Eccentricity")
    axes[1].grid(True)

    axes[2].plot(date_time, lsp_i, 'g-', marker='+')
    axes[2].set_ylabel("Ls of perihelion (°)")
    axes[2].set_xlabel("Time (Myr)")
    axes[2].grid(True)

    plt.tight_layout(rect=[0,0,1,0.96])
    outname = os.path.join(output_folder, "orbital_parameters_simu.png")
    fig.savefig(outname, dpi=150)

    eps_sin_lsp = ecc_i * np.sin(np.radians(lsp_i)) 

    fig2, ax2 = plt.subplots(figsize=(8,5))
    fig2.suptitle(r"$\epsilon \times \sin(L_{sp})$", fontweight='bold')

    ax2.plot(date_time, eps_sin_lsp, 'm-', marker='+')
    ax2.set_ylabel(r"$\epsilon \cdot \sin(L_{sp})$")
    ax2.set_xlabel("Time (Myr)")
    ax2.grid(True)

    plt.tight_layout(rect=[0,0,1,0.95])
    outname2 = os.path.join(output_folder, "sin_ecc_times_Lsp.png")
    fig2.savefig(outname2, dpi=150)


def plot_dust_to_ice_ratio_with_obliquity(
    starts_folder,
    infofile,
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    Plot the dust-to-ice ratio over time as a heatmap, and overlay the evolution of
    obliquity on a secondary y-axis.
    """
    h2o = gridded_data['h2o_ice']
    co2 = gridded_data['co2_ice']
    dust = gridded_data['dust']
    ngrid, ntime, nslope, nz = h2o.shape

    # Read orbital data
    times_yr, obl, _, _, martian_to_earth = read_orbital_data_nc(starts_folder, infofile)
    fargs = dict(kind='linear', bounds_error=False, fill_value='extrapolate')
    obliquity = interp1d(times_yr, obl, **fargs)(date_time)

    # Define custom blue-to-orange colormap
    blue = np.array([0, 0, 255], dtype=float) / 255
    orange = np.array([255, 165, 0], dtype=float) / 255
    custom_cmap = LinearSegmentedColormap.from_list('BlueOrange', [blue, orange], N=256)
    color_map = { 1: 'green', -1: 'red', 0: 'orange' }

    # Log‑ratio bounds and small epsilon to avoid log(0)
    vmin, vmax = -2, 1
    epsilon = 1e-6

    # Loop over grids and slopes
    for ig in range(ngrid):
        for isl in range(nslope):
            # Compute total height time series
            total_height_t = np.zeros(ntime, dtype=float)
            for t_idx in range(ntime):
                h_mat = heights_data[t_idx][isl]
                raw_h = h_mat[ig, :]
                valid_mask = (~np.isnan(raw_h)) & (raw_h != 0.0)
                if np.any(valid_mask):
                    h_valid = raw_h[valid_mask]
                    total_height_t[t_idx] = np.max(h_valid)

            # Compute the per-interval sign of height change
            if ntime > 1:
                dh = np.diff(total_height_t)
                signs = np.sign(dh).astype(int)
            else:
                dh = np.array([], dtype=float)
                signs = np.array([], dtype=int)

            # Prepare fraction and ratio arrays
            ti = top_index[ig, :, isl].copy().astype(int)
            log_ratio_array = np.full((nz, ntime), np.nan, dtype=np.float32)
            frac_all = np.full((nz, ntime, 3), np.nan, dtype=float)  # store fH2O, fCO2, fDust
            for t in range(ntime):
                if t > 0 and ti[t] <= 0:
                    ti[t] = ti[t-1]
                elif ti[t] <= 0:
                    continue
                zmax = ti[t]
                if zmax <= 0:
                    continue

                cH2O = np.clip(h2o[ig, t, isl, :zmax], 0, None)
                cCO2 = np.clip(co2[ig, t, isl, :zmax], 0, None)
                cDust = np.clip(dust[ig, t, isl, :zmax], 0, None)
                total = cH2O + cCO2 + cDust
                total[total == 0] = 1.0
                fH2O = cH2O / total
                fCO2 = cCO2 / total
                fDust = cDust / total
                frac_all[:zmax, t, 0] = fH2O
                frac_all[:zmax, t, 1] = fCO2
                frac_all[:zmax, t, 2] = fDust

                with np.errstate(divide='ignore', invalid='ignore'):
                    ratio = np.where(cH2O > 0, cDust / cH2O, 10**(vmax + 1)
                    )
                    log_ratio = np.log10(ratio + epsilon)
                    log_ratio = np.clip(log_ratio, vmin, vmax)

                log_ratio_array[:zmax, t] = log_ratio

            ratio_array = 10**log_ratio_array

            # Compute edges for pcolormesh
            dt = date_time[1] - date_time[0] if len(date_time) > 1 else 1
            x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1] - date_time[-2])]]) * martian_to_earth
            y_edges = np.concatenate([ref_grid, [ref_grid[-1] + (ref_grid[-1] - ref_grid[-2])]])

            # Plot
            fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
            im = ax.pcolormesh(
                x_edges,
                y_edges,
                ratio_array,
                shading='auto',
                cmap='managua_r',
                norm=LogNorm(vmin=10**vmin, vmax=10**vmax),
            )

            # Custom formatter for millions of Earth years
            def millions_formatter(x, pos):
                return f"{x/1e6:.1f}"

            def format_coord_custom(x_input, y_input):
                # map onto the main axis
                if plt.gca() is ax2:
                    x_pix, y_pix = ax2.transData.transform((x_input, y_input))
                    x, y = ax.transData.inverted().transform((x_pix, y_pix))
                else:
                    x, y = x_input, y_input
                # check bounds
                if x < x_edges[0] or x > x_edges[-1] or y < y_edges[0] or y > y_edges[-1]:
                    return ''
                # locate cell
                i = np.searchsorted(x_edges, x) - 1
                j = np.searchsorted(y_edges, y) - 1
                i = np.clip(i, 0, ratio_array.shape[1] - 1)
                j = np.clip(j, 0, ratio_array.shape[0] - 1)
                # get fractions and obliquity
                fH2O, fCO2, fDust = frac_all[j, i]
                obl   = np.interp(x / martian_to_earth, date_time, obliquity)
                return f"Time={x:.2f}, Elev={y:.2f}, H2O={fH2O:.4f}, Dust={fDust:.4f}, Obl={obl:.2f}°"

            ax.set_title(f"Dust-to-Ice ratio over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
            ax.xaxis.set_major_formatter(FuncFormatter(millions_formatter))
            ax.set_xlabel('Time (Myr)')
            ax.set_ylabel('Elevation (m)')

            # Add colorbar
            cbar = fig.colorbar(im, ax=ax, orientation='vertical', pad=0.15)
            cbar.set_label('Dust / H₂O ice (ratio)')
            cbar.set_ticks([1e-2, 1e-1, 1, 1e1])
            cbar.set_ticklabels(['1:100', '1:10', '1:1', '10:1'])

            # Overlay obliquity on secondary y-axis
            ax2 = ax.twinx()
            for i in range(len(dh)):
                ax2.plot(
                    [date_time[i] * martian_to_earth, date_time[i+1] * martian_to_earth],
                    [obliquity[i], obliquity[i+1]],
                    color=color_map[signs[i]],
                    marker='+',
                    linewidth=1.5
                )
            ax2.format_coord = format_coord_custom
            ax2.set_ylabel('Obliquity (°)')
            ax2.tick_params(axis='y')
            ax2.grid(False)

            # Save figure
            plt.tight_layout()
            outname = os.path.join(
                output_folder,
                f'dust_ice_obliquity_grid{ig+1}_slope{isl+1}.png'
            )
            fig.savefig(outname, dpi=1200, bbox_inches='tight')


def main():
    # 1) Get user inputs
    folder_path, base_name, infofile, orbfile = get_user_inputs()

    # 2) List and verify NetCDF files
    files = list_netcdf_files(folder_path, base_name)
    if not files:
        print(f"No NetCDF files named \"{base_name}#.nc\" found in \"{folder_path}\".")
        sys.exit(1)
    print(f"> Found {len(files)} NetCDF file(s).")

    # 3) Open one sample to get grid dimensions & coordinates
    sample_file = files[0]
    ngrid, nslope, longitude, latitude = open_sample_dataset(sample_file)
    print(f"> ngrid  = {ngrid}, nslope = {nslope}")

    # 4) Collect variable info + global min/max elevations
    var_info, max_nb_str, min_base_elev, max_top_elev = collect_stratification_variables(files, base_name)
    print(f"> max strata per slope = {max_nb_str}")
    print(f"> min base elev = {min_base_elev:.3f} m, max top elev = {max_top_elev:.3f} m")

    # 5) Load full datasets
    datasets = load_full_datasets(files)

    # 6) Extract stratification data
    heights_data, raw_prop_arrays, ntime = extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str)

    # 7) Close datasets
    for ds in datasets:
        ds.close()

    # 8) Normalize to fractions
    frac_arrays = normalize_to_fractions(raw_prop_arrays)

    # 9) Ask whether to include subsurface
    show_subsurface = get_yes_no_input("Show subsurface layers?")
    exclude_sub = not show_subsurface
    if exclude_sub:
        min_base_for_interp = 0.0
        print("> Interpolating only elevations >= 0 m (surface strata).")
    else:
        min_base_for_interp = min_base_elev
        print(f"> Interpolating full depth down to {min_base_elev:.3f} m.")

    # 10) Prompt discretization step
    dz = prompt_discretization_step(max_top_elev)

    # 11) Build reference grid and interpolate
    ref_grid, gridded_data, top_index = interpolate_data_on_refgrid(
        heights_data, frac_arrays, min_base_for_interp, max_top_elev, dz, exclude_sub=exclude_sub
    )

    # 12) Read timestamps and conversion factor from infofile
    date_time, martian_to_earth = read_infofile(infofile)
    if date_time.size != ntime:
        print(f"Warning: {date_time.size} timestamps vs {ntime} NetCDF files.")

    # 13) Plot stratification data over time
    plot_stratification_over_time(
        gridded_data, ref_grid, top_index, heights_data, date_time,
        exclude_sub=exclude_sub, output_folder="."
    )
    plot_stratification_rgb_over_time(
        gridded_data, ref_grid, top_index, heights_data, date_time,
        exclude_sub=exclude_sub, output_folder="."
    )
    #plot_dust_to_ice_ratio_over_time(
    #    gridded_data, ref_grid, top_index, heights_data, date_time,
    #    exclude_sub=exclude_sub, output_folder="."
    #)
    plot_dust_to_ice_ratio_with_obliquity(
        folder_path, infofile,
        gridded_data, ref_grid, top_index, heights_data, date_time,
        exclude_sub=exclude_sub, output_folder="."
    )
    #plot_strata_count_and_total_height(heights_data, date_time, output_folder=".")

    # 14) Plot orbital parameters
    #plot_orbital_parameters(infofile, orbfile, date_time, output_folder=".")
    plot_orbital_parameters_nc(folder_path, infofile, date_time, output_folder=".")

    # 15) Show all figures
    plt.show()


if __name__ == "__main__":
    main()

