#!python3
# plot section the OSNAP section
# to avoid attribute error due to version in proplot ...
from importlib.metadata import version
from packaging.version import parse
#
import sys
prfx='/discover/nobackup/projects/gmao/cal_ocn/abozec1/PYTHON/'
sys.path.append(prfx)

import myenv as my
import numpy as np
import netCDF4 as nc
from hycom.io import read_hycom_fields, read_hycom_grid,sub_var2
from myutilities.density import sigma0_hycom
from matplotlib.colors import BoundaryNorm
from myutilities.tvplus import tvplus
from myutilities.spherdist import spherdist
from time import process_time
print('Starting here ...')

# get the grid
# get lon and lat
data_folder = prfx+'../MOM6-expt/GLBb0.08/'
io_grd=data_folder
filet='regional.grid.a'
grid_data=io_grd+filet
#print(F"The fields available are: {read_field_grid_names(grid_data)}")
fieldg=['plon','plat','pscx','pscy']
grid_field= read_hycom_grid(grid_data, fieldg)
lon=grid_field['plon'][:,:]
lat=grid_field['plat'][:,:]

#ishift= -483 ## plon=-179.89
###ishift=-1197 ## plon=360.0
#lon_model=np.roll(lon, ishift,axis=1)
#lat_model=np.roll(lat, ishift,axis=1)
##print(lon_model.shape,lat_model.shape)

pscx=grid_field['pscx'][:,:]
pscx=pscx*1e-3 # convert to km
pscy=grid_field['pscy'][:,:]
pscy=pscy*1e-3 # convert to km

idm=4500 ; jdm=3298

## get mask
D=sub_var2(data_folder+'regional.depth.a',idm,jdm,1)
wet=np.zeros([jdm,idm])+1.
wet[np.isnan(D) == True ] = 0.
D[np.isnan(D) == True] =0.
#D=np.roll(D, ishift,axis=1)

# define sections from Xiaobiao Matlab code
# ----- name     num     ijstart       ij increase
sec= {
  "name": [  'OSNAP1',  'OSNAP2',  'OSNAP3',  'OSNAP4',  'OSNAP5',  'OSNAP6',  'OSNAP7',  'OSNAP8'],
  "numb": [        38,        19,        45,        18,        50,        24,       155,        35],
  "ijst": [[535,1164],[573,1164],[611,1183],[656,1318],[710,1332],[860,1282],[884,1259],[1039,1258]],
  "ijinc":[[  1,   0],[  2,   1],[  1,   3],[  2,   1],[  3,  -1],[  1,  -1],[  1,   0],[   3,  -1]]
} 

#For the RAPID section (26N), I also have two (separating Florida current @27N, and RAPID array @26.5N; in this case, you can use 26.5 only).
# FC27N: Florida Current 27N
# RAPID: RAPID @ 26.5N
sec1={
  "name": [  'FC27N',  'RAPID'],
  "numb": [       20,      790],
  "ijst": [[224,716],[264,709]],
  "ijinc":[[  1,  0],[  1,  0]]
}
#%-----------------------------------------------------------------
#
#Also, I define i/j using Atlantic model, and there is an offset from Global grid. 
ijoff = [2348-1, 1140-1];

# choose sections:
nsec=4
dx=1 ; dy=1
nn=0
tot_sec=8
for i in np.arange(tot_sec):
  nn=nn+sec["numb"][i] 
  print(i, nn)

lon_pies=np.zeros(nn+1)
lat_pies=np.zeros(nn+1)
ii_pies=np.zeros(nn+1)
jj_pies=np.zeros(nn+1)
ii_pies[0]=ijoff[0]+sec["ijst"][0][0]-1 # ijst -1 matlab vs python
jj_pies[0]=ijoff[1]+sec["ijst"][0][1]-1
#inci=sec["ijinc"][nsec][0]
#incj=sec["ijinc"][nsec][1]
# Get data for the section
n=0
for i in np.arange(tot_sec):
    inci=sec["ijinc"][i][0]
    incj=sec["ijinc"][i][1]
    nb=sec["numb"][i]
    for j in np.arange(nb):
      ii_pies[n+1]=ii_pies[n]+inci*dx
      jj_pies[n+1]=jj_pies[n]+incj*dy
      n=n+1


