{opticon}tag
{bdg-primary}Forest
{bdg-secondary}Modelling
{bdg-warning}Standard
{bdg-info}Python
Accurately delineating trees using detectron2
, a library that provides state-of-the-art deep learning detection and segmentation algorithms.
An established deep learning model, Mask R-CNN was deployed from detectron2
library to delineate tree crowns accurately. A pre-trained model, named detectreeRGB
, is provided to predict the location and extent of tree crowns from a top-down RGB image, captured by drone, aircraft or satellite. detectreeRGB
was implemented in python
3.8 using pytorch
v1.7.1 and detectron2
v0.5. Further details can be found in the repository documentation.
The project was supported by the UKRI Centre for Doctoral Training in Application of Artificial Intelligence to the study of Environmental Risks (AI4ER) (EP/S022961/1).
:::{note} The authors acknowledge the authors of the Detectron2 package which provides the Mask R-CNN architecture. :::
import cv2
from PIL import Image
import os
import numpy as np
import urllib.request
import glob
# intake library and plugin
import intake
from intake_zenodo_fetcher import download_zenodo_files_for_entry
# geospatial libraries
import geopandas as gpd
from rasterio.transform import from_origin
import rasterio.features
import fiona
from shapely.geometry import shape, mapping, box
from shapely.geometry.multipolygon import MultiPolygon
# machine learning libraries
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer
# visualisation
import holoviews as hv
import geoviews.tile_sources as gts
import matplotlib.pyplot as plt
import hvplot.pandas
import hvplot.xarray
import warnings
warnings.filterwarnings(action='ignore')
hv.extension('bokeh', width=100)
# Define the project main folder
notebook_folder = './notebook'
# Set the folder structure
config = {
'in_geotiff': os.path.join(notebook_folder, 'input','tiff'),
'in_png': os.path.join(notebook_folder, 'input','png'),
'model': os.path.join(notebook_folder, 'model'),
'out_geotiff': os.path.join(notebook_folder, 'output','raster'),
'out_shapefile': os.path.join(notebook_folder, 'output','vector'),
}
# List comprehension for the folder structure code
[os.makedirs(val) for key, val in config.items() if not os.path.exists(val)]
intake
¶Let's fetch a sample aerial image from a Zenodo repository.
# write a catalog YAML file
catalog_file = os.path.join(notebook_folder, 'catalog.yaml')
with open(catalog_file, 'w') as f:
f.write('''
sources:
sepilok_rgb:
driver: rasterio
description: 'NERC RGB images of Sepilok, Sabah, Malaysia (collection)'
metadata:
zenodo_doi: "10.5281/zenodo.5494629"
args:
urlpath: "{{ CATALOG_DIR }}/input/tiff/Sep_2014_RGB_602500_646600.tif"
''')
cat_tc = intake.open_catalog(catalog_file)
for catalog_entry in list(cat_tc):
download_zenodo_files_for_entry(
cat_tc[catalog_entry],
force_download=False
)
tc_rgb = cat_tc["sepilok_rgb"].to_dask()
Let's investigate the data-array
, what is the shape? Bounds? Bands? CRS?
print('shape =', tc_rgb.shape,',', 'and number of bands =', tc_rgb.count, ', crs =', tc_rgb.crs)
Mask R-CNN requires images in png
format. Let's export the RGB bands to a png
file.
minx = 602500
miny = 646600
R = tc_rgb[0]
G = tc_rgb[1]
B = tc_rgb[2]
# stack up the bands in an order appropriate for saving with cv2, then rescale to the correct 0-255 range for cv2
# you will have to change the rescaling depending on the values of your tiff!
rgb = np.dstack((R,G,B)) # BGR for cv2
rgb_rescaled = 255*rgb/65535 # scale to image
# save this as png, named with the origin of the specific tile - change the filepath!
filepath = config['in_png'] + '/' + 'tile_' + str(minx) + '_' + str(miny) + '.png'
cv2.imwrite(filepath, rgb_rescaled)
im = cv2.imread(filepath)
plot_input = plt.figure(figsize=(15,15))
plt.imshow(Image.fromarray(im))
plt.title('Input image',fontsize='xx-large')
plt.axis('off')
plt.show()
# define the URL to retrieve the model
fn = 'model_final.pth'
url = f'https://zenodo.org/record/5515408/files/{fn}?download=1'
urllib.request.urlretrieve(url, config['model'] + '/' + fn)
detectron2
config¶The following lines allow configuring the main settings for predictions and load them into a DefaultPredictor
object.
cfg = get_cfg()
# if you want to make predictions using a CPU, run the following line. If using GPU, hash it out.
cfg.MODEL.DEVICE='cpu'
# model and hyperparameter selection
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
### path to the saved pre-trained model weights
cfg.MODEL.WEIGHTS = config['model'] + '/model_final.pth'
# set confidence threshold at which we predict
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.15
#### Settings for predictions using detectron config
predictor = DefaultPredictor(cfg)
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1], scale=1.5, instance_mode=ColorMode.IMAGE_BW) # remove the colors of unsegmented pixels
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
image = cv2.cvtColor(v.get_image()[:, :, :], cv2.COLOR_BGR2RGB)
plot_predictions = plt.figure(figsize=(15,15))
plt.imshow(Image.fromarray(image))
plt.title('Predictions',fontsize='xx-large')
plt.axis('off')
plt.show()
mask_array = outputs['instances'].pred_masks.cpu().numpy()
# get confidence scores too
mask_array_scores = outputs['instances'].scores.cpu().numpy()
num_instances = mask_array.shape[0]
mask_array_instance = []
output = np.zeros_like(mask_array)
mask_array_instance.append(mask_array)
output = np.where(mask_array_instance[0] == True, 255, output)
fresh_output = output.astype(float)
x_scaling = 140/fresh_output.shape[1]
y_scaling = 140/fresh_output.shape[2]
# this is an affine transform. This needs to be altered significantly.
transform = from_origin(int(filepath[-17:-11])-20, int(filepath[-10:-4])+120, y_scaling, x_scaling)
output_raster = config['out_geotiff'] + '/' + 'predicted_rasters_' + filepath[-17:-4]+ '.tif'
new_dataset = rasterio.open(output_raster, 'w', driver='GTiff',
height = fresh_output.shape[1], width = fresh_output.shape[2], count = fresh_output.shape[0],
dtype=str(fresh_output.dtype),
crs='+proj=utm +zone=50 +datum=WGS84 +units=m +no_defs',
transform=transform)
new_dataset.write(fresh_output)
new_dataset.close()
# Read input band with Rasterio
with rasterio.open(output_raster) as src:
shp_schema = {'geometry': 'MultiPolygon','properties': {'pixelvalue': 'int', 'score': 'float'}}
crs = src.crs
for i in range(src.count):
src_band = src.read(i+1)
src_band = np.float32(src_band)
conf = mask_array_scores[i]
# Keep track of unique pixel values in the input band
unique_values = np.unique(src_band)
# Polygonize with Rasterio. `shapes()` returns an iterable
# of (geom, value) as tuples
shapes = list(rasterio.features.shapes(src_band, transform=src.transform))
if i == 0:
with fiona.open(config['out_shapefile'] + '/predicted_polygons_' + filepath[-17:-4] + '_' + str(0) + '.shp', 'w', 'ESRI Shapefile',
shp_schema) as shp:
polygons = [shape(geom) for geom, value in shapes if value == 255.0]
multipolygon = MultiPolygon(polygons)
# simplify not needed here
#multipolygon = multipolygon_a.simplify(0.1, preserve_topology=False)
shp.write({
'geometry': mapping(multipolygon),
'properties': {'pixelvalue': int(unique_values[1]), 'score': float(conf)}
})
else:
with fiona.open(config['out_shapefile'] + '/predicted_polygons_' + filepath[-17:-4] + '_' + str(0)+'.shp', 'a', 'ESRI Shapefile',
shp_schema) as shp:
polygons = [shape(geom) for geom, value in shapes if value == 255.0]
multipolygon = MultiPolygon(polygons)
# simplify not needed here
#multipolygon = multipolygon_a.simplify(0.1, preserve_topology=False)
shp.write({
'geometry': mapping(multipolygon),
'properties': {'pixelvalue': int(unique_values[1]), 'score': float(conf)}
})
# load and plot polygons
in_shp = glob.glob(config['out_shapefile'] + '/*.shp')
poly_df = gpd.read_file(in_shp[0])
plot_vector = poly_df.hvplot(hover_cols=['score'], legend=False).opts(fill_color=None,line_color=None,alpha=0.5, width=800, height=600, xaxis=None, yaxis=None)
plot_vector
# load and plot RGB image
r = tc_rgb.sel(band=[1,2,3])
normalized = r/(r.quantile(.99,skipna=True)/255)
mask = normalized.where(normalized < 255)
int_arr = mask.astype(int)
plot_rgb = int_arr.astype('uint8').hvplot.rgb(
x='x', y='y', bands='band', data_aspect=0.8, hover=False, legend=False, rasterize=True, xaxis=None, yaxis=None, title='Tree crown delineation by detectreeRGB'
)
Note we have some artifacts in the RGB image due to the transformations using the normalization procedure.
plot_predictions_interactive = plot_rgb * plot_vector
plot_predictions_interactive
hvplot.save(plot_predictions_interactive, notebook_folder + '/interactive_predictions.html')
We have read in a raster, chosen a tile and made predictions on it. These predictions can then be transformed to shapefiles and examined in GIS software!
png
using a pretrained Mask R-CNN model, detectreeRGB
.Please see CITATION.cff for the full citation information. The citation file can be exported to APA or BibTex formats (learn more here).
Codebase: version 1.0.0 with commit 16a5a1c
License: The code in this notebook is licensed under the MIT License. The Environmental Data Science book is licensed under the Creative Commons by Attribution 4.0 license. See further details here.
Contact: If you have any suggestion or report an issue with this notebook, feel free to create an issue or send a direct message to environmental.ds.book@gmail.com.
from datetime import date
print('Notebook repository version: v1.0.2')
print(f'Last tested: {date.today()}')
The cell below is dedicated to save the notebook outputs for registering them into a Zenodo repository curated by the Environmental DS book.
outputs = {
'static_figures': {
'filenames': ['static_input','static_predictions'],
'data':[plot_input, plot_predictions]},
'interactive_figures': {
'filenames': ['interactive_vector','interactive_predictions'],
'data':[plot_vector, plot_predictions_interactive]},
}
#save static figures
if len(outputs['static_figures']['filenames']) > 0:
[data.savefig(os.path.join(notebook_folder,outputs['static_figures']['filenames'][x] + '.png')) for x, data in enumerate(outputs['static_figures']['data'])]
#save interactive figures
if len(outputs['interactive_figures']['filenames']) > 0:
[hvplot.save(data, os.path.join(notebook_folder,outputs['interactive_figures']['filenames'][x] + '.html')) for x, data in enumerate(outputs['interactive_figures']['data'])]