Training and Deploying an Image Captioning System
Demo
Upload your own photo to be captioned:
For the rest of this post I show an end-to-end training of the captioning system in a reproducible jupyter notebook style. This notebook was run on Google Colab on a high-ram GPU-accelerated runtime. All code for training and deployment is also available here.
Download the data from the coco site.
!curl -o "annotations_trainval2014.zip" http://images.cocodataset.org/annotations/annotations_trainval2014.zip
!unzip "annotations_trainval2014.zip"
!curl -o "train2014.zip" http://images.cocodataset.org/zips/train2014.zip
!curl -o "val2014.zip" http://images.cocodataset.org/zips/val2014.zip
import some stuff for initial data processing
import matplotlib.pyplot as plt
from io import BytesIO
import cv2
import json
import numpy as np
import zipfile
import torchtext
import string
import torch
now im going to iterate through all the pics in the archives and resize them to 224x224 while preserving the aspect ratios (by padding with black pixels). Ill save each resized picture to a numpy array to use later when building the model.
base_path = "."
ds_to_fn = {'train':'train2014.zip','val':'val2014.zip'}
size = 224
def pad_image(img, height, width):
h, w = img.shape[:2]
t = 0
b = height - h
l = 0
r = width - w
return cv2.copyMakeBorder(img, t, b, l, r,
cv2.BORDER_CONSTANT, value=0)
def resize_and_pad(img, height, width, resample=cv2.INTER_AREA):
if len(img.shape)==2:
img = np.stack([img,img,img],axis=2)
target_aspect_ratio = height/width
im_h, im_w, _ = img.shape
im_aspect_aspect_ratio = im_h/im_w
if im_aspect_aspect_ratio>target_aspect_ratio:
target_height = height
target_width = int(im_w * target_height/im_h)
else:
target_width = width
target_height = int(im_h * target_width/im_w)
resized = cv2.resize(img, (target_width, target_height),
interpolation=resample)
return pad_image(resized, height, width)
pics, im_fn_to_index = {}, {}
for ds in ['train','val']:
fn = ds_to_fn[ds]
archive = zipfile.ZipFile(f"{base_path}/{fn}")
file_list = archive.filelist
pics[ds] = np.zeros((len(file_list)-1,size,size,3),dtype=np.uint8)
im_fn_to_index[ds] = {}
for count,file_obj in enumerate(file_list):
im_fn = file_obj.filename
if not im_fn.endswith('.jpg'):
continue
with archive.open(file_obj) as open_file:
res = BytesIO(open_file.read())
pic = plt.imread(res,'jpg')
ind = len(im_fn_to_index[ds])
pics[ds][ind] = resize_and_pad(pic, size, size)
im_fn_to_index[ds][im_fn] = ind
archive.close()
for word embeddings i am going to use the pretrained GLoVE embeddings that are downloadable from torchtext. i chose to take the 100k most common words since less frequent words will add more dimensions to the output space without much benefit since most words in the captions are simple and therefore common. i also exclude any words with punctuation other than a dash or apostrophe or any words with an uppercase letter since i will only use lowercase text.
vocab = torchtext.vocab.GloVe(name='840B', dim=300, max_vectors=100000)
punctuation = set(c for c in string.punctuation if c not in "-'")
digits = set(str(i) for i in range(10))
inds_to_use = []
seen_lower = set()
upper_added = {}
words = set()
for i,word in enumerate(vocab.itos):
if not any(c in punctuation or c in digits for c in word):
if not all(c in "-'" for c in word) and word.islower():
inds_to_use.append(i)
words.add(word)
vocab.itos = np.array(vocab.itos)[inds_to_use]
vocab.stoi = {s:i for i,s in enumerate(vocab.itos)}
vocab.vectors = vocab.vectors[inds_to_use]
## size of remaining vocab
len(vocab.stoi),len(vocab.itos),vocab.vectors.size()
install some libraries that i use for pretrained building blocks
!pip install efficientnet_pytorch
!pip install transformers
import everything i will need for defining and training the neural net
import torch.nn as nn
import torchvision
import efficientnet_pytorch
import transformers
import scipy.stats
base_path = "."
load all the data and labels and define all the hash maps that i will use later during training
with open(f'{base_path}/annotations/captions_train2014.json','r') as f:
annot_train = json.load(f)
with open(f'{base_path}/annotations/captions_val2014.json','r') as f:
annot_val = json.load(f)
LONGEST_CAPTION = max(len(d['caption'].split())
for d in annot_train['annotations'] +\
annot_val['annotations'])
LONGEST_CAPTION
train_pics = pics['train']
val_pics = pics['val']
train_immap = im_fn_to_index['train']
val_immap = im_fn_to_index['val']
train_fn_to_index = {key.split('/')[1]:val
for key,val in train_immap.items()}
train_index_to_fn = {val:key for key,val in train_fn_to_index.items()}
val_fn_to_index = {key.split('/')[1]:val for key,val in val_immap.items()}
val_index_to_fn = {val:key for key,val in val_fn_to_index.items()}
train_imfn_to_imid = {d['file_name']:d['id']
for d in annot_train['images']}
train_imid_to_caption = {d['image_id']:d['caption']
for d in annot_train['annotations']}
train_imfn_to_caption = {fn:train_imid_to_caption[id_]
for fn,id_ in train_imfn_to_imid.items()}
val_imfn_to_imid = {d['file_name']:d['id'] for d in annot_val['images']}
val_imid_to_caption = {d['image_id']:d['caption']
for d in annot_val['annotations']}
val_imfn_to_caption = {fn:val_imid_to_caption[id_]
for fn,id_ in val_imfn_to_imid.items()}
define a function to show some training images and their provided captions from the training annotations. view a random 10 images and their captions
def show_im_and_cap_train(indexes):
for index in indexes:
imfn = train_index_to_fn[index]
caption = train_imfn_to_caption[imfn]
fig = plt.figure(figsize=(7,7))
plt.imshow(train_pics[index])
plt.title(caption)
plt.show()
show_im_and_cap_train(np.random.randint(0,len(train_pics),10))
i add two special tokens to the vocabulary - an end token and a unknown token. the unknown token will be needed for the embedding vector, the end token will be needed for the output space since it needs to be predicted to end a generated sequence
UNK_TOK = '~~UNK~~'
vocab.itos = np.concatenate([[UNK_TOK],vocab.itos])
vocab.stoi = {v:k+1 for v,k in vocab.stoi.items()}
vocab.stoi[UNK_TOK] = 0 ## ends up being 1
vocab.vectors = torch.cat([torch.zeros(1,300),vocab.vectors])
END_TOK = '~~END~~'
vocab.itos = np.concatenate([[END_TOK],vocab.itos])
vocab.stoi = {v:k+1 for v,k in vocab.stoi.items()}
vocab.stoi[END_TOK] = 0
vocab.vectors = torch.cat([torch.zeros(1,300),vocab.vectors])
next i save the captions as an array of integers which the will be used to look up the embeddings. i also save an accompanying mask vector which will be used to determine which words to backprop on. since the sequences are padded to the maximum length over all the captions, much of the captions array is just padding which is not useful for teaching the model.
train_captions = torch.zeros(train_pics.shape[0],
LONGEST_CAPTION+1,dtype=torch.long)
train_loss_mask = torch.zeros(train_pics.shape[0],
LONGEST_CAPTION+1,dtype=torch.bool)
for i in range(train_pics.shape[0]):
caption = train_imfn_to_caption[train_index_to_fn[i]]
split = caption.split()
for word_ind, word in enumerate(split):
word = word.lower().replace('.','').replace(',','').replace(';','')
if not word:
continue
if word in vocab.stoi:
train_captions[i, word_ind] = vocab.stoi[word]
else:
#print(word)
train_captions[i, word_ind] = vocab.stoi[UNK_TOK]
train_loss_mask[i, word_ind] = True
train_captions[i, word_ind + 1] = vocab.stoi[END_TOK] ## cause it should generate end tok
train_loss_mask[i, word_ind + 1] = True
val_captions = torch.zeros(val_pics.shape[0],
LONGEST_CAPTION+1,dtype=torch.long)
val_loss_mask = torch.zeros(val_pics.shape[0],
LONGEST_CAPTION+1,dtype=torch.bool)
for i in range(val_pics.shape[0]):
caption = val_imfn_to_caption[val_index_to_fn[i]]
split = caption.split()
for word_ind, word in enumerate(split):
word = word.lower().replace('.','').replace(',','').replace(';','')
if not word:
continue
if word in vocab.stoi:
val_captions[i, word_ind] = vocab.stoi[word]
else:
val_captions[i, word_ind] = vocab.stoi[UNK_TOK]
val_loss_mask[i, word_ind] = True
val_captions[i, word_ind + 1] = vocab.stoi[END_TOK]
val_loss_mask[i, word_ind + 1] = True
i define a torch dataset, the transforms for the images (the same transforms used to train efficientnet since i am using the pretrained efficientnet-b0 encoder plus adding some noise to reduce overfitting), and a dataloader with batchsize 16.
batch_size = 84
class Dataset(torch.utils.data.Dataset):
def __init__(self, pics, captions, loss_mask, pic_transform):
self.pics = pics
self.captions = captions
self.loss_mask = loss_mask
self.pic_transform = pic_transform
def __len__(self):
return self.pics.shape[0]
def __getitem__(self, idx):
return {'pics':self.pic_transform(self.pics[idx]),
'captions':self.captions[idx],
'loss_mask':self.loss_mask[idx]}
class AddGaussianNoise():
def __init__(self, mean=0., std=.25):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
normalize, AddGaussianNoise()
])
train_ds = Dataset(train_pics, train_captions,
train_loss_mask, train_transform)
train_dl = torch.utils.data.DataLoader(train_ds,
batch_size=batch_size,
shuffle=True)
val_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
normalize])
val_ds = Dataset(val_pics, val_captions, val_loss_mask, val_transform)
val_dl = torch.utils.data.DataLoader(val_ds,
batch_size=batch_size, shuffle=False)
next i define the rnn model. i load the pretrained efficientnet encoder and use the embedding of the image (after average pooling) as the first hidden state of the GRU. two predict a word from each hidden state i use a single linear layer with dropout in between. During inference I use sampling of the top 2 predictions to generate 10 candidate sequences and then use a pretrained language model to find sentence with the highest predicted probability in order to avoid some of the obvious gramatical issues that can happen with language generation.
def get_lm_score():
tokenizer = transformers.GPT2Tokenizer.from_pretrained('distilgpt2')
lm = transformers.GPT2LMHeadModel.from_pretrained('distilgpt2')
lm.eval()
for parameter in lm.parameters():
parameter.requires_grad = False
max_length = 86
def lm_score(sents):
## sents should be a list of strings
inds = torch.zeros(len(sents),max_length,dtype=torch.long)
mask = torch.ones(len(sents),max_length,dtype=torch.float)
for i in range(len(sents)):
tok = tokenizer.encode_plus(sents[i], add_special_tokens=True,
return_tensors='pt',
max_length=max_length)['input_ids'][0]
inds[i, :len(tok)] = tok
mask[i, len(tok):] = 0
logits = lm(inds)[0]
inds_flattened = inds.flatten()
indexer = torch.arange(0,inds_flattened.size(0),dtype=torch.long)
chosen_words = logits.view(logits.size(0)*logits.size(1),-1)[indexer,inds_flattened]
chosen_words = chosen_words.view(logits.size(0),logits.size(1))
lm_scores = nn.functional.logsigmoid(chosen_words * mask).sum(1).numpy()
lm_scores /= mask.sum(1).numpy()
return lm_scores
return lm_score
class Captioner(nn.Module):
def __init__(self, vocab):
super().__init__()
self.vocab = vocab
self.vocab_size = len(vocab.itos)
self.word_emb_size = 300
self.encoder = efficientnet_pytorch.EfficientNet.from_pretrained('efficientnet-b0')
self.pic_emb_size = 1280
self.average_pooling = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(p=.2)
self.decoder = nn.GRU(input_size=self.word_emb_size,
hidden_size=self.pic_emb_size,
batch_first=True)
self.classifier = nn.Linear(self.pic_emb_size,self.vocab_size)
self.start_tok_embed = nn.Parameter(torch.randn(self.word_emb_size,dtype=torch.float32))
self.lm_score = get_lm_score()
def forward(self, ims, caption_embs):
bs = ims.size(0)
im_embs = self.encoder.extract_features(ims)
im_embs = self.average_pooling(im_embs).view(bs,self.pic_emb_size)
hidden = im_embs.unsqueeze(0)
caption_embs = torch.cat([self.start_tok_embed.expand(bs,1,
self.word_emb_size),caption_embs],
axis=1)
out, _ = self.decoder(caption_embs,hidden)
out = self.dropout(out.reshape(bs*caption_embs.size(1),-1))
out = self.classifier(out)
out = out.view(bs,caption_embs.size(1),-1)
return out
def inference(self, im, device, num_sample=10, max_length=32, topk=2):
with torch.no_grad():
sents = []
for it in range(num_sample):
bs = 1
ims = im.unsqueeze(0)
im_embs = self.encoder.extract_features(ims)
im_embs = self.average_pooling(im_embs).view(bs,self.pic_emb_size)
hidden = im_embs.unsqueeze(0)
word_emb = self.start_tok_embed.expand(bs,1,self.word_emb_size)
preds = []
for i in range(max_length):
_, hidden = self.decoder(word_emb, hidden)
pred = self.classifier(hidden.squeeze(0)).squeeze()
pred = nn.functional.softmax(pred,dim=0)
top_preds = torch.topk(pred,topk)
top_preds_inds = top_preds.indices.cpu().numpy()
top_preds_values = top_preds.values.cpu().numpy()
top_preds_values = top_preds_values[top_preds_inds!=1]
top_preds_inds = top_preds_inds[top_preds_inds!=1]
top_preds_values = top_preds_values/top_preds_values.sum()
pred = np.random.choice(top_preds_inds,p=top_preds_values)
if pred==0:
break
word_emb = self.vocab.vectors[pred].view(bs,
1,self.word_emb_size).to(device)
preds.append(self.vocab.itos[pred])
sents.append(' '.join(preds))
scores = self.lm_score(sents)
print(sents)
return sents[np.argmax(scores)]
def get_word_embs(vocab, word_inds):
words = word_inds[:,:-1]
size = words.size()
return vocab.vectors[words.flatten()].view(size[0],size[1],300)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Captioner(vocab)
model = model.to(device)
for param in model.encoder.parameters():
param.requires_grad = False
model.encoder.eval()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
Finally I train the model for 10 epochs on the training set and monitor cross entropy loss on the train and val set and run the inference algorithm on 3 randomly selected val photos each epoch.
for epoch in range(7):
train_losses = []
val_losses = []
model.train()
model.encoder.eval()
for i,batch in enumerate(train_dl):
pics = batch['pics'].to(device)
caption_embs = get_word_embs(vocab, batch['captions']).to(device)
loss_mask = batch['loss_mask'].to(device)
preds = model(pics, caption_embs)[loss_mask]
labels = batch['captions'].to(device)[loss_mask]
loss = criterion(preds,labels)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(),2)
optimizer.step()
optimizer.zero_grad()
train_losses.append(loss.item())
model.eval()
with torch.no_grad():
for i,batch in enumerate(val_dl):
pics = batch['pics'].to(device)
caption_embs = get_word_embs(vocab, batch['captions']).to(device)
loss_mask = batch['loss_mask'].to(device)
preds = model(pics, caption_embs)[loss_mask]
labels = batch['captions'].to(device)[loss_mask]
loss = criterion(preds,labels)
val_losses.append(loss.item())
rand_val_exs = np.random.randint(0,len(val_ds),size=3)
for idx in rand_val_exs:
item = val_ds[idx]
im = item['pics'].to(device)
cap = model.inference(im, device)
plt.imshow(val_pics[idx])
plt.title(cap)
plt.show()
print(f"epoch: {epoch}, tr_loss: {np.mean(train_losses)}, "
f"vl_loss: {np.mean(val_losses)}")
In this notebook I walked through the end to end process I used to train a caption generation model using Microsoft's COCO dataset. I also productionized this model to run inference on user supplied photos and put the model in an api behind a webserver. You can see all the code for the project here.