Revolutionise your PyTorch Workflow: How to Speed Up Your Deep Learning Training with This Simple Hack!
One of the consistent #pytorch / deep learning #designpatterns that you might come across in documentations and tutorials involves three steps:
While this pattern is effective for simple data augmentations and smaller image sizes, it can lead to CPU bottlenecks when working with more complex augmentations or larger images. In these cases, an alternative design pattern can be more efficient. Here's how it works:
领英推荐
Why is this better? Moving data between the main memory and the GPU memory is time consuming and every time such load happens the GPU is forced to wait. With this setup you are ensuring the prefetch loads the data all the way into GPU memory directly and the GPU won't sit idle between batches.
By following this pattern, you can significantly reduce CPU bottlenecks, speed up your deep learning training and utilise your expensive GPU resources more efficiently.
See below sample code to illustrate, but if you want to see how this works in action I recommend checking out this repository of mine on #github
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms
imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std = (0.229, 0.224, 0.225)
device = 'cuda'
class CustomDataset(Dataset):
? ? def __init__(self, data):
? ? ? ? self.data = data
? ? ? ? self.to_tensor = transforms.Compose([
? ? ? ? ? ? transforms.ToTensor()
? ? ? ? ])
self.transform = transforms.Compose([
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
transforms.RandomGrayscale(),
transforms.Normalize(imagenet_mean, imagenet_std)
])
? ? ? ??
? ? def __len__(self):
? ? ? ? return len(self.data)
? ??
? ? def __getitem__(self, idx):
? ? ? ? img = self.data[idx]
# this is the interesting part
img = self.to_tensor(img)
img = img.to(device)
? ? ? ? img = self.transform(img)
? ? ? ? return img
data = ... # load your dataset here
dataset = CustomDataset(data)
batch_size = 32
num_workers = 4
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=False)
for batch in dataloader:
? ? # do your training here