#! /usr/bin/env python
# JBM
# 24/11/2016: Corrected altitude
# 09/01/2017: Improved initialization of time axis
from netCDF4 import Dataset as NetCDFFile
from netCDF4 import num2date
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as dates
import datetime
from matplotlib.dates import date2num
import pandas
from optparse import OptionParser
from pylab import savefig
import sys, getopt

def main(argv):
  inputfile = ''
  outputdir = ''
  errmsg='Use: '+str(sys.argv[0])+' -i <inputfile> -o <outputdir>'
  try:
     opts, args = getopt.getopt(argv,"h:i:o:")
  except getopt.GetoptError:
     print(errmsg)
     sys.exit(2)
  for opt, arg in opts:
     if opt == '-h':
        print(errmsg)
        sys.exit()
     elif opt in ("-i"):
        inputfile = arg
     elif opt in ("-o"):
        outputdir = arg
  if len(inputfile) == 0 or len(outputdir) == 0:
    print('Please specify an input file and an output directory.')
    print(errmsg)
    sys.exit()
  print( 'Input file is ', inputfile)
  print('Output directory is ', outputdir)

  # MAIN PROGRAM
  # -----------------------------------------------------------------

  ncfile=inputfile
  if len(ncfile.rsplit('/',1)) > 1:
    decomp=(ncfile.rsplit('/',1))[1].rsplit('_',5)
  else:
    decomp=ncfile.rsplit('_',5)
  short=decomp[0]+'_'+decomp[2]+'_'+decomp[3]
