DETR - End to End object DEtection with TRansformer
DETR Model prediction over Global Wheat Detection Kaggle competition dataset

DETR - End to End object DEtection with TRansformer

Attention is all you need, paper for Transformers, changed the state of NLP and has achieved great heights. Though mainly developed for NLP, the latest research around it focuses on how to leverage it across different verticals of deep learning. One such attempt is DEtection TRansformer (DETR) an object detection model developed by the Facebook Research team which cleverly utilizes the Transformer architecture.

Detection Transformer architecture in general leverages the transformer network(both encoder and the decoder) for Detecting Objects in Images. Facebook's researchers argue that for object detection one part of the image should be in contact with the other part of the image for greater result especially with occluded objects and partially visible objects, and what's better than to use a transformer for it.

The Architecture

DETR Architecture; from https://arxiv.org/pdf/2005.12872v3.pdf

The main motive behind DETR is effectively removing the need for many hand-designed components like a non-maximum suppression procedure or anchor generation that explicitly encode prior knowledge about the task and makes the process complex and computationally expensive

The DETR model consists of a pre-trained CNN backbone, which produces a set of lower dimensional set of features. These features are then scaled and added to a positional encoding, which is fed into a Transformer consisting of an Encoder and a Decoder in a manner quite similar to the Encoder-Decoder transformer described in. The output of the decoder is then fed into a fixed number of Prediction Heads which consist of a predefined number of feed forward networks. Each output of one of these prediction heads consists of a class prediction, as well as a predicted bounding box. The loss is calculated by computing the bipartite matching loss.

The main ingredients of DETR, are a set-based global loss that forces unique predictions via bipartite matching, and a transformer encoder-decoder architecture.

Refer DETR (https://arxiv.org/abs/2005.12872) paper to get more insight in its architecture and unique loss functions.

Global Wheat head Detection using DETR

No alt text provided for this image

Wheat is a staple across the globe. Its popularity as a food and crop makes wheat widely studied. To get large and accurate data about wheat fields worldwide, plant scientists use image detection of "wheat heads"—spikes atop the plant containing grain. However, accurate wheat head detection in outdoor field images can be visually challenging. There is often overlap of dense wheat plants, and the wind can blur the photographs. Both make it difficult to identify single heads. Additionally, appearances vary due to maturity, color, genotype, and head orientation. Finally, because wheat is grown worldwide, different varieties, planting densities, patterns, and field conditions must be considered. Models developed for wheat phenotyping need to generalize between different growing environments.

Now let as use DETR architecture to model Global Wheat head detection.

!git clone https://github.com/facebookresearch/detr.git

cloning github repo of detr to import its unique loss

import torch
import random, os
import numpy as np
import pandas as pd


## Progress bar
from tqdm.autonotebook import tqdm


## CV2
import cv2


## SKLEARN
from sklearn.model_selection import StratifiedKFold


## TORCH
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


## DETR Function for loss
import sys
sys.path.append('./detr/')


from detr.models.matcher import HungarianMatcher
from detr.models.detr import SetCriterion


#Albumenatations
import albumentations as A
import matplotlib.pyplot as plt
from albumentations.pytorch.transforms import ToTensorV2


#Glob
from glob import glob

Average Meter - class for averaging loss, metrics, etc over epochs

class AverageMeter(object):
  """Compute and store the average and current values"""
  def __init__(self):
    self.reset()


  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0
  
  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

Basic configuration for the model

# number of folds
n_folds = 5


# random seed
seed = 42


# number of classes
num_classes = 2


# number of bbox for output image
num_queries = 100


# null class coefficient
null_class_coef = 0.5


# batch size for training
BATCH_SIZE = 8


# learning rate for training
LR = 2e-5


# epoch for training
EPOCHS = 10

Seeding everything for better reproduction of results

def seed_everything(seed):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = True

seed_everything(seed)

Preparing Data

!kaggle datasets download 'phunghieu/global-wheat-detection-512x512' -p /content/wheat-detection --unzip

Preparing data for the model. The data can be split into any number of folds as you want , split is stratified based on number of boxes and source

marking = pd.read_csv('/content/wheat-detection/train.csv')

bboxs = np.stack(marking['bbox'].apply(lambda x: np.fromstring(x[1:-1], sep=',')))

for i, column in enumerate(['x', 'y', 'w', 'h']):
  marking[column] = bboxs[:,i]

marking.drop(columns=['bbox'], inplace=True)

About data splitting you can read [here]

skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)