# get lon lat of the profiles
for n in np.arange(nn):
   lon_pies[n]=lon[int(jj_pies[n]),int(ii_pies[n])]
   lat_pies[n]=lat[int(jj_pies[n]),int(ii_pies[n])]

# distance of the section
dist=np.zeros(nn+1)
for n in np.arange(nn)+1:
    dist[n]=dist[n-1]+spherdist(lon_pies[n-1],lat_pies[n-1], lon_pies[n],lat_pies[n])*1e-3 # convert to km

# number of layers
kdm=41
nk='{:2d}'.format(kdm)

## location of Yucatan Section for GOMb0.08
plt='pcolor'
variable='temp'
orientation='yz'
max_depth=4000.0 # maximumdepth of plot at 2500m
year1='2012'
if (variable == 'temp'):
  stitle='Section '+sec["name"][nsec]+' GLBb0.08 pot. temp. (C) '+year1
  stitle='Section OSNAP GLBb0.08 pot. temp. (C) '+year1
else:
  stitle='Section '+sec["name"][nsec]+' GLBb0.08 salinity (psu) '+year1
  stitle='Section OSNAP GLBb0.08 salinity (psu) '+year1

mon=['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j' ,'k', 'l']

# get figures
PNG=1
png_dir=prfx+'PNG/GLBb0.08/'


#for m in [0,1,2,3,4,5,6,7,8,9,10,11]:
for m in [0]:
#for d in my.np.arange(91)+1:
#for d in [1]:
    #m=0 ## month 0 for january for Climatology
    #day='{:03d}'.format(d)
    month='{:02d}'.format(m+1)
    ## get file for thickness and v-vel
    io='/discover/nobackup/projects/gmao/cal_ocn/abozec1/MOM6-expt/GLBb0.08/'
    filed='MONTHLY_Z/ocean_month_'+year1+'_m'+month+'.nc'
    #filed='archm.'+year1+'_001_12.a'
    io='/css/gmao/cal-ocn/GLBb0.08/expt_10.5-23Jun2023/'
    io='/discover/nobackup/projects/gmao/cal_ocn/sakella/GLBb0.08/test_combine/mppncombine-1Proc/'
    filed='ocnm_'+year1+'_001.nc'

    iof=io+filed
    title='GLBb0.08 OM4.1+++ '+nk+' layers '+year1+' M:'+month
    title='GLBb0.08 G16.9 -1Proc- '+nk+' layers '+year1+' D:001'

    # get Netcdf
    # get depth
    fields=['h','potT','salt']
    layers=np.arange(kdm)
    start = process_time()
    dsd=nc.Dataset(iof)

    # read the global file
    h=dsd['h'][0,:,:,:]
    tmp0=dsd['potT'][0,:,:,:]
    tmp1=dsd['salt'][0,:,:,:]
    end = process_time()
    print(end-start)

    print('Data read! ')   
    # extract the section
    thk=np.zeros([kdm,nn])
    temp=np.zeros([kdm,nn])
    sal=np.zeros([kdm,nn])
    topo=np.zeros(nn)
    for n in np.arange(nn):
       thk[:,n]=h[:,int(jj_pies[n]),int(ii_pies[n])].data
       topo[n]=D[int(jj_pies[n]),int(ii_pies[n])]
       temp[:,n]=tmp0[:,int(jj_pies[n]),int(ii_pies[n])].data
       sal[:,n]=tmp1[:,int(jj_pies[n]),int(ii_pies[n])].data


    # get sigma_0
    temp[temp > 1e5] = 0.
    sal[sal > 1e5] = 0.
    dens=sigma0_hycom(temp,sal)


    # apply oneta to thk and calculate depth
    nn=thk.shape[1]
    dep=my.np.zeros([kdm+1,nn])
    print(topo.shape, dep.shape)
    dep[kdm,:]=topo[:]
    for k in my.np.arange(kdm-1, 0,-1):
        dep[k,:]=dep[k+1,:]-thk[k,:]

