Tailor#
Tailor is a component of Finetuner. It converts any general model into an embedding model. Given a general model (either written from scratch, or from PyTorch/Keras/Huggingface model zoo), Tailor performs micro-operations on the model architecture and outputs an embedding model for the Tuner.
Given a general model with weights, Tailor preserves its weights and performs (some of) the following steps:
finding all dense layers by iterating over layers;
chopping off all layers after a certain dense layer;
freezing weights of specific layers.
adding a new bottleneck module on top of the embedding model.
Finally, Tailor outputs an embedding model that can be fine-tuned in Tuner. To make the best use of Tailor, you could follow the journey of:
display
model summary.Select an embedding layer as your model output.
Decide your
freeze
strategy.(optional) Attach a bottleneck module.
display
model summary#
Tailor provides a helper function finetuner.display()
that gives a table summary of a Keras/PyTorch/Paddle model.
Let’s see how to create an embedding model with ResNet-50.
Load a pre-trained ResNet-50 via your favourite deep learning backend and call
display
.import torchvision import finetuner as ft model = torchvision.models.resnet50(pretrained=True) ft.display(model, input_size=(3, 224, 224))
import tensorflow as tf import finetuner as ft model = tf.keras.applications.ResNet50(weights='imagenet') ft.display(model)
import paddle import finetuner as ft model = paddle.vision.models.resnet50(pretrained=True) ft.display(model, input_size=(3, 224, 224))
Now you could see the model tabular summary in your console/notebook. The summary shows the overall layers, output shapes parameters and trainable layers.
name output_shape_display nb_params trainable ────────────────────────────────────────────────────────────────────── conv2d_1 [64, 112, 112] 9408 True batchnorm2d_2 [64, 112, 112] 128 True relu_3 [64, 112, 112] 0 False maxpool2d_4 [64, 56, 56] 0 False conv2d_5 [64, 56, 56] 4096 True batchnorm2d_6 [64, 56, 56] 128 True relu_7 [64, 56, 56] 0 False conv2d_8 [64, 56, 56] 36864 True batchnorm2d_9 [64, 56, 56] 128 True relu_10 [64, 56, 56] 0 False conv2d_11 [256, 56, 56] 16384 True batchnorm2d_12 [256, 56, 56] 512 True conv2d_13 [256, 56, 56] 16384 True batchnorm2d_14 [256, 56, 56] 512 True relu_15 [256, 56, 56] 0 False bottleneck_16 [256, 56, 56] 0 False conv2d_17 [64, 56, 56] 16384 True batchnorm2d_18 [64, 56, 56] 128 True relu_19 [64, 56, 56] 0 False conv2d_20 [64, 56, 56] 36864 True batchnorm2d_21 [64, 56, 56] 128 True relu_22 [64, 56, 56] 0 False conv2d_23 [256, 56, 56] 16384 True batchnorm2d_24 [256, 56, 56] 512 True relu_25 [256, 56, 56] 0 False bottleneck_26 [256, 56, 56] 0 False conv2d_27 [64, 56, 56] 16384 True batchnorm2d_28 [64, 56, 56] 128 True relu_29 [64, 56, 56] 0 False conv2d_30 [64, 56, 56] 36864 True batchnorm2d_31 [64, 56, 56] 128 True relu_32 [64, 56, 56] 0 False conv2d_33 [256, 56, 56] 16384 True batchnorm2d_34 [256, 56, 56] 512 True relu_35 [256, 56, 56] 0 False bottleneck_36 [256, 56, 56] 0 False conv2d_37 [128, 56, 56] 32768 True batchnorm2d_38 [128, 56, 56] 256 True relu_39 [128, 56, 56] 0 False conv2d_40 [128, 28, 28] 147456 True batchnorm2d_41 [128, 28, 28] 256 True relu_42 [128, 28, 28] 0 False conv2d_43 [512, 28, 28] 65536 True batchnorm2d_44 [512, 28, 28] 1024 True conv2d_45 [512, 28, 28] 131072 True batchnorm2d_46 [512, 28, 28] 1024 True relu_47 [512, 28, 28] 0 False bottleneck_48 [512, 28, 28] 0 False conv2d_49 [128, 28, 28] 65536 True batchnorm2d_50 [128, 28, 28] 256 True relu_51 [128, 28, 28] 0 False conv2d_52 [128, 28, 28] 147456 True batchnorm2d_53 [128, 28, 28] 256 True relu_54 [128, 28, 28] 0 False conv2d_55 [512, 28, 28] 65536 True batchnorm2d_56 [512, 28, 28] 1024 True relu_57 [512, 28, 28] 0 False bottleneck_58 [512, 28, 28] 0 False conv2d_59 [128, 28, 28] 65536 True batchnorm2d_60 [128, 28, 28] 256 True relu_61 [128, 28, 28] 0 False conv2d_62 [128, 28, 28] 147456 True batchnorm2d_63 [128, 28, 28] 256 True relu_64 [128, 28, 28] 0 False conv2d_65 [512, 28, 28] 65536 True batchnorm2d_66 [512, 28, 28] 1024 True relu_67 [512, 28, 28] 0 False bottleneck_68 [512, 28, 28] 0 False conv2d_69 [128, 28, 28] 65536 True batchnorm2d_70 [128, 28, 28] 256 True relu_71 [128, 28, 28] 0 False conv2d_72 [128, 28, 28] 147456 True batchnorm2d_73 [128, 28, 28] 256 True relu_74 [128, 28, 28] 0 False conv2d_75 [512, 28, 28] 65536 True batchnorm2d_76 [512, 28, 28] 1024 True relu_77 [512, 28, 28] 0 False bottleneck_78 [512, 28, 28] 0 False conv2d_79 [256, 28, 28] 131072 True batchnorm2d_80 [256, 28, 28] 512 True relu_81 [256, 28, 28] 0 False conv2d_82 [256, 14, 14] 589824 True batchnorm2d_83 [256, 14, 14] 512 True relu_84 [256, 14, 14] 0 False conv2d_85 [1024, 14, 14] 262144 True batchnorm2d_86 [1024, 14, 14] 2048 True conv2d_87 [1024, 14, 14] 524288 True batchnorm2d_88 [1024, 14, 14] 2048 True relu_89 [1024, 14, 14] 0 False bottleneck_90 [1024, 14, 14] 0 False conv2d_91 [256, 14, 14] 262144 True batchnorm2d_92 [256, 14, 14] 512 True relu_93 [256, 14, 14] 0 False conv2d_94 [256, 14, 14] 589824 True batchnorm2d_95 [256, 14, 14] 512 True relu_96 [256, 14, 14] 0 False conv2d_97 [1024, 14, 14] 262144 True batchnorm2d_98 [1024, 14, 14] 2048 True relu_99 [1024, 14, 14] 0 False bottleneck_100 [1024, 14, 14] 0 False conv2d_101 [256, 14, 14] 262144 True batchnorm2d_102 [256, 14, 14] 512 True relu_103 [256, 14, 14] 0 False conv2d_104 [256, 14, 14] 589824 True batchnorm2d_105 [256, 14, 14] 512 True relu_106 [256, 14, 14] 0 False conv2d_107 [1024, 14, 14] 262144 True batchnorm2d_108 [1024, 14, 14] 2048 True relu_109 [1024, 14, 14] 0 False bottleneck_110 [1024, 14, 14] 0 False conv2d_111 [256, 14, 14] 262144 True batchnorm2d_112 [256, 14, 14] 512 True relu_113 [256, 14, 14] 0 False conv2d_114 [256, 14, 14] 589824 True batchnorm2d_115 [256, 14, 14] 512 True relu_116 [256, 14, 14] 0 False conv2d_117 [1024, 14, 14] 262144 True batchnorm2d_118 [1024, 14, 14] 2048 True relu_119 [1024, 14, 14] 0 False bottleneck_120 [1024, 14, 14] 0 False conv2d_121 [256, 14, 14] 262144 True batchnorm2d_122 [256, 14, 14] 512 True relu_123 [256, 14, 14] 0 False conv2d_124 [256, 14, 14] 589824 True batchnorm2d_125 [256, 14, 14] 512 True relu_126 [256, 14, 14] 0 False conv2d_127 [1024, 14, 14] 262144 True batchnorm2d_128 [1024, 14, 14] 2048 True relu_129 [1024, 14, 14] 0 False bottleneck_130 [1024, 14, 14] 0 False conv2d_131 [256, 14, 14] 262144 True batchnorm2d_132 [256, 14, 14] 512 True relu_133 [256, 14, 14] 0 False conv2d_134 [256, 14, 14] 589824 True batchnorm2d_135 [256, 14, 14] 512 True relu_136 [256, 14, 14] 0 False conv2d_137 [1024, 14, 14] 262144 True batchnorm2d_138 [1024, 14, 14] 2048 True relu_139 [1024, 14, 14] 0 False bottleneck_140 [1024, 14, 14] 0 False conv2d_141 [512, 14, 14] 524288 True batchnorm2d_142 [512, 14, 14] 1024 True relu_143 [512, 14, 14] 0 False conv2d_144 [512, 7, 7] 2359296 True batchnorm2d_145 [512, 7, 7] 1024 True relu_146 [512, 7, 7] 0 False conv2d_147 [2048, 7, 7] 1048576 True batchnorm2d_148 [2048, 7, 7] 4096 True conv2d_149 [2048, 7, 7] 2097152 True batchnorm2d_150 [2048, 7, 7] 4096 True relu_151 [2048, 7, 7] 0 False bottleneck_152 [2048, 7, 7] 0 False conv2d_153 [512, 7, 7] 1048576 True batchnorm2d_154 [512, 7, 7] 1024 True relu_155 [512, 7, 7] 0 False conv2d_156 [512, 7, 7] 2359296 True batchnorm2d_157 [512, 7, 7] 1024 True relu_158 [512, 7, 7] 0 False conv2d_159 [2048, 7, 7] 1048576 True batchnorm2d_160 [2048, 7, 7] 4096 True relu_161 [2048, 7, 7] 0 False bottleneck_162 [2048, 7, 7] 0 False conv2d_163 [512, 7, 7] 1048576 True batchnorm2d_164 [512, 7, 7] 1024 True relu_165 [512, 7, 7] 0 False conv2d_166 [512, 7, 7] 2359296 True batchnorm2d_167 [512, 7, 7] 1024 True relu_168 [512, 7, 7] 0 False conv2d_169 [2048, 7, 7] 1048576 True batchnorm2d_170 [2048, 7, 7] 4096 True relu_171 [2048, 7, 7] 0 False bottleneck_172 [2048, 7, 7] 0 False adaptiveavgpool2d_173 [2048, 1, 1] 0 False linear_174 [1000] 2049000 True Green layers are trainable layers, Cyan layers are non-trainable layers or frozen layers. Gray layers indicates this layer has been replaced by an Identity layer. Use to_embedding_model(...) to create embedding model.
name output_shape_display nb_params trainable ──────────────────────────────────────────────────────────────────── input_2 [] 0 False conv1_pad [230, 230, 3] 0 False conv1_conv [112, 112, 64] 9472 True conv1_bn [112, 112, 64] 256 True conv1_relu [112, 112, 64] 0 False pool1_pad [114, 114, 64] 0 False pool1_pool [56, 56, 64] 0 False conv2_block1_1_conv [56, 56, 64] 4160 True conv2_block1_1_bn [56, 56, 64] 256 True conv2_block1_1_relu [56, 56, 64] 0 False conv2_block1_2_conv [56, 56, 64] 36928 True conv2_block1_2_bn [56, 56, 64] 256 True conv2_block1_2_relu [56, 56, 64] 0 False conv2_block1_0_conv [56, 56, 256] 16640 True conv2_block1_3_conv [56, 56, 256] 16640 True conv2_block1_0_bn [56, 56, 256] 1024 True conv2_block1_3_bn [56, 56, 256] 1024 True conv2_block1_add [56, 56, 256] 0 False conv2_block1_out [56, 56, 256] 0 False conv2_block2_1_conv [56, 56, 64] 16448 True conv2_block2_1_bn [56, 56, 64] 256 True conv2_block2_1_relu [56, 56, 64] 0 False conv2_block2_2_conv [56, 56, 64] 36928 True conv2_block2_2_bn [56, 56, 64] 256 True conv2_block2_2_relu [56, 56, 64] 0 False conv2_block2_3_conv [56, 56, 256] 16640 True conv2_block2_3_bn [56, 56, 256] 1024 True conv2_block2_add [56, 56, 256] 0 False conv2_block2_out [56, 56, 256] 0 False conv2_block3_1_conv [56, 56, 64] 16448 True conv2_block3_1_bn [56, 56, 64] 256 True conv2_block3_1_relu [56, 56, 64] 0 False conv2_block3_2_conv [56, 56, 64] 36928 True conv2_block3_2_bn [56, 56, 64] 256 True conv2_block3_2_relu [56, 56, 64] 0 False conv2_block3_3_conv [56, 56, 256] 16640 True conv2_block3_3_bn [56, 56, 256] 1024 True conv2_block3_add [56, 56, 256] 0 False conv2_block3_out [56, 56, 256] 0 False conv3_block1_1_conv [28, 28, 128] 32896 True conv3_block1_1_bn [28, 28, 128] 512 True conv3_block1_1_relu [28, 28, 128] 0 False conv3_block1_2_conv [28, 28, 128] 147584 True conv3_block1_2_bn [28, 28, 128] 512 True conv3_block1_2_relu [28, 28, 128] 0 False conv3_block1_0_conv [28, 28, 512] 131584 True conv3_block1_3_conv [28, 28, 512] 66048 True conv3_block1_0_bn [28, 28, 512] 2048 True conv3_block1_3_bn [28, 28, 512] 2048 True conv3_block1_add [28, 28, 512] 0 False conv3_block1_out [28, 28, 512] 0 False conv3_block2_1_conv [28, 28, 128] 65664 True conv3_block2_1_bn [28, 28, 128] 512 True conv3_block2_1_relu [28, 28, 128] 0 False conv3_block2_2_conv [28, 28, 128] 147584 True conv3_block2_2_bn [28, 28, 128] 512 True conv3_block2_2_relu [28, 28, 128] 0 False conv3_block2_3_conv [28, 28, 512] 66048 True conv3_block2_3_bn [28, 28, 512] 2048 True conv3_block2_add [28, 28, 512] 0 False conv3_block2_out [28, 28, 512] 0 False conv3_block3_1_conv [28, 28, 128] 65664 True conv3_block3_1_bn [28, 28, 128] 512 True conv3_block3_1_relu [28, 28, 128] 0 False conv3_block3_2_conv [28, 28, 128] 147584 True conv3_block3_2_bn [28, 28, 128] 512 True conv3_block3_2_relu [28, 28, 128] 0 False conv3_block3_3_conv [28, 28, 512] 66048 True conv3_block3_3_bn [28, 28, 512] 2048 True conv3_block3_add [28, 28, 512] 0 False conv3_block3_out [28, 28, 512] 0 False conv3_block4_1_conv [28, 28, 128] 65664 True conv3_block4_1_bn [28, 28, 128] 512 True conv3_block4_1_relu [28, 28, 128] 0 False conv3_block4_2_conv [28, 28, 128] 147584 True conv3_block4_2_bn [28, 28, 128] 512 True conv3_block4_2_relu [28, 28, 128] 0 False conv3_block4_3_conv [28, 28, 512] 66048 True conv3_block4_3_bn [28, 28, 512] 2048 True conv3_block4_add [28, 28, 512] 0 False conv3_block4_out [28, 28, 512] 0 False conv4_block1_1_conv [14, 14, 256] 131328 True conv4_block1_1_bn [14, 14, 256] 1024 True conv4_block1_1_relu [14, 14, 256] 0 False conv4_block1_2_conv [14, 14, 256] 590080 True conv4_block1_2_bn [14, 14, 256] 1024 True conv4_block1_2_relu [14, 14, 256] 0 False conv4_block1_0_conv [14, 14, 1024] 525312 True conv4_block1_3_conv [14, 14, 1024] 263168 True conv4_block1_0_bn [14, 14, 1024] 4096 True conv4_block1_3_bn [14, 14, 1024] 4096 True conv4_block1_add [14, 14, 1024] 0 False conv4_block1_out [14, 14, 1024] 0 False conv4_block2_1_conv [14, 14, 256] 262400 True conv4_block2_1_bn [14, 14, 256] 1024 True conv4_block2_1_relu [14, 14, 256] 0 False conv4_block2_2_conv [14, 14, 256] 590080 True conv4_block2_2_bn [14, 14, 256] 1024 True conv4_block2_2_relu [14, 14, 256] 0 False conv4_block2_3_conv [14, 14, 1024] 263168 True conv4_block2_3_bn [14, 14, 1024] 4096 True conv4_block2_add [14, 14, 1024] 0 False conv4_block2_out [14, 14, 1024] 0 False conv4_block3_1_conv [14, 14, 256] 262400 True conv4_block3_1_bn [14, 14, 256] 1024 True conv4_block3_1_relu [14, 14, 256] 0 False conv4_block3_2_conv [14, 14, 256] 590080 True conv4_block3_2_bn [14, 14, 256] 1024 True conv4_block3_2_relu [14, 14, 256] 0 False conv4_block3_3_conv [14, 14, 1024] 263168 True conv4_block3_3_bn [14, 14, 1024] 4096 True conv4_block3_add [14, 14, 1024] 0 False conv4_block3_out [14, 14, 1024] 0 False conv4_block4_1_conv [14, 14, 256] 262400 True conv4_block4_1_bn [14, 14, 256] 1024 True conv4_block4_1_relu [14, 14, 256] 0 False conv4_block4_2_conv [14, 14, 256] 590080 True conv4_block4_2_bn [14, 14, 256] 1024 True conv4_block4_2_relu [14, 14, 256] 0 False conv4_block4_3_conv [14, 14, 1024] 263168 True conv4_block4_3_bn [14, 14, 1024] 4096 True conv4_block4_add [14, 14, 1024] 0 False conv4_block4_out [14, 14, 1024] 0 False conv4_block5_1_conv [14, 14, 256] 262400 True conv4_block5_1_bn [14, 14, 256] 1024 True conv4_block5_1_relu [14, 14, 256] 0 False conv4_block5_2_conv [14, 14, 256] 590080 True conv4_block5_2_bn [14, 14, 256] 1024 True conv4_block5_2_relu [14, 14, 256] 0 False conv4_block5_3_conv [14, 14, 1024] 263168 True conv4_block5_3_bn [14, 14, 1024] 4096 True conv4_block5_add [14, 14, 1024] 0 False conv4_block5_out [14, 14, 1024] 0 False conv4_block6_1_conv [14, 14, 256] 262400 True conv4_block6_1_bn [14, 14, 256] 1024 True conv4_block6_1_relu [14, 14, 256] 0 False conv4_block6_2_conv [14, 14, 256] 590080 True conv4_block6_2_bn [14, 14, 256] 1024 True conv4_block6_2_relu [14, 14, 256] 0 False conv4_block6_3_conv [14, 14, 1024] 263168 True conv4_block6_3_bn [14, 14, 1024] 4096 True conv4_block6_add [14, 14, 1024] 0 False conv4_block6_out [14, 14, 1024] 0 False conv5_block1_1_conv [7, 7, 512] 524800 True conv5_block1_1_bn [7, 7, 512] 2048 True conv5_block1_1_relu [7, 7, 512] 0 False conv5_block1_2_conv [7, 7, 512] 2359808 True conv5_block1_2_bn [7, 7, 512] 2048 True conv5_block1_2_relu [7, 7, 512] 0 False conv5_block1_0_conv [7, 7, 2048] 2099200 True conv5_block1_3_conv [7, 7, 2048] 1050624 True conv5_block1_0_bn [7, 7, 2048] 8192 True conv5_block1_3_bn [7, 7, 2048] 8192 True conv5_block1_add [7, 7, 2048] 0 False conv5_block1_out [7, 7, 2048] 0 False conv5_block2_1_conv [7, 7, 512] 1049088 True conv5_block2_1_bn [7, 7, 512] 2048 True conv5_block2_1_relu [7, 7, 512] 0 False conv5_block2_2_conv [7, 7, 512] 2359808 True conv5_block2_2_bn [7, 7, 512] 2048 True conv5_block2_2_relu [7, 7, 512] 0 False conv5_block2_3_conv [7, 7, 2048] 1050624 True conv5_block2_3_bn [7, 7, 2048] 8192 True conv5_block2_add [7, 7, 2048] 0 False conv5_block2_out [7, 7, 2048] 0 False conv5_block3_1_conv [7, 7, 512] 1049088 True conv5_block3_1_bn [7, 7, 512] 2048 True conv5_block3_1_relu [7, 7, 512] 0 False conv5_block3_2_conv [7, 7, 512] 2359808 True conv5_block3_2_bn [7, 7, 512] 2048 True conv5_block3_2_relu [7, 7, 512] 0 False conv5_block3_3_conv [7, 7, 2048] 1050624 True conv5_block3_3_bn [7, 7, 2048] 8192 True conv5_block3_add [7, 7, 2048] 0 False conv5_block3_out [7, 7, 2048] 0 False avg_pool [2048] 0 False predictions [1000] 2049000 True Green layers are trainable layers, Cyan layers are non-trainable layers or frozen layers. Gray layers indicates this layer has been replaced by an Identity layer. Use to_embedding_model(...) to create embedding model.
name output_shape_display nb_params trainable ────────────────────────────────────────────────────────────────────── conv2d_1 [64, 112, 112] 9408 True batchnorm2d_2 [64, 112, 112] 256 True relu_3 [64, 112, 112] 0 False maxpool2d_4 [64, 56, 56] 0 False conv2d_5 [64, 56, 56] 4096 True batchnorm2d_6 [64, 56, 56] 256 True relu_7 [64, 56, 56] 0 False conv2d_8 [64, 56, 56] 36864 True batchnorm2d_9 [64, 56, 56] 256 True relu_10 [64, 56, 56] 0 False conv2d_11 [256, 56, 56] 16384 True batchnorm2d_12 [256, 56, 56] 1024 True conv2d_13 [256, 56, 56] 16384 True batchnorm2d_14 [256, 56, 56] 1024 True relu_15 [256, 56, 56] 0 False bottleneckblock_16 [256, 56, 56] 0 False conv2d_17 [64, 56, 56] 16384 True batchnorm2d_18 [64, 56, 56] 256 True relu_19 [64, 56, 56] 0 False conv2d_20 [64, 56, 56] 36864 True batchnorm2d_21 [64, 56, 56] 256 True relu_22 [64, 56, 56] 0 False conv2d_23 [256, 56, 56] 16384 True batchnorm2d_24 [256, 56, 56] 1024 True relu_25 [256, 56, 56] 0 False bottleneckblock_26 [256, 56, 56] 0 False conv2d_27 [64, 56, 56] 16384 True batchnorm2d_28 [64, 56, 56] 256 True relu_29 [64, 56, 56] 0 False conv2d_30 [64, 56, 56] 36864 True batchnorm2d_31 [64, 56, 56] 256 True relu_32 [64, 56, 56] 0 False conv2d_33 [256, 56, 56] 16384 True batchnorm2d_34 [256, 56, 56] 1024 True relu_35 [256, 56, 56] 0 False bottleneckblock_36 [256, 56, 56] 0 False conv2d_37 [128, 56, 56] 32768 True batchnorm2d_38 [128, 56, 56] 512 True relu_39 [128, 56, 56] 0 False conv2d_40 [128, 28, 28] 147456 True batchnorm2d_41 [128, 28, 28] 512 True relu_42 [128, 28, 28] 0 False conv2d_43 [512, 28, 28] 65536 True batchnorm2d_44 [512, 28, 28] 2048 True conv2d_45 [512, 28, 28] 131072 True batchnorm2d_46 [512, 28, 28] 2048 True relu_47 [512, 28, 28] 0 False bottleneckblock_48 [512, 28, 28] 0 False conv2d_49 [128, 28, 28] 65536 True batchnorm2d_50 [128, 28, 28] 512 True relu_51 [128, 28, 28] 0 False conv2d_52 [128, 28, 28] 147456 True batchnorm2d_53 [128, 28, 28] 512 True relu_54 [128, 28, 28] 0 False conv2d_55 [512, 28, 28] 65536 True batchnorm2d_56 [512, 28, 28] 2048 True relu_57 [512, 28, 28] 0 False bottleneckblock_58 [512, 28, 28] 0 False conv2d_59 [128, 28, 28] 65536 True batchnorm2d_60 [128, 28, 28] 512 True relu_61 [128, 28, 28] 0 False conv2d_62 [128, 28, 28] 147456 True batchnorm2d_63 [128, 28, 28] 512 True relu_64 [128, 28, 28] 0 False conv2d_65 [512, 28, 28] 65536 True batchnorm2d_66 [512, 28, 28] 2048 True relu_67 [512, 28, 28] 0 False bottleneckblock_68 [512, 28, 28] 0 False conv2d_69 [128, 28, 28] 65536 True batchnorm2d_70 [128, 28, 28] 512 True relu_71 [128, 28, 28] 0 False conv2d_72 [128, 28, 28] 147456 True batchnorm2d_73 [128, 28, 28] 512 True relu_74 [128, 28, 28] 0 False conv2d_75 [512, 28, 28] 65536 True batchnorm2d_76 [512, 28, 28] 2048 True relu_77 [512, 28, 28] 0 False bottleneckblock_78 [512, 28, 28] 0 False conv2d_79 [256, 28, 28] 131072 True batchnorm2d_80 [256, 28, 28] 1024 True relu_81 [256, 28, 28] 0 False conv2d_82 [256, 14, 14] 589824 True batchnorm2d_83 [256, 14, 14] 1024 True relu_84 [256, 14, 14] 0 False conv2d_85 [1024, 14, 14] 262144 True batchnorm2d_86 [1024, 14, 14] 4096 True conv2d_87 [1024, 14, 14] 524288 True batchnorm2d_88 [1024, 14, 14] 4096 True relu_89 [1024, 14, 14] 0 False bottleneckblock_90 [1024, 14, 14] 0 False conv2d_91 [256, 14, 14] 262144 True batchnorm2d_92 [256, 14, 14] 1024 True relu_93 [256, 14, 14] 0 False conv2d_94 [256, 14, 14] 589824 True batchnorm2d_95 [256, 14, 14] 1024 True relu_96 [256, 14, 14] 0 False conv2d_97 [1024, 14, 14] 262144 True batchnorm2d_98 [1024, 14, 14] 4096 True relu_99 [1024, 14, 14] 0 False bottleneckblock_100 [1024, 14, 14] 0 False conv2d_101 [256, 14, 14] 262144 True batchnorm2d_102 [256, 14, 14] 1024 True relu_103 [256, 14, 14] 0 False conv2d_104 [256, 14, 14] 589824 True batchnorm2d_105 [256, 14, 14] 1024 True relu_106 [256, 14, 14] 0 False conv2d_107 [1024, 14, 14] 262144 True batchnorm2d_108 [1024, 14, 14] 4096 True relu_109 [1024, 14, 14] 0 False bottleneckblock_110 [1024, 14, 14] 0 False conv2d_111 [256, 14, 14] 262144 True batchnorm2d_112 [256, 14, 14] 1024 True relu_113 [256, 14, 14] 0 False conv2d_114 [256, 14, 14] 589824 True batchnorm2d_115 [256, 14, 14] 1024 True relu_116 [256, 14, 14] 0 False conv2d_117 [1024, 14, 14] 262144 True batchnorm2d_118 [1024, 14, 14] 4096 True relu_119 [1024, 14, 14] 0 False bottleneckblock_120 [1024, 14, 14] 0 False conv2d_121 [256, 14, 14] 262144 True batchnorm2d_122 [256, 14, 14] 1024 True relu_123 [256, 14, 14] 0 False conv2d_124 [256, 14, 14] 589824 True batchnorm2d_125 [256, 14, 14] 1024 True relu_126 [256, 14, 14] 0 False conv2d_127 [1024, 14, 14] 262144 True batchnorm2d_128 [1024, 14, 14] 4096 True relu_129 [1024, 14, 14] 0 False bottleneckblock_130 [1024, 14, 14] 0 False conv2d_131 [256, 14, 14] 262144 True batchnorm2d_132 [256, 14, 14] 1024 True relu_133 [256, 14, 14] 0 False conv2d_134 [256, 14, 14] 589824 True batchnorm2d_135 [256, 14, 14] 1024 True relu_136 [256, 14, 14] 0 False conv2d_137 [1024, 14, 14] 262144 True batchnorm2d_138 [1024, 14, 14] 4096 True relu_139 [1024, 14, 14] 0 False bottleneckblock_140 [1024, 14, 14] 0 False conv2d_141 [512, 14, 14] 524288 True batchnorm2d_142 [512, 14, 14] 2048 True relu_143 [512, 14, 14] 0 False conv2d_144 [512, 7, 7] 2359296 True batchnorm2d_145 [512, 7, 7] 2048 True relu_146 [512, 7, 7] 0 False conv2d_147 [2048, 7, 7] 1048576 True batchnorm2d_148 [2048, 7, 7] 8192 True conv2d_149 [2048, 7, 7] 2097152 True batchnorm2d_150 [2048, 7, 7] 8192 True relu_151 [2048, 7, 7] 0 False bottleneckblock_152 [2048, 7, 7] 0 False conv2d_153 [512, 7, 7] 1048576 True batchnorm2d_154 [512, 7, 7] 2048 True relu_155 [512, 7, 7] 0 False conv2d_156 [512, 7, 7] 2359296 True batchnorm2d_157 [512, 7, 7] 2048 True relu_158 [512, 7, 7] 0 False conv2d_159 [2048, 7, 7] 1048576 True batchnorm2d_160 [2048, 7, 7] 8192 True relu_161 [2048, 7, 7] 0 False bottleneckblock_162 [2048, 7, 7] 0 False conv2d_163 [512, 7, 7] 1048576 True batchnorm2d_164 [512, 7, 7] 2048 True relu_165 [512, 7, 7] 0 False conv2d_166 [512, 7, 7] 2359296 True batchnorm2d_167 [512, 7, 7] 2048 True relu_168 [512, 7, 7] 0 False conv2d_169 [2048, 7, 7] 1048576 True batchnorm2d_170 [2048, 7, 7] 8192 True relu_171 [2048, 7, 7] 0 False bottleneckblock_172 [2048, 7, 7] 0 False adaptiveavgpool2d_173 [2048, 1, 1] 0 False linear_174 [1000] 2049000 True Green layers are trainable layers, Cyan layers are non-trainable layers or frozen layers. Gray layers indicates this layer has been replaced by an Identity layer. Use to_embedding_model(...) to create embedding model.
Select an embedding layer as your model output.#
After plotted the table summary of the pre-trained model, we can decide which layer can be used as the “embedding layer”.
As you can see in the above ResNet-50, layer with name linear_174
(pytorch/paddle) or predictions
(keras) is the final classification layer for classify 1000 ImageNet classes.
This layer is not a good chocie of “embedding layer”. We’ll use adaptiveavgpool2d_173
(pytorch/paddle) or avg_pool
(keras) as our “embedding layer”.
Keep the layer name in mind, we’ll put everything together after we decide freeze
strategy and bottleneck layer
.
Decide your freeze
strategy.#
To apply a pre-trained model on your downstream task, in most cases you do not want to train everything from scratch.
Finetuner allows you to freeze the entire model or freeze specific layers, with freeze
argument.
If freeze=True
, finetuner will freeze the entire pre-trained model. If freeze is a list of string layer names,
for example, freeze=['conv2d_1', 'conv2d_5']
will only freeze these two layers.
Again, keep this in mind, we’ll put everything together after we attach the bottleneck layer
.
Attach a bottleneck module.#
Sometimes you want to add a bottleneck module or projection head on top of your embedding model. This bottleneck layer should be a simple multi-layer perceptron. This could help you improve your embedding quality and potentially reduce the dimensionality.
Tailor allows you to attach this small bottleneck module on top of your embedding model. In the below example, we put everything together, including:
choose embedding layer
freezing weights of specific layers.
attach a bottleneck layer
with the method called to_embedding_model
method.
Tailor provides a high-level API
finetuner.tailor.to_embedding_model
, which can be used as follows:import torch.nn as nn import torchvision import finetuner as ft model = torchvision.models.resnet50(pretrained=True) class SimpleMLP(nn.Module): def __init__(self): super().__init__() self.l1 = nn.Linear(2048, 2048, bias=True) self.relu = nn.ReLU() self.l2 = nn.Linear(2048, 1024, bias=True) def forward(self, x): return self.l2((self.relu(self.l1(x)))) new_model = ft.tailor.to_embedding_model( model=model, layer_name='adaptiveavgpool2d_173', freeze=['conv2d_1', 'batchnorm2d_2'], # or set to True to freeze the entire model projection_head=SimpleMLP(), input_size=(3, 224, 224), )
import tensorflow as tf import finetuner as ft model = tf.keras.applications.ResNet50(weights='imagenet') bottleneck_model = tf.keras.models.Sequential() bottleneck_model.add(tf.keras.layers.InputLayer(input_shape=(2048,))) bottleneck_model.add(tf.keras.layers.Dense(2048, activation='relu')) bottleneck_model.add(tf.keras.layers.Dense(1024)) new_model = ft.tailor.to_embedding_model( model=model, layer_name='avg_pool', freeze=['conv1_conv', 'conv1_bn'], # or set to True to freeze the entire model projection_head=bottleneck_model, input_size=(3, 224, 224), )
import paddle import paddle.nn as nn import finetuner as ft model = paddle.vision.models.resnet50(pretrained=True) class SimpleMLP(nn.Layer): def __init__(self): super().__init__() self.l1 = nn.Linear(2048, 2048) self.relu = nn.ReLU() self.l2 = nn.Linear(2048, 1024) def forward(self, x): return self.l2(self.relu(self.l1(x))) new_model = ft.tailor.to_embedding_model( model=model, layer_name='adaptiveavgpool2d_173', freeze=['conv2d_1', 'batchnorm2d_2'], # or set to True to freeze the entire model projection_head=SimpleMLP(), input_size=(3, 224, 224), )
if you display
the new_model
, you’ll notice that the fc
layer has been replaced with as Identity
.
The layers you specified to freeze
are not trainable anymore.
Last but not least, you attached a trainable bottleneck module and reduced the ouput dimensionality to 1024
.
Tips#
For PyTorch/Paddle models, having the correct
input_size
andinput_dtype
is fundamental to useto_embedding_model
anddisplay
.You can chop off layers and concat new layers afterward. To get the accurate layer name, you can first use
display
to list all layers.Different frameworks may give different layer names. Often, PyTorch and Paddle layer names are consistent.