Training and Deploying an Image Captioning System

COCOComplete (3)
Model Diagram

Demo

Upload your own photo to be captioned:

















I don't store your uploaded files anywhere




For the rest of this post I show an end-to-end training of the captioning system in a reproducible jupyter notebook style. This notebook was run on Google Colab on a high-ram GPU-accelerated runtime. All code for training and deployment is also available here.

Download the data from the coco site.

In [1]:
!curl -o "annotations_trainval2014.zip" http://images.cocodataset.org/annotations/annotations_trainval2014.zip
!unzip "annotations_trainval2014.zip"
!curl -o "train2014.zip" http://images.cocodataset.org/zips/train2014.zip
!curl -o "val2014.zip" http://images.cocodataset.org/zips/val2014.zip
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  241M  100  241M    0     0  86.2M      0  0:00:02  0:00:02 --:--:-- 86.2M
Archive:  annotations_trainval2014.zip
  inflating: annotations/instances_train2014.json  
  inflating: annotations/instances_val2014.json  
  inflating: annotations/person_keypoints_train2014.json  
  inflating: annotations/person_keypoints_val2014.json  
  inflating: annotations/captions_train2014.json  
  inflating: annotations/captions_val2014.json  
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 12.5G  100 12.5G    0     0  39.6M      0  0:05:24  0:05:24 --:--:-- 89.1M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 6337M  100 6337M    0     0  38.8M      0  0:02:42  0:02:42 --:--:-- 16.3M

import some stuff for initial data processing

In [0]:
import matplotlib.pyplot as plt
from io import BytesIO
import cv2
import json
import numpy as np
import zipfile
import torchtext
import string
import torch

now im going to iterate through all the pics in the archives and resize them to 224x224 while preserving the aspect ratios (by padding with black pixels). Ill save each resized picture to a numpy array to use later when building the model.

In [0]:
base_path = "."
ds_to_fn = {'train':'train2014.zip','val':'val2014.zip'}
size = 224
In [0]:
def pad_image(img, height, width):
    h, w = img.shape[:2]
    t = 0
    b = height - h
    l = 0
    r = width - w
    return cv2.copyMakeBorder(img, t, b, l, r, 
                              cv2.BORDER_CONSTANT, value=0)

def resize_and_pad(img, height, width, resample=cv2.INTER_AREA):
    if len(img.shape)==2:
        img = np.stack([img,img,img],axis=2)
    target_aspect_ratio = height/width
    im_h, im_w, _ = img.shape
    im_aspect_aspect_ratio = im_h/im_w
    if im_aspect_aspect_ratio>target_aspect_ratio:
        target_height = height
        target_width = int(im_w * target_height/im_h)
    else:
        target_width = width
        target_height = int(im_h * target_width/im_w)
    resized = cv2.resize(img, (target_width, target_height),
                         interpolation=resample)
    return pad_image(resized, height, width)
In [0]:
pics, im_fn_to_index = {}, {}
for ds in ['train','val']:
    fn = ds_to_fn[ds]
    archive = zipfile.ZipFile(f"{base_path}/{fn}")
    file_list = archive.filelist
    pics[ds] = np.zeros((len(file_list)-1,size,size,3),dtype=np.uint8)
    im_fn_to_index[ds] = {}
    for count,file_obj in enumerate(file_list):
        im_fn = file_obj.filename
        if not im_fn.endswith('.jpg'):
            continue
        with archive.open(file_obj) as open_file:
            res = BytesIO(open_file.read())
            pic = plt.imread(res,'jpg')
        ind = len(im_fn_to_index[ds])
        pics[ds][ind] = resize_and_pad(pic, size, size)
        im_fn_to_index[ds][im_fn] = ind
    archive.close()

for word embeddings i am going to use the pretrained GLoVE embeddings that are downloadable from torchtext. i chose to take the 100k most common words since less frequent words will add more dimensions to the output space without much benefit since most words in the captions are simple and therefore common. i also exclude any words with punctuation other than a dash or apostrophe or any words with an uppercase letter since i will only use lowercase text.

In [6]:
vocab = torchtext.vocab.GloVe(name='840B', dim=300, max_vectors=100000)
.vector_cache/glove.840B.300d.zip: 2.18GB [16:54, 2.15MB/s]                           
  5%|▍         | 99408/2196017 [00:09<03:28, 10063.89it/s]
In [0]:
punctuation = set(c for c in string.punctuation if c not in "-'")
digits = set(str(i) for i in range(10))

inds_to_use = []
seen_lower = set()
upper_added = {}
words = set()
for i,word in enumerate(vocab.itos):
    if not any(c in punctuation or c in digits for c in word):
        if not all(c in "-'" for c in word) and word.islower():
            inds_to_use.append(i)
            words.add(word)

