#!/usr/bin/env python3
"""Concatenate indexed PEM NetCDF outputs along the Time dimension.

Builds a cumulative time axis from consecutive files and writes one merged
dataset for downstream diagnostics.
"""


import os
import re
import sys
import glob
import readline
import argparse
import xarray as xr
import numpy as np


def complete_path(text, state):
    matches = glob.glob(text + '*')
    return matches[state] if state < len(matches) else None

readline.set_completer_delims(' \t\n;')
readline.set_completer(complete_path)
readline.parse_and_bind("tab: complete")


def parse_args():
    parser = argparse.ArgumentParser(
        description="Concatenate multiple NetCDF files along the Time dimension"
    )
    parser.add_argument(
        "--folder", type=str,
        help="Path to the directory containing the NetCDF files"
    )
    parser.add_argument(
        "--basename", type=str,
        help="Base name of the files, e.g., 'diagevo' for files like diagevo1.nc"
    )
    parser.add_argument(
        "--start", type=int,
        help="Starting index of the files to include"
    )
    parser.add_argument(
        "--end", type=int,
        help="Ending index of the files to include (inclusive)"
    )
    parser.add_argument(
        "--output", type=str, default="merged.nc",
        help="Output filename for the concatenated NetCDF"
    )
    return parser.parse_args()


def prompt_with_default(prompt_text, default, cast_fn=None):
    prompt = f"{prompt_text} [press Enter for default {default}]: "
    while True:
        try:
            user_input = input(prompt)
        except KeyboardInterrupt:
            print("\nInterrupted.")
            sys.exit(1)

        if not user_input.strip():
            return default
        try:
            return cast_fn(user_input) if cast_fn else user_input
        except ValueError:
            print(f"Invalid value. Expecting {cast_fn.__name__}. Please try again.")


def find_index_range(folder, basename):
    pattern = os.path.join(folder, f"{basename}*.nc")
    files = glob.glob(pattern)
    indices = []
    for f in files:
        name = os.path.basename(f)
        m = re.match(fr"{re.escape(basename)}(\d+)\.nc$", name)
        if m:
            indices.append(int(m.group(1)))
    if not indices:
        raise FileNotFoundError(f"No files matching {basename}*.nc found in {folder}")
    return min(indices), max(indices)


def main():
    args = parse_args()

    # Defaults
    default_folder = args.folder or "diags"
    default_basename = args.basename or "diagevo"

    # Prompt for folder and basename with defaults
    folder = prompt_with_default(
        "Enter the folder path containing NetCDF files", default_folder
    )
    basename = prompt_with_default(
        "Enter the base filename", default_basename
    )

    # Determine available index range
    min_idx, max_idx = find_index_range(folder, basename)
    print(f"Found files from index {min_idx} to {max_idx}.")

    # Prompt for start/end with discovered defaults
    start = args.start if args.start is not None else prompt_with_default(
        "Enter the starting file index", min_idx, cast_fn=int
    )
    end = args.end if args.end is not None else prompt_with_default(
        "Enter the ending file index", max_idx, cast_fn=int
    )

    # Validate range
    if start < min_idx or end > max_idx or start > end:
        print(f"Invalid range: must be between {min_idx} and {max_idx}, and start <= end.")
        sys.exit(1)

    # Output filename
    output = args.output or prompt_with_default(
        "Enter the output filename (including .nc)", "merged.nc"
    )

    # Build and verify file list
    file_list = [
        os.path.join(folder, f"{basename}{i}.nc")
        for i in range(start, end + 1)
    ]
    for fpath in file_list:
        if not os.path.isfile(fpath):
            raise FileNotFoundError(f"File not found: {fpath}")

    # Offset Time values to make them cumulative
    datasets = []
    time_offset = 0

    for fpath in file_list:
        ds = xr.open_dataset(fpath, decode_times=False)

        if 'Time' not in ds.coords:
            raise ValueError(f"'Time' coordinate not found in {fpath}")

        time_vals = ds['Time'].values

        new_time_vals = time_vals + time_offset
        ds = ds.assign_coords(Time=new_time_vals)
        datasets.append(ds)

        if len(time_vals) > 1:
            dt = time_vals[1] - time_vals[0]
            duration = dt * len(time_vals)
        else:
            duration = 1
        time_offset += duration

    # Concatenate
    merged_ds = xr.concat(datasets, dim="Time")

    # Optionally decode CF conventions after loading
    try:
        merged_ds = xr.decode_cf(merged_ds)
    except Exception as e:
        print(f"Warning: CF decoding failed: {e}\nProceeding with raw time values.")

    # Inspect and save
    try:
        tmax = merged_ds.Time.max().values
        print(f"Final time value: {tmax:.0f}")
    except Exception:
        print("Time variables not decoded correctly.")

    merged_ds.to_netcdf(output)
    print(f"Merged dataset written to {output}")


if __name__ == "__main__":
    main()
