Semantic Segmentation of Satellite Images Based on Deep Learning Algorithms

Vihan Tyagi
5 min readMay 9, 2021

--

Hello everyone! Hope all of you’re safe. My name is Vihan Tyagi and I’m a final year CSE student from Bennett University. In this blog, I’m going to briefly discuss the project I was working on during my training at the Space Application Center, ISRO. So, let’s get right to it.

WHAT IS SEMANTIC SEGMENTATION?

The process of classifying different objects in an image is known as the segmentation of the image. It’s done at a pixel level (each pixel is classified) such that different objects belonging to the same class will be given the same label. Example: If there are two cars in the image they would be classified under the same label in semantic segmentation unlike in Instance segmentation.

DATASET

The first thing you deal with when you start a deep learning project is to prepare a dataset. If you don’t have access to Earth observation satellite imagery then you could use open source projects like OpenStreetMap to download the ground truth for those exact locations. For my project, I downloaded Map Tiles from an online tile server based on OSM data: wmflabs OSM no labels (since annotations would add to the noise in our data). Here’s the link.

https://tiles.wmflabs.org/osm-no-labels/${z}/${x}/${y}.png

x = X tile, y = Y tile, z = Zoom level (Spatial Resolution)

Satellite image and its ground truth downloaded from OSM

After you get access to the satellite images you have to georeference the labeled data using the satellite images (GeoTIFF files) as reference in QGIS. After georeferencing the labeled data divide your dataset into Train, Validation, and Test in a proper directory structure. My dataset had a few thousand images and their corresponding labels.

UNet MODEL

Let’s look into the model architecture. Originally used for biomedical images it's the go-to model when dealing with satellite images. It provides good localization accuracy and doesn’t require a huge dataset.

The UNet model, as the name suggests, has a U-shaped architecture. It can be divided into an encoder part and a decoder part. The main idea behind UNet is to supplement a usual contracting network by successive layers, where max-pooling parts are replaced by upsampling operators (Transposed CNN) in the expansive part. Hence, these layers increase the resolution of the output. To localize, high-resolution features from the contracting path are combined with the unsampled output. A successive convolution layer can then learn to assemble a more precise output based on this information. Since the expansive part is symmetric to the contracting part, UNet has a U-shaped architecture.

The Model function takes in the total number of classes as input. In the case of our project, it was 12 (originally 50+ but similar classes were clubbed in our class map). Now since our model has been defined let’s compile it by defining the loss, optimizer, and metric (accuracy) by which we’ll evaluate our model’s performance. For this particular use case categorical cross-entropy loss, adam optimizer, and accuracy metric were chosen.

DATA GENERATORS

Now before we start training our model (model.fit) we should also briefly discuss image data generators. When dealing with huge datasets it's very memory intensive to load the entire data and feed it to a model. Most of the time you would get the OOM (out of memory) error. You can solve this issue by loading the data in batches of pre-defined sizes. The TensorFlow function:

tf.keras.preprocessing.image.ImageDataGenerator(…)

allows you to preprocess your data (Normalize, resize, etc.). he output should be the same shape as the input which doesn’t work for our use case as our one-hot encoded label would be of shape x*y*classes. To solve this issue we can create a custom Data Generator.

TRAIN AND PREDICT

First, we create Train, Test, and Validation generators and set callbacks for early stopping (in case the model isn’t learning) and model checkpoint (for saving the model at the end of each epoch). Then just call the model.fit function like mentioned below:

Training your model

When all the epochs are completed, load your best-saved model and predict masks of the satellite images in your test data.

predictions = model.predict(test_generator)

The predictions would be of shape (256, 256, total classes). First, perform the argmax function so that the output is of shape 256*256 where each pixel has the value of the index (in class map) of the class they belong to. After replacing the index with the defined rgb values plot the results:

Prediction using the best-saved model

CHALLENGES FACED AND PROPOSED SOLUTIONS

  • Half the work in a Deep learning project is to generate a good quality dataset. Even though it is relatively easier now thanks to open source projects like OSM, a dataset of a few thousand images is still small. You could use data augmentation in your custom data generator to feed different variations of your data to the model to make it more robust.
  • Even in the no annotation tile server of OSM there are symbols used to represent certain places which add to the noise in our dataset and in turn reduce the accuracy of our model.
  • The OOM error discussed earlier was a major issue which we solved by loading the dataset in batches of 64 via the custom data generator.
  • Class Imbalance: Even though our results are decent there are certain classes that are not being detected properly like roads. This is due to class imbalance and could be solved by giving more weight to the road class and other underrepresented classes. This can be done using sample weights, not class weights because for 3+ dimensions (batch_size, 256, 256, classes) class weights can’t be used. Sample weights can be used by including sample_weight_mode = “temporal” in model.compile

CONCLUSION

My training experience at SAC (ISRO) helped me learn a lot. They have a very supportive environment for trainees and my mentor Mr. Ashutosh Gupta was quite helpful and guided me throughout the course of this internship.

We still plan to work on the project to further increase the performance of the model and improve boundary adherence of the prediction of different classes. We also plan to try out other architectures that have performed well on similar problems, like pix2pix (Generative Adversarial Networks).

This project was part of my final semester internship and I would like to thank the Bennett CSE department for their constant guidance. Thank you for reading and stay safe.

REFERENCES

--

--

No responses yet