#! /usr/bin/env python
from    netCDF4               import    Dataset
from	numpy		      import	*
import  numpy                 as        np
import  matplotlib.pyplot     as        mpl
from matplotlib.cm import get_cmap
import pylab
from matplotlib import ticker
import matplotlib.colors as colors
from mpl_toolkits.basemap import Basemap, shiftgrid
from FV3_utils import *

############################
step_begin = 2015
step_end = 2015

var="temperature" #variable
tint=[30,37] #Time must be as written in the input file
# tint=None #Time must be as written in the input file
xarea="0,1"
yarea="0,1"

prefix="Xhistins"
suffix="_A.nc"
filename=f"../{prefix}{step_begin}{suffix}"
nc1=Dataset(filename)

lat=getvar(nc1,"latitude")
lon=getvar(nc1,"longitude")
alt=getvar(nc1,"altitude")
print("max alt: ",np.max(alt))
time=getvar(nc1,"Time")
time=getvar(nc1,"Time",times=tint) # select days
nbday=int(time[-1]-time[0])

myvar=getvar(nc1,var,times=tint)
myvar=myvar[:,:,0,0]



# read all time steps
for step in range(step_begin+1, step_end+1):
    filename=f"../../{prefix}{step}{suffix}"
    nc1=Dataset(filename)
    newvar=getvar(nc1,var,times=tint)
    newvar=newvar[:,:,0,0]
    myvar=np.concatenate((myvar, newvar), axis=0)

    newtime=getvar(nc1,"Time")
    new_nbday=int(newtime[-1]-newtime[0])
    newtime+=nbday # account for previous files
    nbday=new_nbday+nbday
    time=np.concatenate((time, newtime), axis=0)

# nb of time:
nbtime=size(myvar[:,0])
print(nbtime)

# nb time step /day
nbstep = int(nbtime/nbday)
print((shape(myvar)))

# nb alt
nbalt=size(alt)
print(nbalt)

print("nbday=",nbday)
print("nbtime=",nbtime)
print("nbstep=",nbstep)
print("nbalt=",nbalt)

nbstep_mean=nbtime-nbstep*2
# nbstep_mean=nbstep
#meanvar=np.zeros((nbday,nbalt),dtype='f')
meanvar=np.zeros((nbstep_mean,nbalt),dtype='f')
# anovar=np.zeros((nbtime,nbalt),dtype='f')
anovar=np.zeros((nbstep_mean,nbalt),dtype='f')

# pour chaque jour : calcul moyenne diurne
for i in range(nbstep_mean):
    #i=i+nbstep/2
    # meanvar[i,:]=np.mean(myvar[slice(i,i+((nbday-1)*nbstep)+1,nbstep)], axis=0)
    meanvar[i,:]=np.mean(myvar[nbstep//2+i:nbstep+nbstep//2+i,:],axis=0)
    # meanvar[i,:]=np.mean(myvar[nbstep//2+i:nbstep+nbstep//2+i,:],axis=0)
#for i in range(nbday):
#    meanvar[i,:]=np.mean(myvar[0:8,:],axis=0)

# pour chaque time : calcul anomaly le dernier time est pour autre jour
for i in range(nbstep_mean):
    # for j in range(nbday):
        #index=int(i/nbstep)
        # anovar[i+nbstep*j,:]=myvar[i+nbstep*j,:]-meanvar[i,:]
    anovar[i,:]=myvar[nbstep+i,:]-meanvar[i,:]

print((meanvar[:,:]))
print((myvar[0,:]))

mpl.figure(figsize=(20, 10))
font=26

#pal=get_cmap(name="RdYlBu_r")
# pal=get_cmap(name="Spectral_r")
lev=np.linspace(-0.1,0.1,10)
# lev=np.linspace(anovar.min(),anovar.max(),10)

#xticks=[-90,-60,-30,0,30,60,90]
#yticks=np.linspace(0,240,9)
# time=np.arange(nbstep_mean)/floor(nbstep)
time=time[nbstep:len(time)-nbstep]

print((shape(time), shape(alt),shape(anovar)))

CF=mpl.contourf(time,alt,np.transpose(anovar),lev,cmap="coolwarm",extend='both')
cbar=mpl.colorbar(CF,shrink=1, format="%1.2f")
cbar.ax.set_title("[K]",y=1.04,fontsize=font)
for t in cbar.ax.get_yticklabels():
      t.set_fontsize(font)

#vect=lev
#CS=mpl.contour(lat,alt,myvar,vect,colors='k',linewidths=0.5)
#### inline=1 : values over the line
#mpl.clabel(CS, inline=1, fontsize=20, fmt='%1.0f',inline_spacing=1)

mpl.title('Temperature anomaly', fontsize=font+2)
mpl.ylabel('Altitude (km)',labelpad=10,fontsize=font)
mpl.xlabel('Time (Pluto days)',labelpad=10, fontsize=font)
#mpl.xticks(xticks,fontsize=font-3)
mpl.xticks(fontsize=font-3)
#mpl.yticks(yticks,fontsize=font-3)
mpl.yticks(fontsize=font-3)
pylab.ylim([0,np.max(alt)])
pylab.ylim([0,230])

# mpl.savefig('tempanom.eps',dpi=200)
mpl.savefig(f"tempanom_{step_begin}_{step_end}.png",dpi=200)
mpl.show()




