#!/usr/bin/env python3
#######################################################################################
### Python script to output the stratification data from the "restartpem#.nc" files ###
#######################################################################################

import os
import sys
import numpy as np
import netCDF4 as nc
import matplotlib.pyplot as plt


def get_int_input(prompt: str, min_val: int, max_val: int) -> int:
    """
    Prompt the user for an integer between min_val and max_val (inclusive).
    If min_val == max_val, return min_val without prompting.
    """
    if min_val == max_val:
        print(f"{prompt} (only possible value: {min_val})")
        return min_val

    while True:
        try:
            value = int(input(f"{prompt} (integer between {min_val} and {max_val}): "))
            if min_val <= value <= max_val:
                return value
            print(f"Invalid value! Please enter a number between {min_val} and {max_val}.")
        except ValueError:
            print("Invalid input. Please enter a valid integer.")


def get_yes_no_input(prompt: str) -> bool:
    """
    Prompt the user with a yes/no question. Returns True for yes, False for no.
    """
    while True:
        choice = input(f"{prompt} (y/n): ").strip().lower()
        if choice in ['y', 'yes']:
            return True
        elif choice in ['n', 'no']:
            return False
        else:
            print("Please respond with y or n.")


def load_slope_variables(nc_dataset: nc.Dataset, slope_index: int) -> dict:
    """
    Load all relevant stratification variables for a given slope index (0-based).
    Returns a dict of NumPy arrays.
    """
    idx_str = str(slope_index + 1).zfill(2)
    vars_base = {
        'top_elev': f"stratif_slope{idx_str}_top_elevation",
        'h_co2': f"stratif_slope{idx_str}_h_co2ice",
        'h_h2o': f"stratif_slope{idx_str}_h_h2oice",
        'h_dust': f"stratif_slope{idx_str}_h_dust",
        'h_pore': f"stratif_slope{idx_str}_h_pore",
        'poreice_volfrac': f"stratif_slope{idx_str}_poreice_volfrac",
    }

    data = {}
    for key, var_name in vars_base.items():
        if var_name not in nc_dataset.variables:
            sys.exit(f"Error: Variable '{var_name}' not found in the NetCDF file.")
        data[key] = nc_dataset.variables[var_name][:]
    return data


def calculate_contents(data: dict, grid_index: int, exclude_subsurface: bool) -> (np.ndarray, np.ndarray, np.ndarray):
    """
    For a given slope's stratification data and a specified grid index (0-based),
    compute the height array, layer thicknesses, and layer-wise volume fractions.
    If exclude_subsurface is True, omit layers whose top elevation <= 0.

    Returns:
      - height: 1D array of elevations (length = number_of_layers + 1)
      - contents: 2D array of shape (5, number_of_layers + 1), each row for a component
      - thicknesses: 1D array of layer thicknesses (length = number_of_layers)
    """
    top = data['top_elev']
    h_co2 = data['h_co2']
    h_h2o = data['h_h2o']
    h_dust = data['h_dust']
    h_pore = data['h_pore']
    poreice_volfrac = data['poreice_volfrac']

    layers = top.shape[1]
    # Compute raw height and thickness arrays
    raw_height = np.zeros(layers + 1)
    raw_thickness = np.zeros(layers)

    total_thickness0 = (
        h_co2[0, 0, grid_index]
        + h_h2o[0, 0, grid_index]
        + h_dust[0, 0, grid_index]
        + h_pore[0, 0, grid_index]
    )
    raw_height[0] = top[0, 0, grid_index] - total_thickness0
    raw_height[1:] = top[0, :, grid_index]

    for i in range(layers):
        raw_thickness[i] = (
            h_co2[0, i, grid_index]
            + h_h2o[0, i, grid_index]
            + h_dust[0, i, grid_index]
            + h_pore[0, i, grid_index]
        )

    include_mask = np.ones(layers, dtype=bool)
    if exclude_subsurface:
        include_mask = raw_height[1:] > 0
    if exclude_subsurface and not include_mask.any() and raw_height[0] <= 0:
        sys.exit("Error: No layers remain above the surface (elevation > 0).")

    filt_layers = np.where(include_mask)[0]
    num_filt = len(filt_layers)
    height = np.zeros(num_filt + 1)
    thicknesses = np.zeros(num_filt)

    if exclude_subsurface and raw_height[0] <= 0:
        height[0] = raw_height[filt_layers[0] + 1] - raw_thickness[filt_layers[0]]
    else:
        height[0] = raw_height[0]

    for idx, layer_idx in enumerate(filt_layers):
        height[idx + 1] = raw_height[layer_idx + 1]
        thicknesses[idx] = raw_thickness[layer_idx]

    contents = np.zeros((5, num_filt + 1))
    for idx, layer_idx in enumerate(filt_layers):
        thickness = raw_thickness[layer_idx]
        if thickness <= 0:
            continue
        co2 = h_co2[0, layer_idx, grid_index] / thickness
        h2o = h_h2o[0, layer_idx, grid_index] / thickness
        dust = h_dust[0, layer_idx, grid_index] / thickness
        air = h_pore[0, layer_idx, grid_index] * (1.0 - poreice_volfrac[0, layer_idx, grid_index]) / thickness
        poreice = h_pore[0, layer_idx, grid_index] * poreice_volfrac[0, layer_idx, grid_index] / thickness
        contents[:, idx + 1] = [co2, h2o, dust, air, poreice]

    if num_filt > 0:
        contents[:, 0] = contents[:, 1]

    return height, contents, thicknesses


