#!/usr/bin/env python3
################################################################
### Python script to concatenate the NetCDF files of the PEM ###
### along the dimension 'Time' into one NetCDF file          ###
################################################################


import os
import re
import sys
import glob
import readline
import argparse
import xarray as xr
import numpy as np


def complete_path(text, state):
    matches = glob.glob(text + '*')
    return matches[state] if state < len(matches) else None

readline.set_completer_delims(' \t\n;')
readline.set_completer(complete_path)
readline.parse_and_bind("tab: complete")


def parse_args():
    parser = argparse.ArgumentParser(
        description="Concatenate multiple NetCDF files along the Time dimension"
    )
    parser.add_argument(
        "--folder", type=str,
        help="Path to the directory containing the NetCDF files"
    )
    parser.add_argument(
        "--basename", type=str,
        help="Base name of the files, e.g., 'diagpem' for files like diagpem1.nc"
    )
    parser.add_argument(
        "--start", type=int,
        help="Starting index of the files to include"
    )
    parser.add_argument(
        "--end", type=int,
        help="Ending index of the files to include (inclusive)"
    )
    parser.add_argument(
        "--output", type=str, default="merged.nc",
        help="Output filename for the concatenated NetCDF"
    )
    return parser.parse_args()


def prompt_with_default(prompt_text, default, cast_fn=None):
    prompt = f"{prompt_text} [press Enter for default {default}]: "
    while True:
        try:
            user_input = input(prompt)
        except KeyboardInterrupt:
            print("\nInterrupted.")
            sys.exit(1)

        if not user_input.strip():
            return default
        try:
            return cast_fn(user_input) if cast_fn else user_input
        except ValueError:
            print(f"Invalid value. Expecting {cast_fn.__name__}. Please try again.")


def find_index_range(folder, basename):
    pattern = os.path.join(folder, f"{basename}*.nc")
    files = glob.glob(pattern)
    indices = []
    for f in files:
        name = os.path.basename(f)
        m = re.match(fr"{re.escape(basename)}(\d+)\.nc$", name)
        if m:
            indices.append(int(m.group(1)))
    if not indices:
        raise FileNotFoundError(f"No files matching {basename}*.nc found in {folder}")
    return min(indices), max(indices)


def main():
    args = parse_args()

    # Defaults
    default_folder = args.folder or "diags"
    default_basename = args.basename or "diagpem"

    # Prompt for folder and basename with defaults
    folder = prompt_with_default(
        "Enter the folder path containing NetCDF files", default_folder
    )
    basename = prompt_with_default(
        "Enter the base filename", default_basename
    )

    # Determine available index range
    min_idx, max_idx = find_index_range(folder, basename)
    print(f"Found files from index {min_idx} to {max_idx}.")

    # Prompt for start/end with discovered defaults
    start = args.start if args.start is not None else prompt_with_default(
        "Enter the starting file index", min_idx, cast_fn=int
    )
    end = args.end if args.end is not None else prompt_with_default(
        "Enter the ending file index", max_idx, cast_fn=int
    )

    # Validate range
    if start < min_idx or end > max_idx or start > end:
        print(f"Invalid range: must be between {min_idx} and {max_idx}, and start <= end.")
        sys.exit(1)

    # Output filename
    output = args.output or prompt_with_default(
        "Enter the output filename (including .nc)", "merged.nc"
    )

    # Build and verify file list
    file_list = [
        os.path.join(folder, f"{basename}{i}.nc")
        for i in range(start, end + 1)
    ]
    for fpath in file_list:
        if not os.path.isfile(fpath):
            raise FileNotFoundError(f"File not found: {fpath}")

    # Offset Time values to make them cumulative
    datasets = []
    time_offset = 0

    for fpath in file_list:
        ds = xr.open_dataset(fpath, decode_times=False)

        if 'Time' not in ds.coords:
            raise ValueError(f"'Time' coordinate not found in {fpath}")

        time_vals = ds['Time'].values

        new_time_vals = time_vals + time_offset
        ds = ds.assign_coords(Time=new_time_vals)
        datasets.append(ds)

        if len(time_vals) > 1:
            dt = time_vals[1] - time_vals[0]
            duration = dt * len(time_vals)
        else:
            duration = 1
        time_offset += duration

    # Concatenate
    merged_ds = xr.concat(datasets, dim="Time")

    # Optionally decode CF conventions after loading
    try:
        merged_ds = xr.decode_cf(merged_ds)
    except Exception as e:
        print(f"Warning: CF decoding failed: {e}\nProceeding with raw time values.")

    # Inspect and save
    try:
        tmax = merged_ds.Time.max().values
        print(f"Final time value: {tmax:.0f}")
    except Exception:
        print("Time variables not decoded correctly.")

    merged_ds.to_netcdf(output)
    print(f"Merged dataset written to {output}")


if __name__ == "__main__":
    main()