df_folds = marking[['image_id']].copy()
df_folds.loc[:, 'bbox_count'] = 1
df_folds = df_folds.groupby('image_id').count()
df_folds.loc[:, 'source'] = marking[['image_id', 'source']].groupby('image_id').min()['source']
df_folds.loc[:, 'stratify_group'] = np.char.add(
    df_folds['source'].values.astype(str),
    df_folds['bbox_count'].apply(lambda x: f'_{x // 15}').values.astype(str)
)
df_folds.loc[:, 'fold'] = 0


for fold_number, (train_index, val_index) in enumerate(skf.split(X=df_folds.index, y=df_folds['stratify_group'])):
  df_folds.loc[df_folds.iloc[val_index].index, 'fold'] = fold_number

Data Argumentations

# perform data argumentation on train dataset

def get_train_transforms():
    return A.Compose([A.OneOf([A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit= 0.2, val_shift_limit=0.2, p=0.9),
                               A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.9)],p=0.9),
                      A.ToGray(p=0.01),
                      A.HorizontalFlip(p=0.5),
                      A.VerticalFlip(p=0.5),
                      A.Resize(height=512, width=512, p=1),
                      A.Cutout(num_holes=8, max_h_size=64, max_w_size=64, fill_value=0, p=0.5), ToTensorV2(p=1.0)],p=1.0,
                      bbox_params=A.BboxParams(format='coco',min_area=0, min_visibility=0,label_fields=['labels']))


# perform data argumentation on validation dataset

def get_valid_transforms():
    return A.Compose([A.Resize(height=512, width=512, p=1.0),
                      ToTensorV2(p=1.0)], 
                     p=1.0, 
                     bbox_params=A.BboxParams(format='coco',min_area=0, min_visibility=0,label_fields=['labels']))

Creating dataset

DETR accepts data in coco format which is (x,y,w,h)(for those who do not know there are two formats coco and pascal(smin,ymin,xmax,ymax) which are widely used). So now we need to prepare data in that format

DIR_TRAIN = '/content/wheat-detection/train/'

class WheatDataset(Dataset):
  def __init__(self,image_ids,dataframe,transforms=None):
    self.image_ids = image_ids
    self.df = dataframe
    self.transforms = transforms
  
  def __len__(self) -> int:
    return self.image_ids.shape[0]
  
  def __getitem__(self, index):
    image_id = self.image_ids[index]
    records = self.df[self.df['image_id'] == image_id]
    
    image = cv2.imread(f'{DIR_TRAIN}/{image_id}.jpg', cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) # Convert to RGB
    image /= 255.0 # Normalizing image
    
    # DETR takes in data in coco format 
    boxes = records[['x', 'y', 'w', 'h']].values
        
    #Area of bb
    area = boxes[:,2]*boxes[:,3]
    area = torch.as_tensor(area, dtype=torch.float32)
    
    #the main class is labelled as zero
    labels =  np.zeros(len(boxes), dtype=np.int32)
    
    if self.transforms:
      sample = {
          'image': image,
          'bboxes': boxes,
          'labels': labels
          }
      sample = self.transforms(**sample)
      image = sample['image']
      boxes = sample['bboxes']
      labels = sample['labels']
          
    #Normalizing BBOXES
    
    _,h,w = image.shape
    boxes = A.augmentations.bbox_utils.normalize_bboxes(sample['bboxes'],rows=h,cols=w)
    
    target = {}
    target['boxes'] = torch.as_tensor(boxes,dtype=torch.float32)
    target['labels'] = torch.as_tensor(labels,dtype=torch.long)
    target['image_id'] = torch.tensor([index])
    target['area'] = area
    
    return image, target, image_id

