DETR - End to End object DEtection with TRansformer
Attention is all you need, paper for Transformers, changed the state of NLP and has achieved great heights. Though mainly developed for NLP, the latest research around it focuses on how to leverage it across different verticals of deep learning. One such attempt is DEtection TRansformer (DETR) an object detection model developed by the Facebook Research team which cleverly utilizes the Transformer architecture.
Detection Transformer architecture in general leverages the transformer network(both encoder and the decoder) for Detecting Objects in Images. Facebook's researchers argue that for object detection one part of the image should be in contact with the other part of the image for greater result especially with occluded objects and partially visible objects, and what's better than to use a transformer for it.
The Architecture
The main motive behind DETR is effectively removing the need for many hand-designed components like a non-maximum suppression procedure or anchor generation that explicitly encode prior knowledge about the task and makes the process complex and computationally expensive
The DETR model consists of a pre-trained CNN backbone, which produces a set of lower dimensional set of features. These features are then scaled and added to a positional encoding, which is fed into a Transformer consisting of an Encoder and a Decoder in a manner quite similar to the Encoder-Decoder transformer described in. The output of the decoder is then fed into a fixed number of Prediction Heads which consist of a predefined number of feed forward networks. Each output of one of these prediction heads consists of a class prediction, as well as a predicted bounding box. The loss is calculated by computing the bipartite matching loss.
The main ingredients of DETR, are a set-based global loss that forces unique predictions via bipartite matching, and a transformer encoder-decoder architecture.
Refer DETR (https://arxiv.org/abs/2005.12872) paper to get more insight in its architecture and unique loss functions.
Global Wheat head Detection using DETR
Wheat is a staple across the globe. Its popularity as a food and crop makes wheat widely studied. To get large and accurate data about wheat fields worldwide, plant scientists use image detection of "wheat heads"—spikes atop the plant containing grain. However, accurate wheat head detection in outdoor field images can be visually challenging. There is often overlap of dense wheat plants, and the wind can blur the photographs. Both make it difficult to identify single heads. Additionally, appearances vary due to maturity, color, genotype, and head orientation. Finally, because wheat is grown worldwide, different varieties, planting densities, patterns, and field conditions must be considered. Models developed for wheat phenotyping need to generalize between different growing environments.
Now let as use DETR architecture to model Global Wheat head detection.
!git clone https://github.com/facebookresearch/detr.git
cloning github repo of detr to import its unique loss
import torch
import random, os import numpy as np import pandas as pd ## Progress bar from tqdm.autonotebook import tqdm ## CV2 import cv2 ## SKLEARN from sklearn.model_selection import StratifiedKFold ## TORCH import torch.nn as nn from torch.utils.data import Dataset, DataLoader ## DETR Function for loss import sys sys.path.append('./detr/') from detr.models.matcher import HungarianMatcher from detr.models.detr import SetCriterion #Albumenatations import albumentations as A import matplotlib.pyplot as plt from albumentations.pytorch.transforms import ToTensorV2 #Glob
from glob import glob
Average Meter - class for averaging loss, metrics, etc over epochs
class AverageMeter(object): """Compute and store the average and current values""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n
self.avg = self.sum / self.count
Basic configuration for the model
# number of folds n_folds = 5 # random seed seed = 42 # number of classes num_classes = 2 # number of bbox for output image num_queries = 100 # null class coefficient null_class_coef = 0.5 # batch size for training BATCH_SIZE = 8 # learning rate for training LR = 2e-5 # epoch for training
EPOCHS = 10
Seeding everything for better reproduction of results
def seed_everything(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True seed_everything(seed)
Preparing Data
!kaggle datasets download 'phunghieu/global-wheat-detection-512x512' -p /content/wheat-detection --unzip
Preparing data for the model. The data can be split into any number of folds as you want , split is stratified based on number of boxes and source
marking = pd.read_csv('/content/wheat-detection/train.csv')
bboxs = np.stack(marking['bbox'].apply(lambda x: np.fromstring(x[1:-1], sep=','))) for i, column in enumerate(['x', 'y', 'w', 'h']): marking[column] = bboxs[:,i]
marking.drop(columns=['bbox'], inplace=True)
About data splitting you can read [here]
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed) df_folds = marking[['image_id']].copy() df_folds.loc[:, 'bbox_count'] = 1 df_folds = df_folds.groupby('image_id').count() df_folds.loc[:, 'source'] = marking[['image_id', 'source']].groupby('image_id').min()['source'] df_folds.loc[:, 'stratify_group'] = np.char.add( df_folds['source'].values.astype(str), df_folds['bbox_count'].apply(lambda x: f'_{x // 15}').values.astype(str) ) df_folds.loc[:, 'fold'] = 0 for fold_number, (train_index, val_index) in enumerate(skf.split(X=df_folds.index, y=df_folds['stratify_group'])): df_folds.loc[df_folds.iloc[val_index].index, 'fold'] = fold_number
Data Argumentations
# perform data argumentation on train dataset def get_train_transforms(): return A.Compose([A.OneOf([A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit= 0.2, val_shift_limit=0.2, p=0.9), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.9)],p=0.9), A.ToGray(p=0.01), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.Resize(height=512, width=512, p=1), A.Cutout(num_holes=8, max_h_size=64, max_w_size=64, fill_value=0, p=0.5), ToTensorV2(p=1.0)],p=1.0, bbox_params=A.BboxParams(format='coco',min_area=0, min_visibility=0,label_fields=['labels'])) # perform data argumentation on validation dataset def get_valid_transforms(): return A.Compose([A.Resize(height=512, width=512, p=1.0), ToTensorV2(p=1.0)], p=1.0,
bbox_params=A.BboxParams(format='coco',min_area=0, min_visibility=0,label_fields=['labels']))
Creating dataset
DETR accepts data in coco format which is (x,y,w,h)(for those who do not know there are two formats coco and pascal(smin,ymin,xmax,ymax) which are widely used). So now we need to prepare data in that format
DIR_TRAIN = '/content/wheat-detection/train/'
class WheatDataset(Dataset): def __init__(self,image_ids,dataframe,transforms=None): self.image_ids = image_ids self.df = dataframe self.transforms = transforms def __len__(self) -> int: return self.image_ids.shape[0] def __getitem__(self, index): image_id = self.image_ids[index] records = self.df[self.df['image_id'] == image_id] image = cv2.imread(f'{DIR_TRAIN}/{image_id}.jpg', cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) # Convert to RGB image /= 255.0 # Normalizing image # DETR takes in data in coco format boxes = records[['x', 'y', 'w', 'h']].values #Area of bb area = boxes[:,2]*boxes[:,3] area = torch.as_tensor(area, dtype=torch.float32) #the main class is labelled as zero labels = np.zeros(len(boxes), dtype=np.int32) if self.transforms: sample = { 'image': image, 'bboxes': boxes, 'labels': labels } sample = self.transforms(**sample) image = sample['image'] boxes = sample['bboxes'] labels = sample['labels'] #Normalizing BBOXES _,h,w = image.shape boxes = A.augmentations.bbox_utils.normalize_bboxes(sample['bboxes'],rows=h,cols=w) target = {} target['boxes'] = torch.as_tensor(boxes,dtype=torch.float32) target['labels'] = torch.as_tensor(labels,dtype=torch.long) target['image_id'] = torch.tensor([index]) target['area'] = area return image, target, image_id
Model
- Initial DETR model is trained on coco dataset , which has 91 classes + 1 background class , hence we need to modify it to take our own number of classes
- Also DETR model takes in 100 queries ie ,it outputs total of 100 bboxes for every image , we can very well change that too
class DETRModel(nn.Module): def __init__(self, num_classes, num_queries): super(DETRModel, self).__init__() self.num_classes = num_classes self.num_queries = num_queries self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True) self.in_features = self.model.class_embed.in_features self.model.class_embed = nn.Linear(in_features=self.in_features,out_features=self.num_classes) self.model.num_queries = self.num_queries def forward(self, images):
return self.model(images)
Matcher and Bipartite Matching Loss Now we make use of the unique loss that the model uses and for that we need to define the matcher. DETR calcuates three individual losses :
- Classification Loss for labels(its weight can be set by loss_ce)
- Bbox Loss (its weight can be set by loss_bbox)
- Loss for Background class
matcher = HungarianMatcher() weight_dict = {'loss_ce': 1, 'loss_bbox': 1 , 'loss_giou': 1}
losses = ['labels', 'boxes', 'cardinality']
Training Function
Training of DETR is unique and different from FasteRRcnn and EfficientDET , as we train the criterion as well
def train_fn(data_loader, model, criterion, optimizer, device, scheduler, epoch): model.train() criterion.train() summary_loss = AverageMeter() tk0 = tqdm(data_loader, total=len(data_loader)) for step, (images, targets, image_ids) in enumerate(tk0): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] output = model(images) loss_dict = criterion(output, targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) optimizer.zero_grad() losses.backward() optimizer.step() if scheduler is not None: scheduler.step() summary_loss.update(losses.item(),BATCH_SIZE) tk0.set_postfix(loss=summary_loss.avg)
return summary_loss
Eval Function
def eval_fn(data_loader, model,criterion, device): model.eval() criterion.eval() summary_loss = AverageMeter() with torch.no_grad(): tk0 = tqdm(data_loader, total=len(data_loader)) for step, (images, targets, image_ids) in enumerate(tk0): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] output = model(images) loss_dict = criterion(output, targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) summary_loss.update(losses.item(),BATCH_SIZE) tk0.set_postfix(loss=summary_loss.avg)
return summary_loss
Run Engine
def collate_fn(batch):
return tuple(zip(*batch))
def run(fold): df_train = df_folds[df_folds['fold'] != fold] df_valid = df_folds[df_folds['fold'] == fold] train_dataset = WheatDataset( image_ids=df_train.index.values, dataframe=marking, transforms=get_train_transforms()) valid_dataset = WheatDataset( image_ids=df_valid.index.values, dataframe=marking, transforms=get_valid_transforms()) train_data_loader = DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=collate_fn) valid_data_loader = DataLoader( valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=collate_fn) device = torch.device('cuda') model = DETRModel(num_classes=num_classes,num_queries=num_queries) model = model.to(device) criterion = SetCriterion(num_classes-1, matcher, weight_dict, eos_coef = null_class_coef, losses=losses) criterion = criterion.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=LR) best_loss = 10**5 for epoch in range(EPOCHS): train_loss = train_fn(train_data_loader, model,criterion, optimizer,device,scheduler=None,epoch=epoch) valid_loss = eval_fn(valid_data_loader, model,criterion, device) print('|EPOCH {}| TRAIN_LOSS {}| VALID_LOSS {}|'.format(epoch+1,train_loss.avg,valid_loss.avg)) if valid_loss.avg < best_loss: best_loss = valid_loss.avg print('Best model found for Fold {} in Epoch {}........Saving Model'.format(fold,epoch+1))
torch.save(model.state_dict(), f'detr_best_{fold}.pth') run(fold=0)
Validate Model
def view_sample(df_valid, model, device): valid_dataset = WheatDataset(image_ids=df_valid.index.values, dataframe=marking, transforms=get_valid_transforms()) valid_data_loader = DataLoader( valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=collate_fn) images, targets, image_ids = next(iter(valid_data_loader)) _,h,w = images[0].shape # for de normalizing images images = list(img.to(device) for img in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] boxes = targets[4]['boxes'].cpu().numpy() boxes = [np.array(box).astype(np.int32) for box in A.augmentations.bbox_utils.denormalize_bboxes(boxes,h,w)] sample = images[4].permute(1,2,0).cpu().numpy() model.eval() model.to(device) cpu_device = torch.device("cpu") with torch.no_grad(): outputs = model(images) outputs = [{k: v.to(cpu_device) for k, v in outputs.items()}] fig, ax = plt.subplots(1, 1, figsize=(16, 8)) for box in boxes: cv2.rectangle(sample, (box[0], box[1]), (box[2]+box[0], box[3]+box[1]), (220, 0, 0), 1) oboxes = outputs[0]['pred_boxes'][0].detach().cpu().numpy() oboxes = [np.array(box).astype(np.int32) for box in A.augmentations.bbox_utils.denormalize_bboxes(oboxes,h,w)] prob = outputs[0]['pred_logits'][0].softmax(1).detach().cpu().numpy()[:,0] for box,p in zip(oboxes,prob): if p >0.5: color = (0,0,220) #if p>0.5 else (0,0,0) cv2.rectangle(sample, (box[0], box[1]), (box[2]+box[0], box[3]+box[1]), color, 1) ax.set_axis_off()
ax.imshow(sample)
Now view validate result
model = DETRModel(num_classes=num_classes,num_queries=num_queries) model.load_state_dict(torch.load("./detr_best_0.pth"))
view_sample(df_folds[df_folds['fold'] == 0],model=model,device=torch.device('cuda'))
What’s next
I hope you have gained some valuable insight into DETR and its implementation in general from this introductory post. In subsequent article will focus more on fine tuning DETR model and more advance topic such as Deformable DETR and soon.
Thank you for reading this article. Hope it gave you some good insights! Feel free to leave any feedback:)