vocab.itos = np.array(vocab.itos)[inds_to_use]
vocab.stoi = {s:i for i,s in enumerate(vocab.itos)}
vocab.vectors = vocab.vectors[inds_to_use]
In [8]:
## size of remaining vocab
len(vocab.stoi),len(vocab.itos),vocab.vectors.size()
Out[8]:
(41746, 41746, torch.Size([41746, 300]))

install some libraries that i use for pretrained building blocks

In [9]:
!pip install efficientnet_pytorch
!pip install transformers
Collecting efficientnet_pytorch
  Downloading https://files.pythonhosted.org/packages/b8/cb/0309a6e3d404862ae4bc017f89645cf150ac94c14c88ef81d215c8e52925/efficientnet_pytorch-0.6.3.tar.gz
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from efficientnet_pytorch) (1.5.0+cu101)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch->efficientnet_pytorch) (1.18.5)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->efficientnet_pytorch) (0.16.0)
Building wheels for collected packages: efficientnet-pytorch
  Building wheel for efficientnet-pytorch (setup.py) ... done
  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.6.3-cp36-none-any.whl size=12422 sha256=3c070e4d1baf1641c460b24655bb552213f1ffdf78f29b9fa830390a9ca7165d
  Stored in directory: /root/.cache/pip/wheels/42/1e/a9/2a578ba9ad04e776e80bf0f70d8a7f4c29ec0718b92d8f6ccd
Successfully built efficientnet-pytorch
Installing collected packages: efficientnet-pytorch
Successfully installed efficientnet-pytorch-0.6.3
Collecting transformers
  Downloading https://files.pythonhosted.org/packages/48/35/ad2c5b1b8f99feaaf9d7cdadaeef261f098c6e1a6a2935d4d07662a6b780/transformers-2.11.0-py3-none-any.whl (674kB)
     |████████████████████████████████| 675kB 3.5MB/s 
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.5)
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)
Collecting sentencepiece
  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
     |████████████████████████████████| 1.1MB 15.5MB/s 
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1)
Collecting tokenizers==0.7.0
  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)
     |████████████████████████████████| 3.8MB 24.4MB/s 
Collecting sacremoses
  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
     |████████████████████████████████| 890kB 42.5MB/s 
Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4)
Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.9)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)
Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.15.1)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... done
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893260 sha256=50458b0730a1d32307c4a7a67e21070b13d99080fd955c82a2827fc2d8c19c6d
  Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45
Successfully built sacremoses
Installing collected packages: sentencepiece, tokenizers, sacremoses, transformers
Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.7.0 transformers-2.11.0

import everything i will need for defining and training the neural net

In [0]:
import torch.nn as nn
import torchvision
import efficientnet_pytorch
import transformers
import scipy.stats


base_path = "."

load all the data and labels and define all the hash maps that i will use later during training

In [0]:
with open(f'{base_path}/annotations/captions_train2014.json','r') as f:
    annot_train = json.load(f)

with open(f'{base_path}/annotations/captions_val2014.json','r') as f:
    annot_val = json.load(f)
In [12]:
LONGEST_CAPTION = max(len(d['caption'].split())
                      for d in annot_train['annotations'] +\
                      annot_val['annotations'])
LONGEST_CAPTION
Out[12]:
50
In [0]:
train_pics = pics['train']
val_pics = pics['val']
train_immap = im_fn_to_index['train']
val_immap = im_fn_to_index['val']
In [0]:
train_fn_to_index = {key.split('/')[1]:val 
                     for key,val in train_immap.items()}
train_index_to_fn = {val:key for key,val in train_fn_to_index.items()}

val_fn_to_index = {key.split('/')[1]:val for key,val in val_immap.items()}
val_index_to_fn = {val:key for key,val in val_fn_to_index.items()}
In [0]:
train_imfn_to_imid = {d['file_name']:d['id'] 
                      for d in annot_train['images']}
train_imid_to_caption = {d['image_id']:d['caption']
                         for d in annot_train['annotations']}
train_imfn_to_caption = {fn:train_imid_to_caption[id_]
                         for fn,id_ in train_imfn_to_imid.items()}

val_imfn_to_imid = {d['file_name']:d['id'] for d in annot_val['images']}
val_imid_to_caption = {d['image_id']:d['caption']
                       for d in annot_val['annotations']}
val_imfn_to_caption = {fn:val_imid_to_caption[id_]
                       for fn,id_ in val_imfn_to_imid.items()}

define a function to show some training images and their provided captions from the training annotations. view a random 10 images and their captions

In [0]:
def show_im_and_cap_train(indexes):
    for index in indexes:
        imfn = train_index_to_fn[index]
        caption = train_imfn_to_caption[imfn]
        fig = plt.figure(figsize=(7,7))
        plt.imshow(train_pics[index])
        plt.title(caption)
        plt.show()
In [17]:
show_im_and_cap_train(np.random.randint(0,len(train_pics),10))