import os
import numpy as np
from PIL import Image
from osgeo import osr, gdal, gdalconst
from osgeo_utils.gdal_merge import main as gdal_merge
from .set import gdal_dtype as set_gdal_dtype
from .color import gen_ctable

# Read each TIFF's information
def multiple(idx_cog,target_src,source,cog,band):

    img = []
    for i_source in range(len(source)):

        # Detect tif in bbox of output
        fnames  = source[i_source]["fnames"][target_src]
        s_srs   = source[i_source]["srs"][target_src]
        lnum    = source[i_source]["layer_number"]
        sparams = source[i_source]["params"]
        re_rgb  = source[i_source]["regen_rgb"]

        # Integrate multiple tiff
        img_tmp = warp_files(fnames,s_srs,cog,band,idx_cog,lnum,sparams,re_rgb)

        # Append
        if not re_rgb:
            img.append(img_tmp)
        else:
            img = img_tmp

    # Convert np array
    img = np.array(img)

    # Apply color information
    pint = cog["bands"][band]["pint"]
    if (img.shape[0] == 3):

         # Add alpha channel if RGB
        if pint == "RGB":
            img    = add_alpha(img)
            ctable = None

        # Convert 24 bit true color to 8 bit palette color
        elif pint == "PALETTE":
            nodata = cog["bands"][band]["dn"]["nodata"]
            img, ctable = convert_img_bit(img,nodata)

    else:

        # Generate colortable (LCCS)
        if pint == "PALETTE":
            labels = cog["bands"][band]["labels"]
            ctable = gen_ctable(labels)

        # No color table
        else:
            ctable = None

    # Show image (for test)
    #import matplotlib.pyplot as plt
    #plt.imshow(img[0])
    #plt.show()

    # Return
    return img, ctable

# warp files
def warp_files(fnames,s_srs,cog,band,idx,lnum,sparams,re_rgb):

    # Set source projection (may cause error if source contains differennt projection)
    #src_srs = osr.SpatialReference()
    #src_srs.ImportFromEPSG(int(list(set(s_srs))[0]))

    # Set destination projection
    dst_srs = osr.SpatialReference()
    dst_srs.ImportFromEPSG(cog["epsg"])

    # Set data_type, nodata, same_in & out
    data_type, nodata, same_inout = set_data_params(sparams,cog,band)

    # Set output image size
    img_size  = [int(cog["ppu_max"]*cog["y_range"]/cog["unit"]),
                 int(cog["ppu_max"]*cog["x_range"]/cog["unit"])]

    # Source input as datasets or files
    if re_rgb:
        inputDs = []
        for i in range(len(fnames)):
            tmp = gdal.Translate("",str(fnames[i]),
                    format    = "MEM",
                    rgbExpand = "rgb"
                   )
            inputDs.append(tmp)
    else:
        inputDs = [str(fnames[i]) for i in range(len(fnames))]

    # Set warp options
    warpOptions = ["NUM_THREADS=ALL_CPUS"]

    # Warp to exact area and resolution
    # ex. creationOptions = ["COMPRESS=DEFLATE"],
    outDs = gdal.Warp('',inputDs,
        format          = "MEM",
        outputBounds    = cog["bbox_all"][idx],
        outputBoundsSRS = dst_srs,
        width           = img_size[1],
        height          = img_size[0],
        dstSRS          = dst_srs,
        outputType      = set_gdal_dtype(data_type),
        workingType     = gdal.GDT_Float32,
        warpOptions     = warpOptions,
        warpMemoryLimit = "9999",
        resampleAlg     = "MODE", # NEAREST
        srcNodata       = nodata,
        dstNodata       = nodata,
        multithread     = True,
        overviewLevel   = "AUTO")

    # Clear memory
    inputDs = None

    # Get image
    if re_rgb:
        img = outDs.ReadAsArray()
    else:
        img = outDs.GetRasterBand(lnum).ReadAsArray()

    # Clear memory
    outDs = None

    # Convert image if output type is different for RGB
    if not same_inout:
        img = convert_img_dtype(img,sparams,cog,band)

    # Show image (for test)
    #import matplotlib.pyplot as plt
    #plt.imshow(img)
    #plt.show()

    # Return
    return img