Model

  • Initial DETR model is trained on coco dataset , which has 91 classes + 1 background class , hence we need to modify it to take our own number of classes
  • Also DETR model takes in 100 queries ie ,it outputs total of 100 bboxes for every image , we can very well change that too
class DETRModel(nn.Module):
  def __init__(self, num_classes, num_queries):
    super(DETRModel, self).__init__()
    self.num_classes = num_classes
    self.num_queries = num_queries


    self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
    self.in_features = self.model.class_embed.in_features
    self.model.class_embed = nn.Linear(in_features=self.in_features,out_features=self.num_classes)
    self.model.num_queries = self.num_queries
  
  def forward(self, images):
    return self.model(images)

Matcher and Bipartite Matching Loss Now we make use of the unique loss that the model uses and for that we need to define the matcher. DETR calcuates three individual losses :

  • Classification Loss for labels(its weight can be set by loss_ce)
  • Bbox Loss (its weight can be set by loss_bbox)
  • Loss for Background class
matcher = HungarianMatcher()
weight_dict = {'loss_ce': 1, 'loss_bbox': 1 , 'loss_giou': 1}
losses = ['labels', 'boxes', 'cardinality']

Training Function

Training of DETR is unique and different from FasteRRcnn and EfficientDET , as we train the criterion as well

def train_fn(data_loader, model, criterion, optimizer, device, scheduler, epoch):
  model.train()
  criterion.train()


  summary_loss = AverageMeter()


  tk0 = tqdm(data_loader, total=len(data_loader))


  for step, (images, targets, image_ids) in enumerate(tk0):
    images = list(image.to(device) for image in images)
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]


    output = model(images)


    loss_dict = criterion(output, targets)
    weight_dict = criterion.weight_dict


    losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)


    optimizer.zero_grad()


    losses.backward()
    optimizer.step()
    
    if scheduler is not None:
      scheduler.step()
    
    summary_loss.update(losses.item(),BATCH_SIZE)
    tk0.set_postfix(loss=summary_loss.avg)
  
  return summary_loss

Eval Function

def eval_fn(data_loader, model,criterion, device):
  model.eval()
  criterion.eval()
  summary_loss = AverageMeter()
  
  
  with torch.no_grad():
    tk0 = tqdm(data_loader, total=len(data_loader))
    
    for step, (images, targets, image_ids) in enumerate(tk0):
      images = list(image.to(device) for image in images)
      targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
      
      output = model(images)
      
      loss_dict = criterion(output, targets)
      weight_dict = criterion.weight_dict
      
      losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
      
      summary_loss.update(losses.item(),BATCH_SIZE)
      tk0.set_postfix(loss=summary_loss.avg)
  
  return summary_loss

Run Engine

def collate_fn(batch):
   return tuple(zip(*batch))

