- finetuner.tuner.dataset package
- finetuner.tuner.keras package
- finetuner.tuner.miner package
- finetuner.tuner.paddle package
- finetuner.tuner.pytorch package
- finetuner.tuner.fit(embed_model, train_data, eval_data=None, preprocess_fn=None, collate_fn=None, epochs=10, batch_size=256, num_items_per_class=None, loss='SiameseLoss', optimizer=None, learning_rate=0.001, device='cpu', **kwargs)¶
Finetune the model on the training data.
embed_model (AnyDNN) – an embedding model
train_data (DocumentSequence) – Data on which to train the model
ForwardRef]) – Data on which to evaluate the model at the end of each epoch
ForwardRef]) – A pre-processing function, to apply pre-processing to documents on the fly. It should take as input the document in the dataset, and output whatever content the framework-specific dataloader (and model) would accept.
ForwardRef]) – The collation function to merge the content of individual items into a batch. Should accept a list with the content of each item, and output a tensor (or a list/dict of tensors) that feed directly into the embedding model
int) – Number of epochs to train the model
int) – The batch size to use for training and evaluation
BaseLoss]) – Which loss to use in training. Supported losses are: -
SiameseLossfor Siamese network -
TripletLossfor Triplet network
int]) – Number of items from a single class to include in the batch. Only relevant for class datasets
float) – Learning rate for the default optimizer. If you provide a custom optimizer, this learning rate will not apply.
ForwardRef]) – The optimizer to use for training. If none is passed, an Adam optimizer is used by default, with learning rate specified by the
str) – The device to which to move the model. Supported options are
- Return type
- finetuner.tuner.save(embed_model, model_path, *args, **kwargs)¶
Save the embedding model.
embed_model (AnyDNN) – The embedding model to save
str) – Path to file/folder where to save the model
args – Arguments to pass to framework-specific tuner’s
kwargs – Keyword arguments to pass to framework-specific tuner’s
- Return type