# Import
import numpy as np
from osgeo import gdal, osr
from .bounds.intersect import bbox as intersect_bbox
from ..utils.progress import bar as progress_bar
from .integ import multiple as integ_multi_tiff
from .set import params_tags as set_params_tags
from pathlib import Path
from .read import info as read_cog_info

# multi_level
def multi_level(source,cog):

    # Generate multiple level's COG
    for i in range(len(cog)):

        # Show progress
        cog_level = cog[i]["cog_level"]
        cog_lvstr = f"COG-level-{'{:1.0f}'.format(cog_level)}"
        band      = source[cog_lvstr][0]["band"]
        print(f" - COG Level {cog_level} Generating...")

        # Generate multiple area's COG
        multi_area(source[cog_lvstr], cog[i])

        # Update source
        if i < len(cog)-1:
            source = update_source(source, cog[i],band)

# update_source
def update_source(old_source,cog,band):

    # Detect old and new names
    cog_level_old = cog["cog_level"]
    cog_lvstr_old = f"COG-level-{'{:1.0f}'.format(cog_level_old  )}"
    cog_lvstr_new = f"COG-level-{'{:1.0f}'.format(cog_level_old-1)}"

    # Detect total layers sum number
    pint = cog["bands"][band]["pint"]
    if pint == "RGB":
        lnum = 3
    else:
        lnum = 1

    # Get each asset's file names
    wildcard = f"*/*{band}*tif*"
    out_dir  = cog["out_path"]
    fn_input = np.array(list(Path(out_dir).glob(wildcard)))

    # Get all tif's bbox and ppd (pixel per degree)
    bbox, srs, ppu_max, regen_rgb = read_cog_info(band,fn_input,cog["epsg"],cog["unit"])    

    # Create new source
    src_tmp = []
    for i in range(lnum):

        # Each layer's source
        src_tmp_tmp = {
            "band": band,
            "params":{
                "dn": cog["bands"][band]["dn"],
                "dn2value": cog["bands"][band]["dn2value"],
            },
            "layer_number": i+1,
            "fnames": fn_input,
            "bbox": bbox,
            "srs": srs,
            "ppu_max": ppu_max,
            "regen_rgb": regen_rgb,
        }

        # Append
        src_tmp.append(src_tmp_tmp)

    # Set new source
    new_source = old_source
    new_source[cog_lvstr_new] = src_tmp

    # Return
    return new_source

# multi_area : generate multi area's cog
def multi_area(source,cog):

    # Set Parameters
    band = source[0]["band"]

    # Generate all area's COG 
    for idx_cog in range(len(cog["bbox_all"])):

        # Check intersect (Assumed to be all source have same bbox)
        target_src = intersect_bbox(cog["bbox_all"][idx_cog],source[0]["bbox"])

        # Read/merge tif
        if any(target_src):

            # Generate single raster from source
            img, ctable = integ_multi_tiff(idx_cog,target_src,source,cog,band)

            # Set parameters and tags
            gopt, params, custom_tags, metadata = set_params_tags(img, band, ctable, cog, idx_cog)

            # Generate single COG
            single(gopt, params,custom_tags,metadata)

        # Show progress
        progress_bar(idx_cog,len(cog["bbox_all"]))

# single: generate single cog
def single(gopt, params,tags,metadata):

    # Create GeoTIFF driver
    mem_driver = gdal.GetDriverByName('MEM')
    raster = mem_driver.Create("",
        params["img_size"][2],
        params["img_size"][1],
        params["img_size"][0],
        params["dtype"],
        options = gopt)

    # Set Spatial Reference System
    raster.SetGeoTransform((params["lonlim"][0], params["dlon"], 0,
                            params["latlim"][1], 0, -params["dlat"]))
    raster_srs = osr.SpatialReference()
    raster_srs.ImportFromEPSG(params["epsg"])
    raster.SetProjection(raster_srs.ExportToWkt())

    # Set Metadata (TIFF tags)
    raster.SetMetadata(metadata)

    # Set OVV resampling method
    # "AVERAGE", "AVERAGE_MAGPHASE", "RMS", "BILINEAR", "CUBIC", 
    # "CUBICSPLINE", "GAUSS", "LANCZOS", "MODE", "NEAREST", or "NONE"
    if params["img_size"][0] >= 3 :
        ovv_method = "NEAREST"
    else:
        ovv_method = "MODE"

    # Set data, gdal metadata, pyramid
    for i in range(params["img_size"][0]):

        # Set ColorMap if it is not None
        if tags["ColorMap"] is not None:
            raster.GetRasterBand(i+1).SetColorTable(tags["ColorMap"])

        # Set GDAL_NODATA tag
        if tags["GDAL_NODATA"] is not None:
            raster.GetRasterBand(i+1).SetNoDataValue(tags["GDAL_NODATA"])

        # Set GDAL_METADATA tag
        raster.GetRasterBand(i+1).SetUnitType(tags["GDAL_METADATA"]["unit"  ])
        raster.GetRasterBand(i+1).SetScale(   tags["GDAL_METADATA"]["scale" ])
        raster.GetRasterBand(i+1).SetOffset(  tags["GDAL_METADATA"]["offset"])
        
        # Set JAXA Earth's metadata to GDAL_METADATA
        raster.GetRasterBand(i+1).SetMetadata({"OVV_RESAMPLING":ovv_method})

        # Set Alpha (6 means 'Alpha (gdal.GetColorInterpretationName(6))')
        if i == 3:
            raster.GetRasterBand(i+1).SetColorInterpretation(6)

        # Write each band
        raster.GetRasterBand(i+1).WriteArray(params["img"][i])

    # Make overviews
    if params["img_lv"] is not None:
        raster.BuildOverviews(ovv_method,params["img_lv"])

    # Make COG
    gtiff_driver = gdal.GetDriverByName('GTiff')
    gopt.append("COPY_SRC_OVERVIEWS=YES")
    cogout = gtiff_driver.CreateCopy(str(params["fn_out"]),raster,options=gopt)

    # Write COG
    cogout.FlushCache()
    raster = None
    cogout = None


