# Import
import numpy as np
import datetime
import math
from ...date.set import set_dlim

# Projection parameters class
class Epsg:

    # ----------------------------------------------------------------------------
    # Constructor
    # ----------------------------------------------------------------------------
    def __init__(self,epsg=4326):

        # ------------------------------------------------------------------------
        # Polar steleo projection : North (EPSG:3995)
        # ------------------------------------------------------------------------        
        if epsg == 3995:

            # EPSG
            self.epsg = 3995

            # COG levels, PPU list (pixels per 32768 = 2**15 m)
            self.levels = np.array([[0,0,0,1, 1, 1, 2,  2,  2],\
                                    [1,2,4,8,16,32,64,128,256]])

            # Set each cog level's x,y range
            self.x_ranges = [2**24, 2**21, 2**18]
            self.y_ranges = [2**24, 2**21, 2**18]

            # Set bbox max
            self.bbox_max = [-2**23,-2**23,2**23,2**23]

            # Set ppu's unit
            self.unit     = 32768
            self.unit_str = "m"

            # x : longitude
            self.x_strfmt    = "{:07.0f}"
            self.x_plus_str  = "P"
            self.x_minus_str = "M"

            # y : latitude
            self.y_strfmt    = "{:07.0f}"
            self.y_plus_str  = "P"
            self.y_minus_str = "M"

        # ------------------------------------------------------------------------
        # Polar steleo projection : South (EPSG:3031)
        # ------------------------------------------------------------------------        
        if epsg == 3031:

            # EPSG
            self.epsg = 3031

            # COG levels, PPU list (pixels per 32768 = 2**15 m)
            self.levels = np.array([[0,0,0,1, 1, 1, 2,  2,  2],\
                                    [1,2,4,8,16,32,64,128,256]])

            # Set each cog level's x,y range
            self.x_ranges = [2**24, 2**21, 2**18]
            self.y_ranges = [2**24, 2**21, 2**18]

            # Set bbox max
            self.bbox_max = [-2**23,-2**23,2**23,2**23]

            # Set ppu's unit
            self.unit     = 32768
            self.unit_str = "m"

            # x : longitude
            self.x_strfmt    = "{:07.0f}"
            self.x_plus_str  = "P"
            self.x_minus_str = "M"

            # y : latitude
            self.y_strfmt    = "{:07.0f}"
            self.y_plus_str  = "P"
            self.y_minus_str = "M"

        # ------------------------------------------------------------------------
        # EQR projection (EPSG:4326, default)
        # ------------------------------------------------------------------------
        elif epsg == 4326:

            # EPSG
            self.epsg = 4326

            # COG levels, PPU list (pixels per degree)
            self.levels = np.array([[   0,  0, 0, 1, 1, 1, 2,  2,  2,  3,   3,   3,   4,    4,    4],\
                                    [1.25,2.5, 5,10,20,40,90,180,360,900,1800,3600,9000,18000,36000]])

            # Set each cog level's x,y range
            self.x_ranges = [180, 90, 10, 1, 0.1]
            self.y_ranges = [180, 90, 10, 1, 0.1]

            # Set bbox max
            self.bbox_max = [-360,-90,360,90]

            # Set ppu's unit
            self.unit     = 1
            self.unit_str = "degree"

            # x : longitude
            self.x_strfmt    = "{:06.2f}"
            self.x_plus_str  = "E"
            self.x_minus_str = "W"

            # y : latitude
            self.y_strfmt    = "{:05.2f}"
            self.y_plus_str  = "N"
            self.y_minus_str = "S"

    # ----------------------------------------------------------------------------
    # set_params
    # ----------------------------------------------------------------------------
    def set_cog_params(self, product_dict, ppu, bbox_max,output_path,dstr_raw):

        # Set dlim from date and format
        dfmt      = product_dict["summaries"]["je:stac_date_format"][0]
        ext_tmp   = product_dict["extent"]["temporal"]["interval"]
        ext_bbox  = product_dict["extent"]["spatial"]["bbox"]
        duration  = product_dict["duration"]
        dlim,dstr = set_dlim(dfmt,dstr_raw,duration,ext_tmp)

        # Convert datetime format to TIFF standard
        dlimtif = self.conv_dlim(dlim)

        # Set collection's band and epsg from product_dict
        band_names = list(product_dict["assets"].keys())
        bands = {}
        for i in range(len(band_names)):
            band_dict = product_dict["assets"][band_names[i]]["cog"]
            band_dict["labels"] = product_dict["assets"][band_names[0]]["classification:classes"]
            bands.update({band_names[i]: band_dict})

        # Set maximum COG level
        idx_cog       = np.where(self.levels[1] >= ppu)[0].min()
        cog_level_max = self.levels[0][idx_cog]

        # Set all cog level's information
        cog = []
        for i in range(int(cog_level_max)+1):

            # Set IMG level
            idx_img = np.where(self.levels[0] == i)[0]
            idx_img = idx_img[idx_img <= idx_cog]

            # Set all bbox
            bbox_all = self.set_bbox_all(i,bbox_max,ext_bbox)

            # Generate folder (if not exist)
            path_tmp = output_path.joinpath(product_dict["id"])\
                                  .joinpath(dstr)\
                                  .joinpath(str(i))
            if not path_tmp.exists():
                path_tmp.mkdir(parents=True)

            # Set dict
            ppu_max  = self.levels[1][idx_img.max()]
            tmp_dict = {"epsg": self.epsg,
                        "unit": self.unit,
                        "dlim": dlimtif,
                        "cog_level": i,
                        "overviews": self.set_overviews(len(idx_img)),
                        "ppu_max": ppu_max,
                        "x_range": self.x_ranges[i],
                        "x_strfmt": self.x_strfmt,
                        "x_plus_str" : self.x_plus_str,
                        "x_minus_str" : self.x_minus_str,
                        "y_range": self.y_ranges[i],
                        "y_strfmt": self.y_strfmt,
                        "y_plus_str" : self.y_plus_str,
                        "y_minus_str" : self.y_minus_str,   
                        "bbox_max": self.bbox_max,
                        "bbox_all": bbox_all,
                        "bands": bands,
                        "out_path": path_tmp}

            # Append
            cog.append(tmp_dict)

        # Convert order
        cog.reverse()

        # Return
        return cog

    # ----------------------------------------------------------------------------
    # set_overviews
    # ----------------------------------------------------------------------------
    def set_overviews(self,img_number):
        if img_number == 1:
            ovv = None
        elif img_number == 2:
            ovv = [2]
        elif img_number == 3:
            ovv = [2,4]
        else:
            ovv = None
        return ovv

    # ----------------------------------------------------------------------------
    # set_bbox_all
    # ----------------------------------------------------------------------------
    def set_bbox_all(self,idx,bbox_max_source,bbox_ext):

        # Calc unique extent of bbox
        bbox_ext_max = np.mean(np.array(bbox_ext),axis=0)

        # Limit bbox_max_source of bbox_max
        if bbox_max_source[0] < bbox_ext_max[0]:
           bbox_max_source[0] = bbox_ext_max[0]
        if bbox_max_source[1] < bbox_ext_max[1]:
           bbox_max_source[1] = bbox_ext_max[1]        
        if bbox_max_source[2] > bbox_ext_max[2]:
           bbox_max_source[2] = bbox_ext_max[2]
        if bbox_max_source[3] > bbox_ext_max[3]:
           bbox_max_source[3] = bbox_ext_max[3]

        # Set x,y range list
        x_range  = self.x_ranges[idx]
        y_range  = self.y_ranges[idx]

        # Set bbox max
        bbox_max = [math.floor((bbox_max_source[0]-self.bbox_max[0])/x_range)*x_range + self.bbox_max[0],
                    math.floor((bbox_max_source[1]-self.bbox_max[1])/y_range)*y_range + self.bbox_max[1],
                    math.ceil( (bbox_max_source[2]-self.bbox_max[0])/x_range)*x_range + self.bbox_max[0],
                    math.ceil( (bbox_max_source[3]-self.bbox_max[1])/y_range)*y_range + self.bbox_max[1]]

        # Get x,y edge values list
        x_edges = np.arange(bbox_max[0],bbox_max[2]+x_range,x_range)
        y_edges = np.arange(bbox_max[1],bbox_max[3]+y_range,y_range)

        # Get mesh all bbox
        bbox_all = np.zeros([len(x_edges)-1,len(y_edges)-1,4])
        for i in range(len(x_edges)-1):
            for j in range(len(y_edges)-1):
                bbox_all[i,j,:] = [x_edges[i  ],y_edges[j  ],
                                   x_edges[i+1],y_edges[j+1]]

        # Reshape
        bbox_all = bbox_all.reshape([-1,4])

        # Return
        return bbox_all

    # ----------------------------------------------------------------------------
    # conv_dlim : Convert date lim to TIFF/WMS standard
    # ----------------------------------------------------------------------------
    def conv_dlim(self,dlim):

        # Convert
        datetmp = [datetime.datetime.fromisoformat(dlim[i].replace("Z","")) for i in range(len(dlim))]
        dfmttif = "%Y:%m:%d %H:%M:%S"
        dlimtif = [datetmp[i].strftime(dfmttif) for i in range(len(dlim))]

        # Return
        return dlimtif




