Image-to-Image Search with TripletMarginLoss#
Using image queries to search for visually similar images is a very popular use case. However, pre-trained models do not deliver the best results. Models are trained on general data that lack knowledge related to your specific task. Here’s where Finetuner comes in! It enables you to easily add task-specific knowledge to a model.
This guide will demonstrate how to fine-tune a ResNet model for image-to-image retrieval.
Note, please switch to a GPU/TPU Runtime or this will be extremely slow!
!pip install 'finetuner[full]'
More specifically, we will fine-tune ResNet50 on the Totally Looks Like Dataset. The dataset consists of 6016 pairs of images (12032 in total).
The dataset consists of pairs of images, these are the positive pairs. Negative pairs are constructed by taking two different images, i.e. images that are not in the same pair initially. Following this approach, we construct triplets and use the
TripletLoss. You can find more in the how Finetuner works section.
After fine-tuning, the embeddings of positive pairs are expected to be pulled closer, while the embeddings for negative pairs are expected to be pushed away.
Our journey starts locally. We have to prepare the data and push it to the Jina AI Cloud and Finetuner will be able to get the dataset by its name. For this example,
we’ve already prepared the data, and we’ll provide Finetuner with just the names of training, query and index dataset (e.g.
You don’t have to push your data to the Jina AI Cloud before fine-tuning. Instead of a name, you can provide a
DocumentArray and Finetuner will do upload your data directly.
Important: If your documents refer to locally stored images, please call
doc.load_uri_to_blob() before starting Finetuner to reduce network transmission and speed up training.
import finetuner from finetuner import DocumentArray, Document finetuner.login(force=True)
train_data = DocumentArray.pull('finetuner/tll-train-data', show_progress=True) query_data = DocumentArray.pull('finetuner/tll-test-query-data', show_progress=True) index_data = DocumentArray.pull('finetuner/tll-test-index-data', show_progress=True) train_data.summary()
Now let’s see which backbone models we can use. You can see all the available models by calling
For this example, we’re gonna go with
Now that we have selected our model and loaded the training and evaluation datasets as
DocumentArrays, we can start our fine-tuning run.
from finetuner.callback import EvaluationCallback run = finetuner.fit( model='resnet-base', train_data='finetuner/tll-train-data', batch_size=128, epochs=5, learning_rate=1e-4, device='cuda', callbacks=[ EvaluationCallback( query_data='finetuner/tll-test-query-data', index_data='finetuner/tll-test-index-data', ) ], )
Let’s understand what this piece of code does:
As you can see, we have to provide the
modelwhich we picked before.
We also set
description, which are optional, but recommended in order to retrieve your run easily and have some context about it.
Furthermore, we had to provide names of the
Additionally, we use
Lastly, we set the number of
epochsand provide a
Now that we’ve created a run, let’s see its status. You can monitor the run by checking the status -
run.status() - and the logs -
# note, the fine-tuning might takes 30~ minutes for entry in run.stream_logs(): print(entry)
Since some runs might take up to several hours, it’s important to know how to reconnect to Finetuner and retrieve your runs.
import finetuner finetuner.login() run = finetuner.get_run(run.name)
You can continue monitoring the runs by checking the status -
finetuner.run.Run.status() or the logs -
Currently, we don’t have a user-friendly way to get evaluation metrics from the
finetuner.callback.EvaluationCallback we initialized previously.
What you can do for now is to call
run.logs() after the end of the run and see the evaluation results:
Training [5/5] ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76/76 0:00:00 0:03:15 • loss: 0.003 [16:39:13] DEBUG Metric: 'model_average_precision' Value: 0.19598 __main__.py:202 DEBUG Metric: 'model_dcg_at_k' Value: 0.28571 __main__.py:202 DEBUG Metric: 'model_f1_score_at_k' Value: 0.04382 __main__.py:202 DEBUG Metric: 'model_hit_at_k' Value: 0.46013 __main__.py:202 DEBUG Metric: 'model_ndcg_at_k' Value: 0.28571 __main__.py:202 DEBUG Metric: 'model_precision_at_k' Value: 0.02301 __main__.py:202 DEBUG Metric: 'model_r_precision' Value: 0.19598 __main__.py:202 DEBUG Metric: 'model_recall_at_k' Value: 0.46013 __main__.py:202 DEBUG Metric: 'model_reciprocal_rank' Value: 0.19598 __main__.py:202 INFO Done ✨ __main__.py:204 INFO Saving fine-tuned models ... __main__.py:207 INFO Saving model 'model' in /usr/src/app/tuned-models/model ... __main__.py:218 INFO Pushing saved model to Jina AI Cloud ... __main__.py:225 [16:39:41] INFO Pushed model artifact ID: '62b33cb0037ad91ca7f20530' __main__.py:231 INFO Finished 🚀 __main__.py:233 __main__.py:248
After the run has finished successfully, you can download the tuned model on your local machine:
artifact = run.save_artifact('resnet-model')
Now you saved the
artifact into your host machine,
let’s use the fine-tuned model to encode a new
Inference with ONNX
In case you set
to_onnx=True when calling
model = finetuner.get_model(artifact, is_onnx=True)
query = DocumentArray([query_data]) model = finetuner.get_model(artifact=artifact, device='cuda') finetuner.encode(model=model, data=query) finetuner.encode(model=model, data=index_data) assert query.embeddings.shape == (1, 2048)
And finally, you can use the embedded
query to find top-k visually related images within
index_data as follows:
query.match(index_data, limit=10, metric='cosine')
Before and after#
We can directly compare the results of our fine-tuned model with its zero-shot counterpart to get a better idea of how finetuning affects the results of a search. While the differences between the two models may be subtle for some queries, some of the examples below (such as the second example) show that the model after fine-tuning is able to better match similar images.
To save you some time, we have plotted some examples where the model’s ability to return similar images has clearly improved:
On the other hand, there are also cases where the fine-tuned model performs worse, and fails to correctly match images that it previously could. This case is much rarer than the previous case. For this dataset, there were 108 occasions where the fine-tuned model returned the correct pair where it couldn’t before and only 33 occasions where the finetuned model returned an incorrect image after fine-tuning but returned a correct one before. Nevertheless, it still can happen: