Deep Learning: How to Train a Joint Embedding using Pytorch
Introduction
In one of my latest research projects, I am working with multimodal chatbots that integrate both vision and text. I’m excited to share the latest work I’ve done recently with my joint embedding model.
What is a Joint Embedding?
A joint embedding is simply that—a “joint” “embedding”. It is an embedding that joins together two modes of media, in my case, vision and text.
The whole idea of a joint embedding is to train a model that can learn to represent different types of media in one format. For example, you can train a model that given both an image of an /orange/ and a caption saying literally “orange”, would output an array that describes both the image and the text.
The array the model outputs is considered the multimodal space. Another way to understand the purpose of a joint embedding is like so: images with associated captions, e.g, an image of a dog and a caption saying “dog”, when passed through the model will output very similar, if not the same values in the multimodal space, however, unrelated images and captions, such as an image of a watermelon and the text “car”, would output very different coordinates in the multimodal space, thus showing that the space can be used to relate the similarities of different types of media.
The above image I found online gives a visual description of what’s happening. As you can see, the label “dog” is in the bottom right corner. Images of dogs project very close to the label “dog”. We can see how two different types of media, images and text, can be related to each and understood when passed through a model.
Why would we need a Joint Embedding?
In projects involving both vision and text, its often you’ll have a textual description of an object or situation, and want to retrieve a related image. With a joint embedding, you can pass you text through the model and use this to retrieve similar images. Its like how you can search on Google Images using text, and related images will show up. It is that exact functionality. Thus, it can be incredibly powerful and useful.
The Code
I built the model from scratch using Pytorch. Pytorch is by far my favorite framework for deep learning research at the moment. Below is my final train.py script for training the joint embedding.
Although there is a lot of work happening behind the scenes, I wanted to demonstrate how simple and straight forward the final product can be. My code is completely open source, available on Github with a description of how to use it.
from data import Data
from settings import config
from model import Model
from loss import PairwiseRankingLoss as Loss
from optimizer import Optimizer
# Load data
data ?= Data()
# track score to save best model
score = 0
# Use K fold cross validation for model selection
for train, test, fold in data.k_folds(5):
# Prepare data to use the current fold
data.process(train, test, fold)
# Load model
model = Model(data)
# Model loss function
loss = Loss()
# Optimizer
optimizer = Optimizer(model)
# Begin epochs
for epoch in range(config["num_epochs"]):
# Process batches
for caption, image_feature in data:
pass
# Pass data through model
caption, image_feature = model(caption, image_feature)
# Compute loss
cost = loss(caption, image_feature)
# Zero gradient, Optimize loss, and perform back-propagation
optimizer.backprop(cost)
# Evaluate model results
model.evaluate(data)
# Final evaluation - save if results are better
model_score = model.evaluate(data)
if model_score > score:
score = model_score
model.save()
data.save_dictionaries()
Follow up
Please check out my Github Repository here
Quantum Algorithms in Biology | CFA Candidate
5 年Keep it up man :)?
Hands-On Software Engineer, Research Scientist, and Technology/Product Executive
5 年Awesome! Thank you for sharing.