def plot_profiles(height: np.ndarray, contents: np.ndarray, thicknesses: np.ndarray, labels: tuple[str, ...]) -> None:
    """
    Create and save plots:
      1. Simple step profiles (each component in its own subplot)
      2. Overlaid step profiles (all components on one plot)
      3. Stacked fill-between profile with relative fractions
    """
    n_components = contents.shape[0]
    colors = ['r', 'b', 'y', 'violet', 'c']

    # 1. Simple subplots
    fig, axes = plt.subplots(1, n_components, sharey=True, constrained_layout=True)
    fig.suptitle('Simple Content Profiles for Stratification', fontweight='bold')

    for idx, ax in enumerate(axes):
        ax.step(contents[idx, :], height, where='post', color=colors[idx])
        ax.set_xlim(0, 1)
        ax.set_xlabel('Volume Fraction [m^3/m^3]')
        ax.set_title(labels[idx])
        if idx == 0:
            ax.set_ylabel('Elevation [m]')
        # Add major and minor grids on y-axis
        ax.grid(which='major', axis='y', linestyle='--', linewidth=0.5, color='0.8')
        ax.minorticks_on()
        ax.grid(which='minor', axis='y', linestyle=':', linewidth=0.3, color='0.9')
    plt.savefig('layering_simple_profiles.png')

    # 2. Overlaid step profiles
    plt.figure()
    for idx, label in enumerate(labels):
        plt.step(contents[idx, :], height, where='post', color=colors[idx], label=label)
    plt.grid(axis='x', color='0.95')
    plt.grid(axis='y', color='0.95')
    plt.minorticks_on()
    plt.grid(which='minor', axis='y', linestyle=':', linewidth=0.3, color='0.9')
    plt.xlim(0, 1)
    plt.xlabel('Volume Fraction [m^3/m^3]')
    plt.ylabel('Elevation [m]')
    plt.title('Content Profiles for Stratification', fontweight='bold')
    plt.legend()
    plt.savefig('layering_overlaid_profiles.png')

    # 3. Stacked fill-between profile with relative fractions
    plt.figure()
    cumulative = np.zeros_like(contents[0, :])
    for idx, label in enumerate(labels):
        plt.fill_betweenx(
            height,
            cumulative,
            cumulative + contents[idx, :],
            step='pre',
            label=label,
            color=colors[idx]
        )
        cumulative += contents[idx, :]
    plt.vlines(x=0., ymin=height[0], ymax=height[-1], color='k', linestyle='-')
    plt.vlines(x=1., ymin=height[0], ymax=height[-1], color='k', linestyle='-')
    for h in height:
        plt.hlines(y=h, xmin=0.0, xmax=1.0, color='k', linestyle='--', linewidth=0.5)
    plt.xlabel('Volume Fraction [m^3/m^3]')
    plt.ylabel('Elevation [m]')
    plt.title('Stacked Content Profiles for Stratification', fontweight='bold')
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
    plt.savefig('layering_stacked_profiles.png')