#    tmp=my.np.nancumsum(thk,axis=0)
#    nn=tmp.shape[1]
#    dep=my.np.zeros([kdm,nn])
#    dep[1:kdm,:]=tmp[0:kdm-1,:]
    # get mid-depth for labels
    dpm=my.np.zeros([kdm,nn])
    for n in my.np.arange(nn):
      for k in my.np.arange(kdm-1):
          dpm[k,n]=0.5*(np.min([max_depth,dep[k+1,n]])+np.min([max_depth,dep[k,n]]))


    # get longitude
    xaxpl=dist
    xlab=0
    longi=my.np.zeros([kdm+1,nn])
    for k in my.np.arange(kdm+1):
        longi[k,:]=dist[:-1]

    ## plot
    # get cmap
    cmap64=my.mygc.get_cmapjet40()
    #cmap64=my.mygc.get_cmapfc100()
    cms='jet'
    xlabels=['Distance (km)']
    levels_tem=my.np.linspace(2.,12.,41)
    levels_sal=my.np.linspace(34.8,35.4,40)
    normT = BoundaryNorm(levels_tem, ncolors=40, clip=True)
    normS = BoundaryNorm(levels_sal, ncolors=40, clip=True)

    if (variable == 'temp'):
       var=np.zeros([kdm+1,nn])
       var[0:kdm,:]=temp
       var[kdm,:]=temp[kdm-1,:]
       var[var < 2] = 2
       var[var > 12] = 12
       levels_var=levels_tem
       cmap=cmap64
       norm=normT
    else:
       var=np.zeros([kdm+1,nn])
       var[0:kdm,:]=sal
       var[kdm,:]=sal[kdm-1,:]
       levels_var=levels_sal
       cmap=cmap64
       norm=normS


    print(np.min(var), np.max(var), np.min(dens), np.max(dens))
    fig, axs = my.plot.subplots(ncols=1,nrows=1,refwidth='14cm',refaspect=1.46,span=0)
    axs.format(suptitle=stitle)
    #axs[0].plot(xaxpl,dep.T,color='k',lw=0.5)
    axs[0].contour(longi[:-1,:],dep[:-1,:],dens[:,:],lw=1, color='w',levels=[27.50,27.80,27.88])
#    m = axs[0].pcolormesh(longi[:,:],dep[:,:],var[:-1,:-1],cmap=cmap,norm=norm,shading='flat',extend='neither')
    m = axs[0].contourf(longi, dep, var, cmap=cmap, extend='neither',levels=levels_tem)

    axs[0].colorbar(m,loc='r',ticks=1.)
#    for k in my.np.arange(26)+15:
#        layer='{:2d}'.format(k+1)
#        axs[0].text(xaxpl[int(nn/2)], 0.5*(dpm.T[int(nn/2),k]+dpm.T[int(nn/2)+1,k]), layer,fontsize=8, color='red')

    axs[0].format(yreverse=True,ylim=(max_depth,0),title=title,\
               xlabel=xlabels[xlab],ylabel='Depth (m)',xtickdir='in',ytickdir='in')

    if (PNG == 1):
        #file_png='glb8_'+sec["name"][nsec]+'_om41+++_'+nk+'l_'+variable+'_'+year1+'_'+month+'.png'
        file_png='glb8_OSNAP_g16.9_1proc_'+nk+'l_'+variable+'_'+year1+'_'+month+'.png'

    print(file_png)
    fig.savefig(png_dir+file_png,dpi=300,\
               facecolor='w', edgecolor='w',transparent=False) ## .pdf,.eps
    my.plot.show()
    #my.plot.close()