def run(fold):
  df_train = df_folds[df_folds['fold'] != fold]
  df_valid = df_folds[df_folds['fold'] == fold]


  train_dataset = WheatDataset(
      image_ids=df_train.index.values,
      dataframe=marking,
      transforms=get_train_transforms())
  
  valid_dataset = WheatDataset(
      image_ids=df_valid.index.values,
      dataframe=marking,
      transforms=get_valid_transforms())
  
  train_data_loader = DataLoader(
      train_dataset,
      batch_size=BATCH_SIZE,
      shuffle=False,
      num_workers=4,
      collate_fn=collate_fn)


  valid_data_loader = DataLoader(
      valid_dataset,
      batch_size=BATCH_SIZE,
      shuffle=False,
      num_workers=4,
      collate_fn=collate_fn)
  
  device = torch.device('cuda')
  model = DETRModel(num_classes=num_classes,num_queries=num_queries)
  model = model.to(device)
  criterion = SetCriterion(num_classes-1, matcher, weight_dict, eos_coef = null_class_coef, losses=losses)
  criterion = criterion.to(device)
  
  optimizer = torch.optim.AdamW(model.parameters(), lr=LR)


  best_loss = 10**5
  
  for epoch in range(EPOCHS):
    train_loss = train_fn(train_data_loader, model,criterion, optimizer,device,scheduler=None,epoch=epoch)
    valid_loss = eval_fn(valid_data_loader, model,criterion, device)
    
    print('|EPOCH {}| TRAIN_LOSS {}| VALID_LOSS {}|'.format(epoch+1,train_loss.avg,valid_loss.avg))
    
    if valid_loss.avg < best_loss:
      best_loss = valid_loss.avg
      print('Best model found for Fold {} in Epoch {}........Saving Model'.format(fold,epoch+1))
      torch.save(model.state_dict(), f'detr_best_{fold}.pth')


run(fold=0)

Validate Model

def view_sample(df_valid, model, device):
  valid_dataset = WheatDataset(image_ids=df_valid.index.values,
                               dataframe=marking,
                               transforms=get_valid_transforms())
  
  valid_data_loader = DataLoader(
      valid_dataset,
      batch_size=BATCH_SIZE,
      shuffle=False,
      num_workers=4,
      collate_fn=collate_fn)
  
  images, targets, image_ids = next(iter(valid_data_loader))
  _,h,w = images[0].shape # for de normalizing images
  
  images = list(img.to(device) for img in images)
  targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  
  boxes = targets[4]['boxes'].cpu().numpy()
  boxes = [np.array(box).astype(np.int32) for box in A.augmentations.bbox_utils.denormalize_bboxes(boxes,h,w)]
  sample = images[4].permute(1,2,0).cpu().numpy()
  
  model.eval()
  model.to(device)
  cpu_device = torch.device("cpu")
  
  with torch.no_grad():
    outputs = model(images)
    
  outputs = [{k: v.to(cpu_device) for k, v in outputs.items()}]
  fig, ax = plt.subplots(1, 1, figsize=(16, 8))


  for box in boxes:
    cv2.rectangle(sample,
                  (box[0], box[1]),
                  (box[2]+box[0], box[3]+box[1]),
                  (220, 0, 0), 1)
  
  oboxes = outputs[0]['pred_boxes'][0].detach().cpu().numpy()
  oboxes = [np.array(box).astype(np.int32) for box in A.augmentations.bbox_utils.denormalize_bboxes(oboxes,h,w)]
  prob   = outputs[0]['pred_logits'][0].softmax(1).detach().cpu().numpy()[:,0]
  
  for box,p in zip(oboxes,prob):
    if p >0.5:
      color = (0,0,220) #if p>0.5 else (0,0,0)
      cv2.rectangle(sample,
                    (box[0], box[1]),
                    (box[2]+box[0], box[3]+box[1]),
                    color, 1)
  ax.set_axis_off()
  ax.imshow(sample)


Now view validate result

model = DETRModel(num_classes=num_classes,num_queries=num_queries)
model.load_state_dict(torch.load("./detr_best_0.pth"))
view_sample(df_folds[df_folds['fold'] == 0],model=model,device=torch.device('cuda'))
No alt text provided for this image

What’s next

I hope you have gained some valuable insight into DETR and its implementation in general from this introductory post. In subsequent article will focus more on fine tuning DETR model and more advance topic such as Deformable DETR and soon.

Thank you for reading this article. Hope it gave you some good insights! Feel free to leave any feedback:)

要查看或添加评论,请登录

社区洞察

其他会员也浏览了