Text-to-Image Search via CLIP#
Traditionally, searching images from text (text-image-retrieval) relies heavily on human annotations, this is commonly referred to as Text/Tag-based Image Retrieval (TBIR).
The OpenAI CLIP model maps the dense vectors extracted from text and image into the same semantic space and produces a strong zero-shot model to measure the similarity between text and images.
This guide will showcase fine-tuning a CLIP
model for text-to-image retrieval.
Note, please consider switching to GPU/TPU Runtime for faster inference.
Install#
!pip install 'finetuner[full]'
Task#
We’ll be fine-tuning CLIP on the fashion captioning dataset which contains information about fashion products.
For each product, the dataset contains a title and images of multiple variants of the product. We constructed a parent Document
for each picture, which contains two chunks: an image document and a text document holding the description of the product.
Data#
Our journey starts locally. We have to prepare the data and push it to the Jina AI Cloud and Finetuner will be able to get the dataset by its name. For this example,
we already prepared the data, and we’ll provide the names of training and evaluation data (fashion-train-data-clip
and fashion-eval-data-clip
) directly to Finetuner.
In addition, we also provide labeled queries and an index of labeled documents for evaluating the retrieval capabilities of the resulting fine-tuned model stored in the datasets fashion-eval-data-queries
and fashion-eval-data-index
.
Push data to the cloud
We don’t require you to push data to the Jina AI Cloud by yourself. Instead of a name, you can provide a DocumentArray
and Finetuner will do the job for you.
When working with documents where images are stored locally, please call doc.load_uri_to_blob()
to reduce network transmission and speed up training.
import finetuner
from finetuner import DocumentArray, Document
finetuner.login(force=True)
train_data = DocumentArray.pull('finetuner/fashion-train-data-clip', show_progress=True)
eval_data = DocumentArray.pull('finetuner/fashion-eval-data-clip', show_progress=True)
query_data = DocumentArray.pull('finetuner/fashion-eval-data-queries', show_progress=True)
index_data = DocumentArray.pull('finetuner/fashion-eval-data-index', show_progress=True)
train_data.summary()
Backbone model#
Currently, we support several CLIP variations from open-clip for text to image retrieval tasks.
However, you can see all available models either in choose backbone section or by calling finetuner.describe_models()
.
Fine-tuning#
Now that we have the training and evaluation datasets loaded as DocumentArray
s and selected our model, we can start our fine-tuning run.
from finetuner.callback import EvaluationCallback
run = finetuner.fit(
model='clip-base-en',
train_data='finetuner/fashion-train-data-clip',
eval_data='finetuner/fashion-eval-data-clip',
epochs=5,
learning_rate= 1e-7,
loss='CLIPLoss',
device='cuda',
callbacks=[
EvaluationCallback(
model='clip-text',
index_model='clip-vision',
query_data='finetuner/fashion-eval-data-queries',
index_data='finetuner/fashion-eval-data-index',
)
],
)
Let’s understand what this piece of code does:
We start with providing
model
, names of training and evaluation data.We also provide some hyperparameters such as number of
epochs
and alearning_rate
.We use
CLIPLoss
to optimize the CLIP model.We use an evaluation callback, which uses the
'clip-text'
model for encoding the text queries and the'clip-vision'
model for encoding the images in'fashion-eval-data-index'
.
Monitoring#
Now that we’ve created a run, let’s see its status. You can monitor the run by checking the status - run.status()
and - the logs - run.logs()
or - run.stream_logs()
.
# note, the fine-tuning might takes 20~ minutes
for entry in run.stream_logs():
print(entry)
Since some runs might take up to several hours/days, it’s important to know how to reconnect to Finetuner and retrieve your run.
import finetuner
finetuner.login()
run = finetuner.get_run(run.name)
You can continue monitoring the run by checking the status - finetuner.run.Run.status()
or the logs - finetuner.run.Run.logs()
.
Evaluating#
Our EvaluationCallback
during fine-tuning ensures that after each epoch, an evaluation of our model is run. We can access the results of the last evaluation in the logs as follows print(run.logs())
:
Training [5/5] ━━━━ 195/195 0:00… 0:0… • loss: 2.419 • val_loss: 2.803
[13:32:41] INFO Done ✨ __main__.py:195
DEBUG Finetuning took 0 days, 0 hours 5 minutes and 30 seconds
DEBUG Metric: 'clip-text-to-clip-vision_precision_at_k' Value: 0.28532
DEBUG Metric: 'clip-text-to-clip-vision_hit_at_k' Value: 0.94282
DEBUG Metric: 'clip-text-to-clip-vision_average_precision' Value: 0.53372
DEBUG Metric: 'clip-text-to-clip-vision_reciprocal_rank' Value: 0.67706
DEBUG Metric: 'clip-text-to-clip-vision_dcg_at_k' Value: 2.71247
...
Saving#
After the run has finished successfully, you can download the tuned model on your local machine:
artifact = run.save_artifact('clip-model')
Inference#
Now you saved the artifact
into your host machine,
let’s use the fine-tuned model to encode a new Document
:
text_da = DocumentArray([Document(text='some text to encode')])
image_da = DocumentArray([Document(uri='https://upload.wikimedia.org/wikipedia/commons/4/4e/Single_apple.png')])
clip_text_encoder = finetuner.get_model(artifact=artifact, select_model='clip-text')
clip_image_encoder = finetuner.get_model(artifact=artifact, select_model='clip-vision')
finetuner.encode(model=clip_text_encoder, data=text_da)
finetuner.encode(model=clip_image_encoder, data=image_da)
print(text_da.embeddings.shape)
print(image_da.embeddings.shape)
(1, 512)
(1, 512)
what is select_model?
When fine-tuning CLIP, we are fine-tuning the CLIPVisionEncoder and CLIPTextEncoder in parallel.
The artifact contains two models: clip-vision
and clip-text
.
The parameter select_model
tells finetuner which model to use for inference, in the above example,
we use clip-text
to encode a Document with text content.
Inference with ONNX
In case you set to_onnx=True
when calling finetuner.fit
function,
please use model = finetuner.get_model(artifact, is_onnx=True)
Advanced: WiSE-FT#
WiSE-FT, proposed by Mitchell et al. in Robust fine-tuning of zero-shot models, has been proven to be an effective way for fine-tuning models with a strong zero-shot capability, such as CLIP. As was introduced in the paper:
Large pre-trained models such as CLIP or ALIGN offer consistent accuracy across a range of data distributions when performing zero-shot inference (i.e., without fine-tuning on a specific dataset). Although existing fine-tuning methods substantially improve accuracy on a given target distribution, they often reduce robustness to distribution shifts. We address this tension by introducing a simple and effective method for improving robustness while fine-tuning: ensembling the weights of the zero-shot and fine-tuned models (WiSE-FT).
Finetuner allows you to apply WiSE-FT easily,
all you need to do is use the WiSEFTCallback
.
Finetuner will trigger the callback when the fine-tuning job is finished and merge the weights between the pre-trained model and the fine-tuned model:
from finetuner.callback import WiSEFTCallback
run = finetuner.fit(
model='clip-base-en',
...,
loss='CLIPLoss',
- callbacks=[],
+ callbacks=[WiSEFTCallback(alpha=0.5)],
)
The value you set for alpha
should be greater than or equal to 0 and less than or equal to 1:
if
alpha
is a float between 0 and 1, we merge the weights between the pre-trained model and the fine-tuned model.if
alpha
is 0, the fine-tuned model is identical to the pre-trained model.if
alpha
is 1, the pre-trained weights will not be utilized.
That’s it! Check out clip-as-service to learn how to plug in a fine-tuned CLIP model to our CLIP-specific service.
Before and after#
We can directly compare the results of our fine-tuned model with a pre-trained clip model by displaying the matches each model has for the same query. While the differences between the results of the two models are quite subtle for some queries, the examples below clearly show that finetuning increases the quality of the search results:
Results for query: "nightingale tee jacket" using a zero-shot model (top) and the fine-tuned model (bottom)