# Import
import numpy as np
import itertools
from osgeo import osr, gdal, gdalconst
from ..utils.progress import bar as progress_bar

# Read each TIFF's information
def info(band_name, fnames,target_epsg,unit):

    # Show progress
    print(" - Getting input tiff files information ...")

    # Set warning suppress (in unix, replace "NUL" to "/dev/null")
    gdal.SetConfigOption('CPL_LOG', 'NUL')

    # Initialization
    bbox_out = np.zeros([len(fnames),4],dtype = np.float64)
    ppu      = np.zeros( len(fnames)   ,dtype = np.float64)
    srs      = np.zeros( len(fnames)   ,dtype = np.int32  )
    csum     = []

    # Read information of each TIFF
    for i in range(len(fnames)):

        # Read tiff
        src = gdal.Open(str(fnames[i]), gdalconst.GA_ReadOnly)

        # Get spatial reference
        srs[i] = src.GetSpatialRef().GetAttrValue('AUTHORITY',1)

        # Get transformation values
        tsf = src.GetGeoTransform()

        # Get bounding box
        bbox_in = [tsf[0],
                   tsf[3]+src.RasterYSize*tsf[5],
                   tsf[0]+src.RasterXSize*tsf[1],
                   tsf[3]]

        # Calculate transformed bounding box
        bbox_out[i,:] = tform_bbox(bbox_in,srs[i],target_epsg)

        # Get x,y average ppd
        ppu_x_tmp = src.RasterXSize/(bbox_out[i,2]-bbox_out[i,0])
        ppu_y_tmp = src.RasterYSize/(bbox_out[i,3]-bbox_out[i,1])
        ppu[i]    = unit*((ppu_x_tmp+ppu_y_tmp)/2)

        # Get color table
        ct_tmp = src.GetRasterBand(1).GetRasterColorTable()
        if ct_tmp is not None:

            # Convert table to rgb array
            cmap_tmp = [ct_tmp.GetColorEntry(i) for i in range(256)]
            cmap     = list(itertools.chain.from_iterable(cmap_tmp))
            cmap_sum = sum(cmap) 

            # Append
            csum.append(cmap_sum)

        # Show progress
        progress_bar(i,len(fnames))

    # Validate all cmap is same or not (regenerate RGB flag)
    if not csum:
        regen_rgb = False
    else:
        csum_u = set(csum)
        if len(csum_u) > 1:
            regen_rgb = True
        else:
            regen_rgb = False

    # Calculate ppu_max
    ppu_max = ppu.max()

    # Return
    return bbox_out, srs, ppu_max, regen_rgb

# Transform bbox
def tform_bbox(bbox_input,epsg_input,epsg_output):

    # Check input/output epsg
    if epsg_input == epsg_output:

        # Return, finish
        return bbox_input

    # Transform bounding box
    else:

        # Set input spatial reference
        input_sref  = osr.SpatialReference()
        input_sref.ImportFromEPSG( int(epsg_input ))

        # Set output spatial reference
        output_sref = osr.SpatialReference()
        output_sref.ImportFromEPSG(int(epsg_output))

        # Modify x to lon,y to lat in GDAL 3.0 Coordinate transformation
        output_sref.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)

        # Transformer
        tformer = osr.CreateCoordinateTransformation(input_sref, output_sref)

        # Set pts
        # Recommended to use 21. This is the number of points to use
        # to densify the bounding polygon in the transformation.)
        densify_pts = 21
        
        # Transform 
        bbox_output = tformer.TransformBounds(bbox_input[0],bbox_input[1],
                                              bbox_input[2],bbox_input[3], densify_pts)

        # Return
        return bbox_output

# Read single cog
def single(fn,rnum):

    # Get raster
    src    = gdal.Open(fn, gdalconst.GA_ReadOnly)
    ctable = src.GetRasterBand(1).GetRasterColorTable() # GetCount(),GetColorEntry(10)

    # Get each bands
    img = []
    for i in range(len(rnum)):
        img.append(src.GetRasterBand(rnum[i]).ReadAsArray())
    img = np.array(img)

    # Read Tiff tags
    #src.GetMetadata()
    #src.GetMetadataItem("TIFFTAG_SOFTWARE")

    # Return
    return img,ctable