#  shortname=short.replace(".","-")
  shortname=short

  datapath='/thredds/ipsl/fabric/lmdz/AXE4/'

  # LOADING THE DATA
  # -----------------------------------------------------------------
  
  # Read the files using the (very convenient) Pandas reader
  #data = pandas.read_csv('CR3000_Tour_PT100_30_Air_T.dat',sep=',', na_values=".")
  #data = pandas.read_csv('CR3000_Tour_PT100_30.dat',sep=',', na_values=".")
  #data = pandas.read_csv('temp10+_modified.dat',parse_dates='Date',sep=';', na_values=".")
  data = pandas.read_csv(datapath+'/'+'temp10+_modified.dat',sep=';', na_values=".")
  
  # Monthly mean
  # ------------
  
  data.index = pandas.to_datetime(data['Date'], format='%Y-%m-%d %H:%M:%S')
  #IMorig datamth=data.resample('M', how='mean')
  datamth=data.resample('M').mean()
  levels_obs = [3.5, 10.9, 18.3, 25.6, 33., 42.2]
  
  # Full dataset
  # ------------
  
  print(data.columns)
  
  # LOADING THE GCM RESULTS
  # -----------------------------------------------------------------
  
  # FILE 1
  nc = NetCDFFile(ncfile)
  time = nc.variables['time_counter'][:]
  temp = nc.variables['temp'][:]
  longi = nc.variables['lon'][:]
  lati = nc.variables['lat'][:]
  longi_user = 123.
  longi_id = np.abs(longi - longi_user).argmin()
  lati_user = -75.
  lati_id = np.abs(lati - lati_user).argmin()
  alti_id = range(4)
  alti_var = np.zeros(4)
  flabel = ['']*4
  if str(nc.variables).find("geop") > -1 and \
     str(nc.variables).find("phis") > -1:
    geop = nc.variables['geop'][:]
    phis = nc.variables['phis'][:]
    for ilev in range(4):
      alti_var[ilev] = np.mean(geop[:,alti_id[ilev],lati_id,longi_id] - \
                       phis[:,lati_id,longi_id])/9.8
      flabel[ilev] = \
        "LMDz (z="+str("%.1f" % alti_var[ilev])+"m)"
  else:
    flabel[0] = "LMDz (z=6m approx.)"
    flabel[1] = "LMDz (z=20m approx.)"
    flabel[2] = "LMDz (z=35m approx.)"
    flabel[3] = "LMDz (z=53m approx.)"
  print(longi[longi_id])
  print(lati[lati_id])

  if time[0] > 86400:
  # Time axis is in seconds 
    date = num2date(time[:], units = 'seconds since 2010-01-01 00:00:00')
  elif time[0] > 1:
  # Time axis is in days
    date = num2date(time[:], units = 'days since 2010-01-01 00:00:00')
  else:
  # Time axis is in months
    date = num2date(time[:], units = 'months since 2010-01-01 00:00:00')

  # DISPLAYING THE RESULTS
  # -----------------------------------------------------------------
  # We plot the figure
  fig = plt.figure()
  ax = fig.gca()
  ax.set_xticks(date)
  ax.set_xticklabels(date)
  ax.xaxis.set_major_locator(dates.MonthLocator())
  ax.xaxis.set_major_formatter(dates.DateFormatter('%b'))
  ax.set_title(shortname)
  ax.set_ylabel("Air temperature (degC)")
  plt.grid()
  
  #plt.plot(codedate, uservar1) # DATA
  #plt.plot(codedate, uservar1,'0.8') # FULL full
  linegcm=plt.plot(date,temp[:,alti_id[0],lati_id,longi_id]-273.15,'k-o',linewidth=2,label=flabel[0]) # GCM
  lineobs=plt.plot(date,datamth['tm1'],'k--', \
    linewidth=2,label='OBS '+str("%.1f" % levels_obs[0])+'m (2010)') # DATA monthly mean
  lineobs=plt.plot(date,datamth['tm2'],'k--', \
    linewidth=2,label='OBS '+str("%.1f" % levels_obs[1])+'m (2010)') # DATA monthly mean
  linegcm=plt.plot(date,temp[:,alti_id[1],lati_id,longi_id]-273.15,'b-o',linewidth=2,label=flabel[1]) # GCM
  lineobs=plt.plot(date,datamth['tm3'],'b--', \
    linewidth=2,label='OBS '+str("%.1f" % levels_obs[2])+'m (2010)') # DATA monthly mean
  lineobs=plt.plot(date,datamth['tm4'],'b--', \
    linewidth=2,label='OBS '+str("%.1f" % levels_obs[3])+'m (2010)') # DATA monthly mean
  linegcm=plt.plot(date,temp[:,alti_id[2],lati_id,longi_id]-273.15,'r-o',linewidth=2,label=flabel[2]) # GCM
  lineobs=plt.plot(date,datamth['tm5'],'r--', \
    linewidth=2,label='OBS '+str("%.1f" % levels_obs[4])+'m (2010)') # DATA monthly mean
  lineobs=plt.plot(date,datamth['tm6'],'r--', \
    linewidth=2,label='OBS '+str("%.1f" % levels_obs[5])+'m (2010)') # DATA monthly mean
#  linegcm=plt.plot(date,temp[:,alti_id[3],lati_id,longi_id]-273.15,'g-o',linewidth=2,label=flabel[3]) # GCM
  #linegcm=plt.plot(date, temp2[:,alti_id,lati_id,longi_id]-273.15,'b-o',linewidth=2,label='NPv5.5 (1982-1989)') # GCM
  #plt.plot(date, t2m[:,lati_id,longi_id]-273.15,'ro') # GCM
 
  handles=[lineobs, linegcm]
  ax.legend(handles)
  plt.legend(loc=(1.03, 0.8))
  #plt.legend(('pvap (bottom)', 'pvap (top)'), loc=(1.03, 0.8))
  plt.xticks(rotation=30,ha='right') #plt.xticks(rotation='vertical')
  plt.subplots_adjust(bottom=0.2)
  plt.subplots_adjust(right=0.65) # keep room for legend e.g. 0.8
  
  savefig(outputdir+'/'+'tempDC-'+shortname+'.png', bbox_inches='tight')
  #plt.show()

  # -----------------------------------------------------------------
  # -----------------------------------------------------------------
  # -----------------------------------------------------------------

# MAIN PROGRAM
# Just calling the main function
if __name__ == "__main__":
   main(sys.argv[1:])

#------------------------------------------------------------------


