RetinaNet / Focal Loss (Object Detection)

RetinaNet / Focal Loss (Object Detection)

Computer vision has revolutionized how we interact with technology, and object detection is one of the key areas that has seen tremendous progress in recent years. With object detection, we can train algorithms to identify and locate objects within an image or video accurately. One of the most current and popular object detection algorithms is RetinaNet, which was introduced by Facebook AI Research in 2018. In this article, we'll take a closer look at RetinaNet, understand how it works, and compare it to other object detection algorithms. Whether you're a computer vision expert or just starting out, this article will provide a comprehensive understanding of RetinaNet and how it can be used to solve real-world problems.

What is RetinaNet?

RetinaNet is a single-stage object detection algorithm that combines the strengths of two popular object detection approaches: anchor-based methods and anchor-free methods. Unlike traditional anchor-based methods, RetinaNet does not rely on pre-defined anchor boxes and instead predicts object bounding boxes and class probabilities directly from feature maps. This approach results in a simpler and more efficient architecture, while still maintaining high accuracy. Additionally, RetinaNet addresses the issue of class imbalance in object detection by using a novel loss function that balances the contributions of positive and negative samples.

Anchor-based vs Anchor-free :

No alt text provided for this image
A visual explanation shows the difference between anchor-based and anchor-free methods. The red bounding box is the ground truth, the blue bounding box is a predefined anchor, and the green lines are the offsets. (a) The anchor-based methods predict the offsets based on a predefined anchor. (b) The anchor-free methods directly estimate the offsets of a point to its outside boundaries.


Anchor-based methods involve pre-defining a set of anchor boxes, also known as default boxes, and using them as a reference to predict object-bounding boxes. The anchor boxes are typically chosen to cover a range of aspect ratios and scales so that they can be matched to the size and shape of different objects. The goal of anchor-based methods is to classify each anchor box as either positive (contains an object) or negative (does not contain an object), and then refine the position and size of the bounding box to better fit the object.

Anchor-free methods, on the other hand, do not rely on pre-defined anchor boxes. Instead, they predict object bounding boxes and class probabilities directly from feature maps. This approach results in a simpler and more efficient architecture, but can also make it more challenging to handle objects of different sizes and aspect ratios.

Novel loss function :

RetinaNet's novel loss function is one of its key contributions to the field of object detection. The loss function is designed to handle the problem of class imbalance, which is a common issue in object detection where some classes have many more instances than others. This can lead to a bias in the model towards the more frequent classes, making it harder to accurately detect instances of the less frequent classes.

RetinaNet's loss function addresses class imbalance by using a balanced focal loss, which down-weights the contribution of easy examples (where the model is confident and correct) and focuses on the hard examples (where the model is less confident or incorrect). The focal loss penalizes confident predictions that are incorrect more harshly than less confident predictions, helping to reduce the impact of class imbalance and improve overall performance.

The balanced focal loss has been shown to significantly improve object detection accuracy on a variety of datasets, making RetinaNet a strong candidate for solving object detection problems in real-world applications.

No alt text provided for this image
Comparing Focal Loss with Cross Entropy Loss

The implementation of the balanced focal loss function in Python:


import numpy as np


def balanced_focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
? ? """
? ? Compute the balanced focal loss between y_true and y_pred.
? ??
? ? Args:
? ? ? ? y_true: ground-truth labels, shape (batch_size, num_classes)
? ? ? ? y_pred: predicted probabilities, shape (batch_size, num_classes)
? ? ? ? alpha: weight for balancing positive and negative samples
? ? ? ? gamma: focusing parameter for down-weighting easy examples
? ??
? ? Returns:
? ? ? ? The balanced focal loss.
? ? """
? ? y_pred = np.clip(y_pred, 1e-8, 1 - 1e-8)
? ? y_true = y_true[:, np.newaxis]
? ? logit = -y_true * np.log(y_pred)
? ? loss = y_true * ((1 - y_pred) ** gamma) * logit
? ? loss = np.sum(loss, axis=-1)
? ? pos_weight = np.sum(y_true, axis=-1)
? ? pos_weight = np.where(pos_weight == 0, 1, pos_weight)
? ? return (alpha * pos_weight * loss) / np.sum(pos_weight)
        

RetinaNet architecture :

No alt text provided for this image
from RetinaNet paper

  • The feature extractor network is typically a pre-trained deep convolutional neural network (CNN), such as ResNet or FPN, that is used to extract high-level feature representations from the input image. The feature maps generated by the feature extractor are passed through the detection head, which consists of two sub-networks: the classification sub-network and the regression sub-network.
  • The classification sub-network takes the feature maps as input and predicts the probability of an object being present in each anchor box.
  • The regression sub-network takes the same feature maps as input and predicts the bounding box coordinates of the objects in the anchor boxes.
  • The anchor boxes are predefined bounding boxes of different scales and aspect ratios that are placed at every location in the feature maps. During the training process, the anchor boxes are matched with the ground-truth bounding boxes, and the balanced focal loss is used to train the classification and regression sub-networks.
  • At inference time, the predicted probabilities and bounding boxes are combined to generate the final detections. The non-maximum suppression (NMS) algorithm is typically used to filter out overlapping detections and keep only the most confident ones.

RetinaNet paper.
my GitHub.

要查看或添加评论,请登录

AYOUB KIROUANE的更多文章

社区洞察

其他会员也浏览了