# Import # Prevent the MCP process from crashing due to matplotlib attempting to open a GUI window. # Set the backend to Agg (non-interactive) before any other imports. import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from mcp.server.fastmcp import FastMCP, Image from jaxa.earth import je import requests import numpy as np import sys import os import contextlib from datetime import datetime, timedelta # Initialize FastMCP server mcp = FastMCP("JAXA_Earth_API_Assistant") def suppress_stdout(): """ contextlib.redirect_stdout only captures Python-level stdout. To handle cases where C extensions (e.g., je library, matplotlib) write directly to OS-level fd=1, this context manager also redirects stdout to stderr at the OS level. """ @contextlib.contextmanager def _suppress(): stdout_fd = sys.stdout.fileno() saved_fd = os.dup(stdout_fd) try: os.dup2(sys.stderr.fileno(), stdout_fd) with contextlib.redirect_stdout(sys.stderr): yield finally: os.dup2(saved_fd, stdout_fd) os.close(saved_fd) return _suppress() def _is_broadcast_error(e: Exception) -> bool: """Return True if the exception is caused by a broadcast shape mismatch.""" return "broadcast" in str(e).lower() or "shape" in str(e).lower() def _fetch_single_windows(collection: str, band: str, dlim: list, bbox: list, image_size: int = 300) -> list: """ Fetch data by splitting the date range dlim into individual windows. This is a helper to avoid broadcast shape mismatch errors that occur when je.ImageProcess processes multiple timesteps at once. Strategy: 1. First attempt get_images() over the full dlim range. 2. If a broadcast error occurs during ImageProcess, split the date range in half and process each half recursively. 3. If the range shrinks to a single day and still fails, skip it. Returns: list of je.ImageData (one entry per timestep) """ ppu = image_size / (bbox[2] - bbox[0]) fmt = "%Y-%m-%dT%H:%M:%S" start_dt = datetime.strptime(dlim[0][:19], fmt) end_dt = datetime.strptime(dlim[1][:19], fmt) # If the minimum unit (same day), attempt only once. def _try_fetch(d0: datetime, d1: datetime): d0s = d0.strftime(fmt) d1s = d1.strftime(fmt) with suppress_stdout(): data = je.ImageCollection(collection=collection, ssl_verify=False)\ .filter_date(dlim=[d0s, d1s])\ .filter_resolution(ppu=ppu)\ .filter_bounds(bbox=bbox)\ .select(band=band)\ .get_images() return data def _collect(d0: datetime, d1: datetime) -> list: try: data = _try_fetch(d0, d1) # Call show_images as a probe to detect broadcast errors. with suppress_stdout(): je.ImageProcess(data).show_images() return [data] except Exception as e: if not _is_broadcast_error(e): raise # Range is one day or less — give up and skip. if (d1 - d0).days < 1: return [] # Split in half and recurse. mid = d0 + (d1 - d0) / 2 left = _collect(d0, mid) right = _collect(mid + timedelta(seconds=1), d1) return left + right return _collect(start_dt, end_dt) # Search collection ID/Bands from JAXA Earth API @mcp.tool() async def search_collections_id(): """ This guide introduces data available through the JAXA Earth API. Based on user requests, please select and respond with the appropriate dataset ID and bands. Returned text is a list of all datasets available via the JAXA Earth API. Each parameter is described as follows: id: The ID uniquely identifying the dataset title: The dataset title description: The dataset description bands: The IDs of the data included in the dataset. Multiple IDs are separated by commas (,). keywords: Keywords associated with the dataset. These include the name of the observing satellite or the space agency managing the data. startDate: A string representing the start date and time of the dataset period in ISO8601 format. This indicates that data from this date onward is available. endDate: A string representing the end date and time of the dataset period in ISO8601 format. A value of "present" indicates that data is still being updated daily. The end of each dataset's parameters is indicated by "---". bbox: Geographic extent of the dataset (bounding box). For EPSG:4326, this indicates [minimum longitude, minimum latitude, maximum longitude, maximum latitude]. For EPSG:3995 or EPSG:3031, this indicates [minimum X, minimum Y, maximum X, maximum Y]. epsg: EPSG code indicating the projection method of the dataset. Based on the above, please select the optimal dataset ID and bands for the user's request and provide your response. """ JE_TEXT_PATH = "https://data.earth.jaxa.jp/app/mcp/catalog.v2.md" with suppress_stdout(): response = requests.get(JE_TEXT_PATH, verify=False) je_text = response.text return je_text # Show satellite image using JAXA Earth API @mcp.tool() async def show_images( collection: str = "JAXA.EORC_ALOS.PRISM_AW3D30.v3.2_global", band: str = "DSM", dlim: list[str, str] = ["2021-01-01T00:00:00", "2021-01-01T00:00:00"], bbox: list[float, float, float, float] = [135.0, 37.5, 140.0, 42.5], ) -> list: """Show satellite image using JAXA Earth API based on user input. Args: collection: JAXA Earth API collection ID. band: band name in the collection. dlim: date range limit request from start to end. yyyy-mm-ddThh:mm:ss formatted date string. bbox: Geographic extent of the dataset (bounding box). For EPSG:4326, this indicates [minimum longitude, minimum latitude, maximum longitude, maximum latitude]. """ windows = _fetch_single_windows(collection, band, dlim, bbox) output = [] for data in windows: with suppress_stdout(): img_data = je.ImageProcess(data).show_images(output="buffer") for png_buffer in img_data.png_buffers: output.append(Image(data=png_buffer, format="png")) return output # Calculate satellite data's spatial statistics using JAXA Earth API @mcp.tool() async def calc_spatial_stats( collection: str = "JAXA.EORC_ALOS.PRISM_AW3D30.v3.2_global", band: str = "DSM", dlim: list[str, str] = ["2021-01-01T00:00:00", "2021-01-01T00:00:00"], bbox: list[float, float, float, float] = [135.0, 37.5, 140.0, 42.5], ) -> list: """Calculate satellite data's spatial statistics values using JAXA Earth API based on user input. Args: collection: JAXA Earth API collection ID. band: band name in the collection. dlim: date range limit request from start to end. yyyy-mm-ddThh:mm:ss formatted date string. bbox: Geographic extent of the dataset (bounding box). For EPSG:4326, this indicates [minimum longitude, minimum latitude, maximum longitude, maximum latitude]. """ windows = _fetch_single_windows(collection, band, dlim, bbox) timeseries = [] for data in windows: with suppress_stdout(): img_data = je.ImageProcess(data).calc_spatial_stats() timeseries.extend(img_data.timeseries) return timeseries # Show satellite data's spatial statistics using JAXA Earth API @mcp.tool() async def show_spatial_stats( collection: str = "JAXA.EORC_ALOS.PRISM_AW3D30.v3.2_global", band: str = "DSM", dlim: list[str, str] = ["2021-01-01T00:00:00", "2021-01-01T00:00:00"], bbox: list[float, float, float, float] = [135.0, 37.5, 140.0, 42.5], ) -> list: """Show satellite data's spatial statistics result image using JAXA Earth API based on user input. Args: collection: JAXA Earth API collection ID. band: band name in the collection. dlim: date range limit request from start to end. yyyy-mm-ddThh:mm:ss formatted date string. bbox: Geographic extent of the dataset (bounding box). For EPSG:4326, this indicates [minimum longitude, minimum latitude, maximum longitude, maximum latitude]. """ windows = _fetch_single_windows(collection, band, dlim, bbox) output_images = [] for data in windows: with suppress_stdout(): img_data = je.ImageProcess(data)\ .calc_spatial_stats()\ .show_spatial_stats(output="buffer") for png_buffer in img_data.png_buffers_stats: output_images.append(Image(data=png_buffer, format="png")) return output_images # Get satellite raster numerical data using JAXA Earth API @mcp.tool() async def get_raster_data( collection: str = "JAXA.EORC_ALOS.PRISM_AW3D30.v3.2_global", band: str = "DSM", dlim: list[str, str] = ["2021-01-01T00:00:00", "2021-01-01T00:00:00"], bbox: list[float, float, float, float] = [135.0, 37.5, 140.0, 42.5], ) -> list: """Get satellite raster numerical data (numpy array) using JAXA Earth API based on user input. Returns a list of dicts (one per acquisition date), each containing: shape, dtype, min, max, mean, std, and the raw pixel values as a nested list. Args: collection: JAXA Earth API collection ID. band: band name in the collection. dlim: date range limit request from start to end. yyyy-mm-ddThh:mm:ss formatted date string. bbox: Geographic extent of the dataset (bounding box). For EPSG:4326, this indicates [minimum longitude, minimum latitude, maximum longitude, maximum latitude]. """ windows = _fetch_single_windows(collection, band, dlim, bbox) results = [] for data in windows: with suppress_stdout(): img = je.ImageProcess(data).show_images() raster_array = img.raster.img[0] results.append({ "shape": list(raster_array.shape), "dtype": str(raster_array.dtype), "min": float(np.nanmin(raster_array)), "max": float(np.nanmax(raster_array)), "mean": float(np.nanmean(raster_array)), "std": float(np.nanstd(raster_array)), "data": raster_array.tolist(), }) return results # Show satellite image AND return spatial statistics together using JAXA Earth API @mcp.tool() async def show_images_and_stats( collection: str = "JAXA.EORC_ALOS.PRISM_AW3D30.v3.2_global", band: str = "DSM", dlim: list[str, str] = ["2021-01-01T00:00:00", "2021-01-01T00:00:00"], bbox: list[float, float, float, float] = [135.0, 37.5, 140.0, 42.5], ) -> dict: """Show satellite image(s) AND return spatial statistics (numerical values) together using JAXA Earth API based on user input. Returns a dict with: "images" : list of PNG Image objects for visual display (one per acquisition date) "stats_images": list of PNG Image objects for statistics plots "timeseries" : list of dicts with datetime + statistical values (mean, min, max, std, etc.) per acquisition date — use these numbers to answer quantitative questions "summary" : dict summarising overall min/max/mean/std across all dates and the bbox If a broadcast error occurs across multiple timesteps, the date range is automatically split and processed. Args: collection: JAXA Earth API collection ID. band: band name in the collection. dlim: date range limit request from start to end. yyyy-mm-ddThh:mm:ss formatted date string. bbox: Geographic extent of the dataset (bounding box). For EPSG:4326, this indicates [minimum longitude, minimum latitude, maximum longitude, maximum latitude]. """ windows = _fetch_single_windows(collection, band, dlim, bbox) images = [] stats_images = [] timeseries = [] all_arrays = [] for data in windows: with suppress_stdout(): # Image img_buf = je.ImageProcess(data).show_images(output="buffer") # Statistics plot img_stats = je.ImageProcess(data)\ .calc_spatial_stats()\ .show_spatial_stats(output="buffer") # Raster array (for summary) img_raster = je.ImageProcess(data).show_images() for buf in img_buf.png_buffers: images.append(Image(data=buf, format="png")) for buf in img_stats.png_buffers_stats: stats_images.append(Image(data=buf, format="png")) timeseries.extend(img_stats.timeseries) all_arrays.append(img_raster.raster.img[0]) # Concatenate all timesteps and compute summary combined = np.concatenate([a.flatten() for a in all_arrays]) summary = { "band" : band, "collection" : collection, "bbox" : bbox, "period" : {"start": dlim[0], "end": dlim[1]}, "n_dates" : len(windows), "min" : float(np.nanmin(combined)), "max" : float(np.nanmax(combined)), "mean" : float(np.nanmean(combined)), "std" : float(np.nanstd(combined)), "valid_pixels": int(np.sum(~np.isnan(combined.astype(float)))), "total_pixels": int(combined.size), } return { "images" : images, "stats_images": stats_images, "timeseries" : timeseries, "summary" : summary, } # Run the MCP server def main(): mcp.run(transport='stdio') if __name__ == "__main__": main()