# set data params
def set_data_params(sparams,cog,band):

    # Input data parameters
    data_type_in = sparams["dn"]["data_type"]
    nodata_in    = sparams["dn"]["nodata"]

    # Output data parameters
    data_type_out = cog["bands"][band]["dn"]["data_type"]
    nodata_out    = cog["bands"][band]["dn"]["nodata"]

    # Check if data type will be changed or not
    judge1 = data_type_out == data_type_in
    judge2 = nodata_out    == nodata_in

    # Set output image's parameters
    if judge1 & judge2 :
        data_type  = data_type_out
        nodata     = nodata_out
        same_inout = True
    else:
        data_type  = data_type_in
        nodata     = nodata_in
        same_inout = False

    # Return
    return data_type, nodata, same_inout


# convert image bit
def convert_img_bit(img_in,nodata):

    # Detect minimum value of each dimension
    alpha = np.min(img_in,axis=0)
    alpha[alpha > 0] = 1

    # Change dimensions
    img_tmp = np.transpose(img_in,[1,2,0])

    # Read image as Image class in PIL
    img_tmp = Image.fromarray(img_tmp)

    # Quantize 24 bit to 8 bit colors
    # method = 0 : median cut
    qnum  = 255 # not 256 becauce 0 to nodata
    img_q = img_tmp.quantize(colors=qnum, method=0, dither=1)

    # Get palette
    palette = img_q.getpalette()
    palette = np.reshape(palette,[256,3])
    palette = np.append([[0,0,0]],palette[:-1:],axis=0)

    # Create color table
    ctable = gdal.ColorTable()
    for i in range(len(palette)):
        ptmp = tuple(np.append(palette[i],255))
        ctable.SetColorEntry(i, ptmp)

    # Output image
    img_out = np.array(img_q)+1
    img_out[alpha == 0] = nodata
    img_out = np.array([img_out])

    # Return
    return img_out, ctable

# make alpha image and add
def add_alpha(img_in):

    # Detect minimum value of each dimension
    alpha = np.min(img_in,axis=0)
    alpha[alpha > 0] = 255

    # Add alpha
    img_out = np.concatenate([img_in,[alpha]])

    # Return
    return img_out

# convert image for RGB
def convert_img_dtype(img_in,sparams,cog,band):

    # Set input params
    nodata_in = sparams["dn"]["nodata"]
    slope     = sparams["dn2value"]["slope"]
    offset    = sparams["dn2value"]["offset"]

    # Set output params
    data_type_out = cog["bands"][band]["dn"]["data_type"]
    nodata_out    = cog["bands"][band]["dn"]["nodata"]
    min_out       = cog["bands"][band]["dn"]["min"]
    max_out       = cog["bands"][band]["dn"]["max"]

    # Initialize
    img_out = np.array(img_in,dtype=np.float)

    # Apply nodata to nan, slope, offset
    img_out[img_out == nodata_in] = np.nan
    img_out = slope*img_out+offset

    # Apply gamma correction
    gamma = 1.5
    bit   = 8
    img_out = (2**bit-1)*(img_out/(2**bit-1))**(1/gamma)

    # Regulation and apply nodata
    img_out[img_out < min_out] = min_out
    img_out[img_out > max_out] = max_out
    img_out[img_out == np.nan] = nodata_out
    img_out = img_out.astype(data_type_out)

    # Return
    return img_out

# merge files (function for backup measure)
def merge_files(fnames,cog,band,idx,lnum):

    # Set output image's parameters
    data_type = cog["bands"][band]["dn"]["data_type"]
    nodata    = cog["bands"][band]["dn"]["nodata"]
               
    # Output, input path
    out_tmp_path = str(cog["out_path"].joinpath("merge_temp.tif"))
    input_files  = [str(fnames[i]) for i in range(len(fnames))]

    # Merge parameters
    # (gdal_merge uses nearest neibor method (can't change))
    parameters  = ['', '-o', out_tmp_path] +\
                   input_files +\
                  ['-of','GTiff',
                   '-init', f"{nodata}",
                   '-ot', data_type,
                   '-ul_lr',
                   f"{cog['bbox_all'][idx][0]}",
                   f"{cog['bbox_all'][idx][3]}",
                   f"{cog['bbox_all'][idx][2]}",
                   f"{cog['bbox_all'][idx][1]}",
                   '-co', 'COMPRESS=DEFLATE',
                   '-n', f"{nodata}",
                   '-ps',f"{1/cog['ppu_max']}",f"{1/cog['ppu_max']}"]

    # Merge all tif files
    gdal_merge(parameters)

    # Read merged tif
    src = gdal.Open(out_tmp_path, gdalconst.GA_ReadOnly)
    img = src.GetRasterBand(lnum).ReadAsArray()

    # delete
    src = None
    os.remove(out_tmp_path)

    # Return
    return img


        