def plot_stratum_layer(contents: np.ndarray, istr_index: int, labels: tuple[str, ...]) -> None:
    """
    Plot detailed composition of the specified stratum index (0-based) as a bar chart.
    """
    layer_fractions = contents[:, istr_index + 1]
    plt.figure()
    x_positions = np.arange(len(labels))
    plt.bar(x_positions, layer_fractions, color=['r', 'b', 'y', 'violet', 'c'])
    plt.xticks(x_positions, labels, rotation=45, ha='right')
    plt.ylim(0, 1)
    plt.ylabel('Volume Fraction [m^3/m^3]')
    plt.title(f'Composition of Stratum {istr_index + 1}', fontweight='bold')
    # Add major and minor grids on y-axis for the histogram
    plt.grid(which='major', axis='y', linestyle='--', linewidth=0.5, color='0.8')
    plt.minorticks_on()
    plt.grid(which='minor', axis='y', linestyle=':', linewidth=0.3, color='0.9')
    plt.tight_layout()
    plt.savefig(f'stratum_{istr_index+1}_composition.png')


def main():
    if len(sys.argv) > 1:
        filename = sys.argv[1]
    else:
        filename = input("Enter the NetCDF file name: ").strip()

    if not os.path.isfile(filename):
        sys.exit(f"Error: File '{filename}' does not exist.")

    with nc.Dataset(filename, 'r') as dataset:
        required_dims = ['Time', 'physical_points', 'nslope', 'nb_str_max']
        for dim in required_dims:
            if dim not in dataset.dimensions:
                sys.exit(f"Error: Missing dimension '{dim}' in file.")

        ngrid = len(dataset.dimensions['physical_points'])
        nslope = len(dataset.dimensions['nslope'])
        nb_str_max = len(dataset.dimensions['nb_str_max'])

        print(f"File '{filename}' opened successfully.")
        print(f"Number of grid points (physical_points): {ngrid}")
        print(f"Number of slopes: {nslope}")
        print(f"Maximum number of strata per slope: {nb_str_max}")

        igrid_input = get_int_input(
            "Enter grid point number", 1, ngrid
        ) - 1
        islope_input = get_int_input(
            "Enter slope number", 1, nslope
        ) - 1

        show_subsurface = get_yes_no_input("Show subsurface layers?")
        exclude_sub = not show_subsurface

        # Load data for the chosen slope to determine number of surface strata
        slope_data = load_slope_variables(dataset, islope_input)

        # Compute raw heights to count strata above surface
        top = slope_data['top_elev']
        layers = top.shape[1]
        raw_height = np.zeros(layers + 1)
        # Compute initial subsurface bottom
        h_co2 = slope_data['h_co2']
        h_h2o = slope_data['h_h2o']
        h_dust = slope_data['h_dust']
        h_pore = slope_data['h_pore']
        total_thickness0 = (
            h_co2[0, 0, igrid_input]
            + h_h2o[0, 0, igrid_input]
            + h_dust[0, 0, igrid_input]
            + h_pore[0, 0, igrid_input]
        )
        raw_height[0] = top[0, 0, igrid_input] - total_thickness0
        raw_height[1:] = top[0, :, igrid_input]

        include_mask = np.ones(layers, dtype=bool)
        if exclude_sub:
            include_mask = raw_height[1:] > 0
        # Count number of strata above surface
        filt_layers = np.where(include_mask)[0]
        nb_str_surf_max = len(filt_layers)
        if not show_subsurface:
            if nb_str_surf_max == 0:
                print("No stratum layers remain after filtering. Cannot proceed.")
                return

        # Prompt for stratum index based on surface strata count
        istr_input = get_int_input(
            "Enter stratum index for detailed plot", 1, nb_str_surf_max if exclude_sub else nb_str_max
        ) - 1

        height_arr, contents_arr, thicknesses_arr = calculate_contents(
            slope_data, igrid_input, exclude_sub
        )

        component_labels = ("CO2 Ice", "H2O Ice", "Dust", "Air", "Pore ice")

        plot_profiles(height_arr, contents_arr, thicknesses_arr, component_labels)
        plot_stratum_layer(contents_arr, istr_input, component_labels)

        # Show all figures at once
        plt.show()


if __name__ == '__main__':
    main()

