#!/usr/bin/env python # coding: utf-8 # ## Vision-Language Model Training on Sentinel-2 Imagery # This tutorial constructs a training dataset from a [Planetary Computer](https://planetarycomputer.microsoft.com/) data collection and then fine-tunes Microsoft's [Florence-2 Vision-Language Model](https://huggingface.co/microsoft/Florence-2-base-ft) on that dataset. In this notebook we'll: # - Generate a physically uniform distribution of `n` coordinates # - Collect labels for the geographic regions using [OpenStreetMap](https://www.openstreetmap.org/) # - Collect and download images from the Planetary Computer's [STAC API](https://github.com/radiantearth/stac-api-spec) # - Generate natural-language captions using Azure OpenAI # - Fine-tune a Vision-Language Model on the generated dataset # This example uses [Sentinel-2 Level-2A](https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a) data. The techniques in this notebook can be applied (with minor modifications) to other remote-sensing datasets. # ### Getting Started # You will need to download the `map.json` file from the [PlanetaryComputerExamples GitHub repository](https://github.com/microsoft/PlanetaryComputerExamples) under `tutorials/assets/map.json` as a prerequisite to running the code cells in this notebook. It is a file that contains Polygons representing the land areas on Earth for our point sampling. You may edit this file to only include areas you wish to sample from either to match the geographic availability of the dataset or to create more specialized models. # Below are all the packages that you will need to run this notebook. # In[ ]: # Imports import os import json import shapely from shapely.ops import transform from shapely.geometry import MultiPolygon, Polygon, LineString, box import matplotlib.pyplot as plt import math import random import pandas import csv import folium import overpass import overpy import pystac_client import planetary_computer import geopandas import dask_geopandas import pyproj import rasterio import rasterio.mask import ast # Modules for step IV. Florence-2 training from datasets import Dataset, DatasetDict from sklearn.model_selection import train_test_split from transformers import AutoModelForCausalLM, AutoProcessor import torch from torch.utils.data import Dataset as TorchDataset from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AdamW, get_scheduler from PIL import Image # Dependencies # numpy # flash_attn # einops # timm # In[3]: from openai import AzureOpenAI from azure.identity import DefaultAzureCredential # Below, we set the parameters of how many data points to generate, and where to save the data. `n` is how many data points we create. `d` is the square grid side length in degrees. `DATA_DIR` specifies the master data directory. After each step in the process, we save the collected data into intermediate `.csv` files. This step becomes necessary as the number of samples `n` increases. # In[4]: # Set Hyperparameters n = 100 # number of areas to randomly sample DATA_DIR = "data/tutorial_data/" # Set the directory for the data d = 0.05 # degree measurement size NPARTITIONS = int(n / 50) # Set dataset filepaths LAND_MAP = "map.json" # Set the directory for the Polygons of the the land masses # Directory to store the .csv file of randomly sampled grids GRIDS = f"{DATA_DIR}grids.csv" # Store the query .json file path in .csv GRIDS_TAGS = f"{DATA_DIR}grids_tags.csv" # Store collected tags and metadata GRIDS_TAGS_IMAGES = f"{DATA_DIR}grids_tags_images.csv" IMAGES_TAGS = f"{DATA_DIR}images_tags.csv" # Store the collected images and tag lists IMAGES_CAPTIONS = f"{DATA_DIR}images_captions.csv" # Store the generated captions IMAGES_DIR = f"{DATA_DIR}images/" # Directory for cropped images MODEL_DIR = f"{DATA_DIR}model/" # Directory for the model files print(f"Master data directory: {DATA_DIR}") print(f"Generate {n} grids with size {d} x {d}") # Make the directories if they do not already exist. # In[4]: os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(IMAGES_DIR, exist_ok=True) os.makedirs(MODEL_DIR, exist_ok=True) print(f"Directories {DATA_DIR}, {IMAGES_DIR}, and {MODEL_DIR} created") # ### I. Object Collection # This section uses the OverPass API to query OpenStreetMap (OSM). We randomly select geographic regions all over the world and from these regions, collect all of the OSM tags associated within the bounding box of the grid. The objects selected are stored in a `.csv` file for reference by future code blocks. # #### 1. Generating the Grids # # Randomly select `n` (ex. 100) grids across the globe, each with size `d`° x `d`° (ex. 0.05°). # To sample data from land areas only, we supply a GeoJson file `map.json` as a collection of polygons representing land masses on Earth. It is a somewhat rough estimate, but for our purposes it allows us to sample images that contain the features we are interested in (land formations and vegetation). Feel free to edit the `map.json` file to only include areas you wish to sample from! Here we plot these polygons for visualization: # In[5]: # Load GeoJSON of land polygons with open(LAND_MAP) as file: land = json.load(file) # Create a MultiPolygon from the land geometries by extracting the coordinates ob = MultiPolygon( [Polygon((feature["coordinates"])[0]) for feature in land["geometries"]] ) # Plot Multipolygon fig, ax = plt.subplots() ax.set_aspect("equal") for polygon in ob.geoms: x, y = polygon.exterior.xy ax.plot(x, y) plt.show() # For our random sampling process, we will use Gaussian-distributed random vectors normalized to the unit sphere for a physically uniform sampling of locations. # In[6]: def generate_random_coordinate() -> tuple[float, float]: """ This function generates a random coordinate on the Earth's surface using a uniform normal distribution on the unit sphere. Returns: (longitude, latitude): A tuple of the longitude and latitude coordinates of the generated coordinate, in this order. """ gx = random.gauss(0.0, 1.0) gy = random.gauss(0.0, 1.0) gz = random.gauss(0.0, 1.0) norm2 = gx**2 + gy**2 + gz**2 norm1 = 1.0 / math.sqrt(norm2) x = gx * norm1 y = gy * norm1 z = gz * norm1 radLat = math.asin(z) radLon = math.atan2(y, x) lat = math.degrees(radLat) lon = math.degrees(radLon) return lon, lat print("Function generate_random_coordinate defined!") # *Some important values* # - 1 degree of latitude = 69 miles # - 1 degree of longitude at equator = 69.172 miles # - 1 degree of longitude = cosine (latitude in radians) * length of degree (miles) at equator # In[7]: LONGITUDE = 69.172 # 1 degree of longitude is approximately 69.172 miles at the equator LATITUDE = 69 # 1 degree of latitude is approximately 69 miles # Now that we have the points as the centers for our bounding boxes, we are ready to create the boxes. Latitude lines are very evenly separated at approximately 69 miles no matter your longitude. Yet, longitude lines are the furthest apart at the equator at ~69.172 miles and eventually approach 0 miles apart as you approach the poles. This means that we cannot simply add and subtract `d/2` degrees to each coordinate. We need to account for this ratio of longitude separation according to the latitude. So our grids are standardized by the longitude and latitude of degrees `d` at the *equator* so ~ `d` * 69mi x `d` * 69mi sized grids. # To avoid the hassle of splitting a grid into two geometries for crossing the latitude or longitude degree boundaries, we remove this possibility. # In[8]: def not_crossing_antimeridian(lon: float, lat: float, distance: float) -> bool: """ This function ensures that the generated coordinate would not cross the antimeridian when a square of a given distance is centered at the (lon, lat) coordinate. Args: lon: float representing the decimal longitude of the coordinate you wish to check lat: float representing the decimal latitude of the coordinate you wish to check Returns: bool: True if generated coordinate would not cross the antimeridian, False otherwise. """ distance = ( distance / 2 ) # distance is the length of the side of the square, so we need to divide by 2 distance = distance * LONGITUDE # Grid dimensions in miles # convert latitude to radians lat_rad = math.radians(lat) long_degree = math.cos(lat_rad) * LONGITUDE longitude = distance / long_degree latitude = distance / LATITUDE minX = lon - longitude minY = lat - latitude maxX = lon + longitude maxY = lat + latitude if ( (minX < -180) or (minY < -90) or (maxX > 180) or (maxY > 90) ): # Goes out of bounds return False return True print("Function not_crossing_antimeridian defined!") # This function generates random coordinates and checks to see if the coordinate is in the bounds of the `map.json` Polygons. This continues until we have `n` valid points. # In[9]: def random_points_in_bounds(n: int) -> list[tuple[float, float]]: """ This function generates n random coordinates located in the land boundaries of Earth's surface. Args: n: The number of random coordinates to generate. Returns: A list of tuples of the longitude and latitude coordinates of the generated coordinates. """ # Load GeoJSON of land polygons with open(LAND_MAP) as file: land = json.load(file) # Create a MultiPolygon from the land geometries by extracting the coordinates land_areas = MultiPolygon( [Polygon((feature["coordinates"])[0]) for feature in land["geometries"]] ) points = [] # Continuously generate a random coordinate and check if it is in the MultiPolygon while len(points) < n: point = generate_random_coordinate() valid = False valid = shapely.contains_xy( land_areas, point[0], point[1] ) and not_crossing_antimeridian(point[0], point[1], d) if valid: points.append(point) return points print("Function random_points_in_bounds defined!") # In[10]: def create_bbox(lon: float, lat: float, distance: float) -> list[float]: """ This function generates a square bounding box of degree d x d (normalized to latitude and longitude ratios) with the provided point located in the center. Args: lon: The longitude of the coordinate. lat: The latitude of the coordinate. distance: The length of the side of the square bounding box in degrees. Returns: A bounding box as a list of the form [minX, minY, maxX, maxY] """ distance = ( distance / 2 ) # distance is the length of the side of the square, so we need to divide by 2 distance = distance * LONGITUDE # Grid dimensions in miles # convert latitude to radians lat_rad = math.radians(lat) long_degree = math.cos(lat_rad) * LONGITUDE longitude = distance / long_degree latitude = distance / LATITUDE minX = lon - longitude minY = lat - latitude maxX = lon + longitude maxY = lat + latitude # Point cannot go out of bounds because it is already checked in the # random_points_in_bounds function return [minX, minY, maxX, maxY] print("Function create_bbox defined!") # The following code block generates the `n` random square grids and saves the sampled coordinates to the `.csv` file specified at the path `GRIDS`. # In[11]: # Generate n random square grids grids = [] for point in random_points_in_bounds(n): minX, minY, maxX, maxY = create_bbox(point[0], point[1], d) grids.append([minX, minY, maxX, maxY]) # Create .csv file for sampled coordinates to be saved df = pandas.DataFrame(grids, columns=["minX", "minY", "maxX", "maxY"]) df["uid"] = range(1, n + 1) df.to_csv(GRIDS, index=False) print(f"Dataframe saved to {GRIDS}") # Now that we have randomly sampled our `n` points from Earth, let us plot the selected grids on a map for visualization of the generated distribution. # In[12]: # Open the csv file in reader mode with open(GRIDS, "r") as csvfile: reader = csv.reader(csvfile) grids = list(reader)[1:] # Create a map centered at a specific location m = folium.Map(location=[0, 0], zoom_start=2) # Iterate over the grids and create a rectangle for each bounding box for bbox in grids: min_lon, min_lat, max_lon, max_lat = bbox[0], bbox[1], bbox[2], bbox[3] rectangle = folium.Rectangle( bounds=[(min_lat, min_lon), (max_lat, max_lon)], color="blue", fill=True, fill_color="blue", fill_opacity=0.3, ) rectangle.add_to(m) # Display the map m # #### 2. Reading the data from the CSV file # Instead of generating a new dataset every time we want to test the next section, let us read the data from the `GRIDS` `.csv` file into a `pandas` dataframe so we can start here as a checkpoint. # In[13]: # Read the CSV file into a DataFrame grids_df = pandas.read_csv(GRIDS) # Display the DataFrame grids_df # #### 3. Query the objects located in these grids from the OSM database using the [OverPass API](https://overpass-turbo.eu/#) # Each object in OpenStreetMap is described by one or more *tags*. Each *tag* consists of two text fields: *key* and *value*. A *key* is a category or type of feature (e.g., "surface"), while a *value* describes the specific feature or subcategory given the *key* (e.g, "asphalt"). # For training a Vision-Language Model like Microsoft's Florence-2 Model, you need *labeled* imagery. These labels can come from a variety of sources, but for this tutorial I chose to use OpenStreetMap. OpenStreetMap is a human curated, 'wikipedia' for land areas. Every road, building, traffic light, river, lake, you name it... is included as a *tag* for a geometry. **Note**: For this notebook, as long as the format for the tag lists remains the same as input (a list of [*key*, *value*] pairs ex. `[['railway', 'rail'], ['waterway', 'river'], ['road', 'track']]`) it will have no impact on `II. Image Collection` or `III. Caption Generation`. Here, for each grid, we query all objects from OpenStreetMap and save the tags to a column of the .csv file. # For simplicity, and because "noisy but semantically diverse image-text datasets [are capable]" ([SkyScript](https://doi.org/10.48550/arXiv.2312.12856)), we will simply apply all "relevant" OSM tags. The meaning of "relevant tags" or "accepted keys" will be defined below. # In[14]: # Connect to the overpass API overpass_api = overpass.API( endpoint="https://overpass-api.de/api/interpreter", timeout=100 ) # View the [most common tags](https://taginfo.openstreetmap.org/tags) in OpenStreetMap # The `accepted_keys` variable holds the "relevant tags" that can be associated with land formations and objects we can identify from Sentinel-2. Feel free to update this variable to include additional keys. For example, when working with [NAIP](https://planetarycomputer.microsoft.com/dataset/naip) data, which has a much smaller GSD, you may want to include type of building or crosswalks. # In[15]: # Collect the tags for each grid tag_lists = [] gsd = ( 10 # Ground Sample Distance in meters for sentinel-2, differs for other collections ) accepted_keys = [ "natural", "landuse", "leisure", "amenity", "building", "waterway", "aeroway", "highway", "surface", "service", "leaf_type", "leaf_cycle", "railway", "parking", "generator", "water", "material", "smoothness", "tracktype", "place", "ford", "wetland", "crop", "addr:city", ] # We rename some of the above keys to have a more present day semantic meaning for humans. Example: 'aeroway' is changed to 'airport'. # In[16]: # Rename certain tag names def rename_tags(tag_list: list[list[str, str]]) -> list[list[str, str]]: """ This function renames keys and values in the tag_list. Args: tag_list: A list of [key, value] pairs that represent object tags. Returns: (longitude, latitude): A tuple of the longitude and latitude coordinates of the generated coordinate, in this order. """ for tag in tag_list: if tag[0] == "highway": if tag[1] != "primary" and tag[1] != "motorway" and tag[1] != "truck road": tag[0] = "road" elif tag[0] == "aeroway": tag[0] = "airport" elif tag[0] == "lit": tag[0] = "light" elif tag[0] == "leisure": tag[0] = "leisure land" elif tag[0] == "addr:city": tag[0] = "city" tag[1] = "yes" return tag_list print("Function rename_tags defined!") # When working with geospatial data, it is always imperative to double check the coordinate system -- and if they are using latitude and longitude, which order they take those values. Planetary Computer and the functions above are saved in a `minX`, `minY`, `maxX`, `maxY` system, an XY coordinate understanding similar to the Cartesian plane. However, it is a convention to use latitude, longitude for some areas (you may have learned it this way yourself). OpenStreetMap uses such a convention. Always keep your eyes peeled for these changes! # "Bounding box clauses always start with the lowest latitude (southernmost) followed by lowest longitude (westernmost), then highest latitude (northernmost) then highest longitude (easternmost)" ([OpenStreetMap Wiki](https://wiki.openstreetmap.org/wiki/Overpass_API)) # ==> OSM queries are in the format `minY`, `minX`, `maxY`, `maxX` # The following cell will take around 2 minutes for `n` = 100 # In[17]: MILE = 1609 # 1 mile is 1609 meters for i in range(len(grids_df)): # Query each grid saved in grids.csv for OSM objects if len(grids_df) > 10000: print( "Too many queries! The API for OpenStreetMap, OverPass, has a limit" "of 10,000 queries per day. Please do not overload the API." ) break tags = set() # Need to switch our lon, lat convention for OSM query # This query gathers all nodes that fall within the grid minX, minY, maxX, maxY = ( grids_df["minX"][i], grids_df["minY"][i], grids_df["maxX"][i], grids_df["maxY"][i], ) query = f"(node({minY},{minX},{maxY},{maxX});<;);" # Get the result of the query result_json = overpass_api.get(query, responseformat="json") # parse the JSON file into a readable format using OverPy result = overpy.Result.from_json(result_json) ways = result.get_elements(filter_cls=overpy.Way) for way in ways: locations = way._node_ids locations = [result.get_nodes(x) for x in locations] locations = [[node[0].lat, node[0].lon] for node in locations if node != []] # Check if way is closed, then calculate the area ==> # remove tags from objects that are too small to see given gsd latitude_ratio = math.cos(math.radians((maxY + minY) / 2)) minimum_area = (gsd * gsd) / (latitude_ratio * MILE * LATITUDE) ** 2 if not ( (len(locations) > 3) and (LineString(locations).is_closed) and (Polygon(locations).area > minimum_area) ): for tag in way.tags: if tag in accepted_keys: pair = (tag, way.tags[tag]) tags.add(pair) tag_list = rename_tags([list(tag) for tag in list(tags)]) tag_lists.append(tag_list) # Add the tag lists to the dataframe grids_df["tags"] = tag_lists grids_df # In[18]: # Save the dataframe to a CSV file grids_df.to_csv(GRIDS_TAGS, index=False) print(f"Dataframe saved to {GRIDS_TAGS}") # Now that we have successfully queried OSM for all of the randomly sampled grids, let us parse the JSON files into a readable format using [OverPy](https://python-overpy.readthedocs.io/en/latest/). In the following cell, we visualize an arbitrarily selected grid and overlap the OSM tags over it. # The following cell may take a while depending on the number of objects contained inside the area. # In[19]: # Select a grid to visualize with the OSM data overlayed i = 1 minY, minX, maxY, maxX = ( grids_df["minY"][i], grids_df["minX"][i], grids_df["maxY"][i], grids_df["maxX"][i], ) # Create a map centered at a specific location m = folium.Map( location=[ (maxY + minY) / 2, (maxX + minX) / 2, ], # Center the map at the center of the bounding box zoom_start=14, min_lat=minY, max_lat=maxY, min_lon=minX, max_lon=maxX, control_scale=True, ) # Create a rectangle for the bounding box and display it on the map rectangle = folium.Rectangle( bounds=[(minY, minX), (maxY, maxX)], color="blue", fill=True, fill_color="blue", fill_opacity=0.3, ) rectangle.add_to(m) # Load the cached query results for the bounding box query = f"(node({minY},{minX},{maxY},{maxX});<;);" print(query) data = overpass_api.get(query, responseformat="json") result = overpy.Result.from_json(data) # Iterate over the ways in the query results and create a polygon for each way ways = result.get_elements(filter_cls=overpy.Way) for way in ways: locations = way._node_ids locations = [result.get_nodes(x) for x in locations] locations = [[node[0].lat, node[0].lon] for node in locations if node != []] for location in locations: folium.Circle( location=location, radius=20, ).add_to(m) if len(locations) != 1: linestring = LineString(locations) if linestring.is_closed: folium.Polygon( locations=locations, color="red", weight=2, fill=True, fill_color="red", fill_opacity=0.3, ).add_to(m) else: folium.PolyLine( locations=locations, color="blue", weight=2, ).add_to(m) else: folium.PolyLine( locations=locations, color="blue", weight=2, ).add_to(m) # Display the map m # Now that we have queried OSM to obtain all the objects in the randomly selected grids, we will move onto `II. Image Collection` for the given grids! # ### II. Image Collection # This section uses the STAC API to query the Planetary Computer in order to collect the images subject to the randomly generated grids, the steps of which can be found in the previous section `I. Object Collection`. # To access the data, we’ll create a `pystac_client.Client`. The `modifier` part is what lets us download the data assets Azure Blob Storage. # In[20]: # Connect to the Planetary Computer's data catalogs catalog = pystac_client.Client.open( "https://planetarycomputer.microsoft.com/api/stac/v1", modifier=planetary_computer.sign_inplace, ) # Below, we read in the data from the `.csv` file that was saved in the previous section. # In[21]: # Read the CSV file into a DataFrame df = pandas.read_csv(GRIDS_TAGS) # Display the DataFrame df # #### 1. Create partitioned geometries of the sampled grids # We can use the STAC API to search for assets meeting some criteria. This might include the date and time the asset covers, is spatial extent, or any other property captured in the STAC item’s metadata. First, to make querying the STAC API faster, we partition the grids into partitions of 50 points each, by spatial locality, increasing the efficiency of the query protocol (which becomes more efficient per partition the greater `n` becomes) # In[22]: # Create the geometries for the boxes goems = [ box(*row) for row in df[["minX", "minY", "maxX", "maxY"]].itertuples(index=False) ] gdf = geopandas.GeoDataFrame(df, geometry=goems) # In[23]: # Use dask_geopandas to handle large dataframes of geospatial data ddf = dask_geopandas.from_geopandas(gdf, npartitions=1) hd = ddf.hilbert_distance().compute() gdf["hd"] = hd gdf = gdf.sort_values( "hd" ) # Sort the data by the Hilbert distance (spatial locality -- similar hd values are close) dgdf = dask_geopandas.from_geopandas(gdf, npartitions=NPARTITIONS, sort=False) dgdf.compute() # Compute the dask dataframe # #### 2. Iterate over the partitions, saving the STAC metadata and the cropped images # In[24]: metadata_df = pandas.DataFrame() # We use `rasterio` in order to read in the .tif files from the STAC Item. For our training purposes, we are interested in the visual imagery .href .tif. In future experimentation with fine-tuning or training VLMs including data other than just RGB (ex. infrared, elevation, vegetation, etc.) is very interesting for a model to understand *all aspects* of a geographic location. # In[25]: for i in range(NPARTITIONS): print(f"Calculating partition {i}...") chunk = dgdf.to_delayed()[i].compute() query = catalog.search( collections=["sentinel-2-l2a"], intersects=chunk.unary_union, datetime="2024-01-01/2024-06-01", ) items = query.item_collection() if items != []: geodf = geopandas.GeoDataFrame.from_features(items.to_dict()["features"]) geodf["assets"] = [item.assets for item in items] print(f"Saving images to {IMAGES_DIR}...") for j in range(len(chunk)): uid, minY, minX, maxY, maxX = ( chunk["uid"].iloc[j], chunk["minY"].iloc[j], chunk["minX"].iloc[j], chunk["maxY"].iloc[j], chunk["maxX"].iloc[j], ) grid_bbox = box(minX, minY, maxX, maxY) filtered_items = geodf[geodf.contains(grid_bbox)] if not filtered_items.empty: selected_item_index = filtered_items["eo:cloud_cover"].idxmin(axis=0) selected_item = filtered_items.loc[selected_item_index] water = selected_item["s2:water_percentage"] snow = selected_item["s2:snow_ice_percentage"] vegetated = selected_item["s2:vegetation_percentage"] cloud = selected_item["eo:cloud_cover"] metadata_df = pandas.concat( [ metadata_df, pandas.DataFrame([[uid, water, snow, vegetated, cloud]]), ], ignore_index=True, ) if not os.path.exists(f"{IMAGES_DIR}/{uid}.tif"): with rasterio.open(selected_item.assets["visual"].href) as src: # Create a Window and calculate the transform from the dataset project = pyproj.Transformer.from_crs( pyproj.CRS("EPSG:4326"), src.crs, always_xy=True ).transform grid_bbox = transform(project, grid_bbox) out_image, out_transform = rasterio.mask.mask( src, [grid_bbox], crop=True ) out_meta = src.meta # Create a new cropped raster to write to out_meta.update( { "driver": "GTiff", "height": out_image.shape[1], "width": out_image.shape[2], "transform": out_transform, } ) # Save the cropped image into a new .tif file in IMAGES_DIR with rasterio.open( f"{IMAGES_DIR}/{uid}.tif", "w", **out_meta ) as dest: dest.write(out_image) print(f"Images saved to {IMAGES_DIR}") # #### 3. Merge the metadata DataFrames together then with the grids_images_tags df # In[26]: metadata_df.columns = [ "uid", "water_percentage", "snow_ice_percentage", "vegetated_percentage", "cloud_cover", ] metadata_df # In[27]: df = pandas.read_csv(GRIDS_TAGS) images = [] for i in range(n): uid = df["uid"].iloc[i] images.append(f"{uid}.tif") df = df.assign(image=images) df # In[28]: # Merge the metadata DataFrame with the original DataFrame (GRIDS_QUERIES_TAGS) df = pandas.merge(df, metadata_df, on=["uid"]) df # In[29]: # Save the dataframe with uid, minX, minY, maxX, maxY, tags, image, # water_percentage, snow_ice_percentage, vegetated_percentage, # cloud_cover to a .csv file df.to_csv(GRIDS_TAGS_IMAGES, index=False) print(f"Dataframe saved to {GRIDS_TAGS_IMAGES}") # We will use the images collected and the saved tag lists during object collection to generate natural language captions that will be paired with the images in `III. Caption Generation` # ### III. Caption Generation # This notebook takes the images selected in `II. Image Collection` and the objects selected in `I. Object Collection` and combines them into a `.csv` file for training. # Read in the `GRIDS_TAGS_IMAGES` file such that we have access to the previously collected bounding boxes, tag lists, image file, and STAC metadata (`uid`, `minX`, `minY`, `maxX`, `maxY`, `tags`, `image`, `water_percentage`, `snow_ice_percentage`, `vegetated_percentage`, and `cloud_cover`) # In[33]: # Read the CSV file into a DataFrame grids_images_df = pandas.read_csv(GRIDS_TAGS_IMAGES) # Display the DataFrame grids_images_df # #### 1. Assembing the tag list strings from the tags # We have the tag lists from the `I. Object Collection` and now we add the tags for the metadata collected in `II. Image Collection`! # In[34]: # Add tags for metadata columns def add_metadata_tags( tag_list: list[list[str, str]], csv_filepath, row ) -> list[list[str, str]]: df = pandas.read_csv(csv_filepath) metadata_keys = [ "water_percentage", "snow_ice_percentage", "vegetated_percentage", "cloud_cover", ] for key in metadata_keys: value = df[key][row] if value > 95: tag_list.append((key, "very high")) if value > 80: tag_list.append((key, "high")) elif value > 40: tag_list.append((key, "medium")) elif value > 10: tag_list.append((key, "low")) elif value > 1: tag_list.append((key, "very low")) elif value > 0.01: tag_list.append((key, "almost none")) else: tag_list.append((key, "none")) return tag_list print("Function add_metadata_tags defined!") # For each grid-image-tags pair, concatenate the key-value pairs into an easy-to-interpret list for the LLM # In[35]: captions = [] for i in range(len(df)): tags = df["tags"][i] tag_list = list(ast.literal_eval(tags)) # list of tuples to list of lists tag_list = [list(tag) for tag in tag_list] tag_list = add_metadata_tags(tag_list, GRIDS_TAGS_IMAGES, i) caption = "" for tag in tag_list: caption += tag[0] + " : " + tag[1] + ", " captions.append(caption[0:-2]) # Save the tag lists to the path specified under the global variable `IMAGES_TAGS` # In[36]: # Create a new pandas dataframe images_tags_df = pandas.DataFrame() # Save the descriptions as a new column in the DataFrame images_tags_df = images_tags_df.assign(caption=captions) # Save the images as a new column in the DataFrame images_tags_df = images_tags_df.assign(image=grids_images_df["image"]) # Final DataFrame with Images and their captions images_tags_df.to_csv(IMAGES_TAGS, index=False) print(f"Dataframe saved to {IMAGES_TAGS}") # #### 2. Generating natural language captions from tag lists using GPT 3.5 (optional) # This step requires access to an OpenAI resource endpoint. This step uses ChatGPT to turn tag lists of [*key*, *value*] pairs into a natural language descriptions. Feel free to skip this step and head to `IV. Model Fine-Tuning` to avoid the costs associated with using Azure and Azure OpenAI. # ##### a. Prerequistes: Create an Azure OpenAI resource # [Create an Azure OpenAI resource](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
# To use your resource's endpoint, navigate to your newly created Azure OpenAI resource under Microsoft Azure. Under Resource Management, click on Keys & Endpoint. Copy and paste this endpoint into the `endpoint` variable into the cell below. # Create a model deployment for GPT 3.5. Go to your Azure OpenAI resource and click 'Model Deployments' then 'Manage Deployments'. Click 'Create new deployment' and enter 'my-gpt-35' as the deployment name and 'gpt-35-turbo' under Select a model. # In[37]: endpoint = "" # ##### b. Generate the captions # Read in the data from the .csv file created in step 1. Turn the tag lists into natural language captions using GPT 3.5 # In[38]: df = pandas.read_csv(IMAGES_TAGS) df # Specify the endpoint for your Azure OpenAI resource. To find your resource's endpoint, navigate to your created Azure OpenAI resource under Microsoft Azure. Under Resource Management, click on Keys & Endpoint. Copy and paste this endpoint into the `endpoint` variable. # In the cell below, we create a client to access our Azure OpenAI resource's endpoint through an API and authenticating by generating an `azure_ad_token`. These tokens must be refreshed after they expire, otherwise they produce an authentication error. To remedy this, we include a try, except clause when running our caption generation loop in case the token expires during cell execution. # # We also define a function `get_completion`, where we extract the LLM's output given the model and the prompt. # Open a terminal and run `az login --use-device-code` and follow the instructions if you encounter a `ClientAuthenticationError` # In[39]: credential = DefaultAzureCredential() azure_ad_token = credential.get_token( "https://cognitiveservices.azure.com/.default" ).token client = AzureOpenAI( azure_ad_token=azure_ad_token, api_version="2024-02-01", azure_endpoint=endpoint ) # Connect to OpenAI API def get_completion(model, prompt, new_client): messages = [{"role": "user", "content": prompt}] response = new_client.chat.completions.create( model=model, messages=messages, max_tokens=None, ) return response.choices[0].message.content # In the above Prerequiste instructions, you created a model deployment for gpt-35-turbo. This is where you specify the model you are using for the prompt completion. Enter the your model deployment name that you have registered with your resource. # In[43]: model = "my-gpt-35" # In[44]: images_captions_df = pandas.DataFrame() # For each image-tags pair we generate the natural language description from the tag lists and save this as a new pair in `images_captions_df`. The caption generation speed is limited by both the number of tokens per minute limit that was set during resource creation, but more likely the number of API calls allowed per minute. Ex. for 16K tokens per minute rate limit only 96 API calls are allowed per minute. Depending on the size of the tag lists, your limiting factor could change. # In[1]: print("Generating captions...") for i in range(len(df)): try: prompt = ( "Give a concise natural language caption of an image" "using the following image tag descriptors:" ) + df["caption"].iloc[i] completion = get_completion(model, prompt, client) image_url = df["image"].iloc[i] images_captions_df = pandas.concat( [images_captions_df, pandas.DataFrame([[image_url, completion]])], ignore_index=True, ) except Exception: print("Token Expired. Refreshing Token...") azure_ad_token = credential.get_token( "https://cognitiveservices.azure.com/.default" ).token client = AzureOpenAI( azure_ad_token=azure_ad_token, api_version="2024-02-01", azure_endpoint=endpoint, ) prompt = ( "Give a concise natural language caption of an image" "using the following image tag descriptors:" ) + df["caption"].iloc[i] completion = get_completion(model, prompt, client) image_url = df["image"].iloc[i] images_captions_df = pandas.concat( [images_captions_df, pandas.DataFrame([[image_url, completion]])], ignore_index=True, ) print("Generation complete!") # In[46]: images_captions_df.rename(columns={0: "image", 1: "caption"}, inplace=True) images_captions_df # Save the generated captions and the image names to a .csv file # In[47]: # Save df to csv images_captions_df.to_csv(IMAGES_CAPTIONS, index=False) # ### IV. Model Fine-Tuning # For this step, you will need access to a machine with a compute GPU # In[1]: # Test if torch is able to detect your compute GPU torch.cuda.is_available() # #### 1. Data Formatting # If you skipped the part where we turn the tag lists into natural language captions, replace `IMAGES_CAPTIONS` with `IMAGES_TAGS` # In[7]: images_captions_df = pandas.read_csv(IMAGES_CAPTIONS) # In the following cell, we format the images_captions pairs .csv into a dataframe format for Florence-2 Fine Tuning. For more information visit the [Huggingface article](https://huggingface.co/blog/finetune-florence2) # In[8]: # Rename captions to answers images_captions_df.rename(columns={"caption": "answers"}, inplace=True) # Wrap answers column in a list images_captions_df["answers"] = images_captions_df["answers"].apply(lambda x: [x]) # Add question id column images_captions_df["question_id"] = range(1, len(images_captions_df) + 1) # Add question column images_captions_df["question"] = "What does the image describe?" # Add question_type column images_captions_df["question_types"] = [["form"]] * len(images_captions_df) images_captions_df # In[9]: # Split the data into training and validation train_df, val_df = train_test_split(images_captions_df, test_size=0.1, random_state=42) # Transform train_df and val_df into Dataset() train_dataset = Dataset.from_pandas(train_df) val_dataset = Dataset.from_pandas(val_df) # Wrap train_dataset and val_dataset in DatasetDict() data = DatasetDict({"train": train_dataset, "validation": val_dataset}) # View the data structure data # #### 2. Model Definitions # We can load the model and processor using the `AutoModelForCasualLM` and `AutoProcessor` classes from the transformers library. We need to pass `trust_remote_code=True` because the model uses custom code -- it has not been natively integrated into transformers yet. We will also freeze the vision encoder to make fine-tuning less expensive. # In[ ]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForCausalLM.from_pretrained( "microsoft/Florence-2-base-ft", # Load model and processor from Florence-2-base-ft trust_remote_code=True, revision="refs/pr/6", ).to(device) processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base-ft", trust_remote_code=True, revision="refs/pr/6" ) for param in model.vision_tower.parameters(): param.is_trainable = False # Below we define the dataset class to hold the structure of the training data, including the image, question, and expected response. # In[11]: class SenDataset(TorchDataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): example = self.data[idx] question = example["question"] first_answer = example["answers"][0] image_path = IMAGES_DIR + example["image"] image = Image.open(image_path) image = image.convert("RGB") return question, first_answer, image # We'll now build the data collator that builds training batches from the dataset samples, and start training. In A100 with 40GB memory, we can fit in 6 examples. If you're training on T4, you can use a batch size of 1. # In[12]: # Define the collate function def collate_fn(batch): questions, answers, images = zip(*batch) inputs = processor( text=list(questions), images=list(images), return_tensors="pt", padding=True ).to(device) return inputs, answers train_dataset = SenDataset(data["train"]) # Define the training dataset val_dataset = SenDataset(data["validation"]) # Define the validation dataset batch_size = 2 num_workers = 0 train_loader = DataLoader( train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True, ) # #### 3. Model Training! # We can train the model now. # Run the following code on your compute GPU. Your model may take several hours to train depending on epochs and dataset size. With `n=100` it should only take a few minutes. # In[ ]: epochs = 5 # Define the number of epochs (times model will see training dataset) optimizer = AdamW(model.parameters(), lr=1e-6) num_training_steps = epochs * len(train_loader) lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps, ) for epoch in range(epochs): model.train() train_loss = 0 i = -1 for inputs, answers in tqdm( train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}" ): i += 1 input_ids = inputs["input_ids"] pixel_values = inputs["pixel_values"] labels = processor.tokenizer( text=answers, return_tensors="pt", padding=True, return_token_type_ids=False, max_length=1024, truncation=True, ).input_ids.to(device) outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels) loss = outputs.loss loss.backward() optimizer.step() lr_scheduler.step() optimizer.zero_grad() train_loss += loss.item() avg_train_loss = train_loss / len(train_loader) print(f"Average Training Loss: {avg_train_loss}") model.eval() val_loss = 0 with torch.no_grad(): for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"): inputs, answers = batch input_ids = inputs["input_ids"] pixel_values = inputs["pixel_values"] labels = processor.tokenizer( text=answers, return_tensors="pt", padding=True, return_token_type_ids=False, max_length=1024, truncation=True, ).input_ids.to(device) outputs = model( input_ids=input_ids, pixel_values=pixel_values, labels=labels ) loss = outputs.loss val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) print(f"Average Validation Loss: {avg_val_loss}") # Save the model and processor to a location in your datastore uri. torch.save(model.state_dict(), (f"{MODEL_DIR}/florence-2-state-dict-{epoch}")) processor.save_pretrained((f"{MODEL_DIR}/florence-2-processor-{epoch}/")) torch.save(model.state_dict(), (f"{MODEL_DIR}/final-florence-2-state-dict")) processor.save_pretrained((f"{MODEL_DIR}/final-florence-2-processor/")) # ### V. Model Inference # Load the trained model and processor from `IV. Model Training` # In[6]: state_dict = torch.load(f"{MODEL_DIR}/final-florence-2-state-dict") nl_trained_model = AutoModelForCausalLM.from_pretrained( "microsoft/Florence-2-base-ft", trust_remote_code=True, state_dict=state_dict ).eval() # In[7]: nl_trained_processor = AutoProcessor.from_pretrained( f"{MODEL_DIR}/final-florence-2-processor", trust_remote_code=True ) # Load a test set of images and captions to compare the ground truth labels with the generated output. In order to create a test dataset, you can run the dataset creation steps with `n=100` to have `n` new samples to test your model with. Perhaps save the test images in a separate images directory to prevent duplicate image names. Save this as `test.csv` under `DATA_DIR` # In[11]: test_df = pandas.read_csv(f"{DATA_DIR}test.csv") test_df # In[12]: i = 0 # Select an example to look at from your test dataframe image = Image.open(f"{DATA_DIR}images/" + test_df.iloc[i]["image"]) # Display the image image # Define the prediction function # In[13]: device = torch.device("cpu") # In[14]: # Define the function to take in the prompt, model, and processor and return the text def run_example(model, processor, task_prompt, text_input=None): if text_input is None: prompt = task_prompt else: prompt = task_prompt + text_input inputs = processor(text=prompt, images=image, return_tensors="pt") generated_ids = model.generate( input_ids=inputs["input_ids"].to(device), pixel_values=inputs["pixel_values"].to(device), max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( generated_text, task=task_prompt, image_size=(image.width, image.height), ) return parsed_answer # In[15]: # Print out the ground truth caption (the caption from the test dataframe) print(test_df.iloc[i]["caption"]) # In[16]: # Model fine-tuned on natural language captions generated by GPT 3.5 task_prompt = "" run_example(nl_trained_model, nl_trained_processor, task_prompt)