Aurelio logo
Updated on August 13, 2024

Fine-Tuning in Sentence Transformers 3

Information Retrieval

Open In GitHub

Embedding models are one of the backbones of successful Retrieval-Augmented Generation (RAG) applications, crucial in retrieving relevant contexts to generate accurate answers. However, we often train embedding models on general knowledge, which limits their effectiveness when applied to specific domains or company-specific data. Customizing embeddings for your particular use case can improve the retrieval performance of your RAG application, leading to more accurate and relevant results.

With the release of Sentence Transformers 3, fine-tuning embedding models has become more accessible and efficient than before. This powerful library allows developers and researchers to fine-tune or enhance pre-trained models with domain-specific data, potentially achieving performance comparable to or surpassing proprietary models at a fraction of the cost.

This blog post will walk you through fine-tuning an embedding model using Sentence Transformers 3. We'll demonstrate how, with just a modest dataset and accessible computational resources, you can significantly improve retrieval performance for your specific use case. Our example will focus on fine-tuning the all-mpnet-base-v2 model for biomedical question answering, showcasing the potential for domain-specific improvements.

We'll cover everything from dataset preparation and model selection to fine-tuning, including:

  • Installing necessary libraries and setting up your environment
  • Preparing and formatting your dataset
  • Establishing a baseline and evaluation protocol
  • Defining an appropriate loss function
  • Configuring training arguments
  • Fine-tuning the model using the SentenceTransformerTrainer
  • Evaluating the results and comparing performance

By the end of this tutorial, you'll have a clear understanding of how to leverage Sentence Transformers 3 to create custom embedding models tailored to your specific needs, potentially kickstarting your own data flywheel and improving your RAG application's performance.


Prerequisites

We will install the following libraries:

  • Pytorch
  • Sentence Transformers (HF)
  • Transformers (HF)
  • Datasets (HF)

Throughout this tutorial, we are using Python 3.11.5.

text
!pip install --upgrade \
"torch==2.1.2" \
"tensorboard==2.17.0" \
"sentence-transformers==3.0.1" \
"datasets==2.19.1" \
"transformers==4.41.2" \
"accelerate==0.31.0"

After installing the necessary libraries, you should register on Hugging Face as we are going to use Hugging Face Hub to push our models and training logs.

Get your access token here.

python
# Log into your HF account and store your token (access key) on the disk
from huggingface_hub import login

login(token="ADD YOUR TOKEN HERE", add_to_git_credential=False)

Dataset Preparation

The Hugging Face Hub has many datasets that we can use to fine-tune embedding models. You can look here at the required dataset structure needed for fine-tuning embeddings.

We will use enelpol/rag-mini-bioasq, which includes 4,719 question-answer passages from the BioASQ challenge datasets for biomedical semantic indexing and Question Answering (QA). We will use this dataset as a Positive Pair configuration.

We must load the dataset using the Hugging Face datasets library.

python
from datasets import load_dataset

# Load dataset from HF hub
train_dataset = load_dataset("enelpol/rag-mini-bioasq", name="question-answer-passages", split="train")
test_dataset = load_dataset("enelpol/rag-mini-bioasq", name="question-answer-passages", split="test")

print(train_dataset[0])
print(test_dataset[0])
text
Downloading readme: 100%|██████████| 1.76k/1.76k [00:00<00:00, 3.91MB/s]
Downloading data: 100%|██████████| 1.12M/1.12M [00:00<00:00, 1.27MB/s]
Downloading data: 100%|██████████| 187k/187k [00:00<00:00, 874kB/s]
Generating train split: 100%|██████████| 4012/4012 [00:00<00:00, 126521.01 examples/s]
Generating test split: 100%|██████████| 707/707 [00:00<00:00, 145418.44 examples/s]
{'question': 'What is the implication of histone lysine methylation in medulloblastoma?', 'answer': 'Aberrant patterns of H3K4, H3K9, and H3K27 histone lysine methylation were shown to result in histone code alterations, which induce changes in gene expression, and affect the proliferation rate of cells in medulloblastoma.', 'id': 1682, 'relevant_passage_ids': [23179372, 19270706, 23184418]}
{'question': 'Is capmatinib effective for glioblastoma?', 'answer': 'No. Combination of capmatinib buparlisib resulted in no clear activity in patients with recurrent PTEN-deficient glioblastoma.', 'id': 4213, 'relevant_passage_ids': [31776899]}

The dataset has the following format.

json
{"question": "<question>", "answer": "<answer with some information>", "id": "<id>", "relevant_passage_ids": "<[list of ids of relevant passages]>"},
{"question": "<question>", "answer": "<answer with some information>", "id": "<id>", "relevant_passage_ids": "<[list of ids of relevant passages]>"},
{"question": "<question>", "answer": "<answer with some information>", "id": "<id>", "relevant_passage_ids": "<[list of ids of relevant passages]>"},
...

Given that the format is a bit different from the format that we need to provide to 'Sentence-transformers', we have to select and rename the columns to match the expected format.

Once the formatting is ready, we save the train and test datasets to disk.

python
# Rename the columns
train_dataset = train_dataset.rename_column("question", "anchor")
train_dataset = train_dataset.rename_column("answer", "positive")
test_dataset = test_dataset.rename_column("question", "anchor")
test_dataset = test_dataset.rename_column("answer", "positive")

# Add "id" column if not present
if "id" not in train_dataset.column_names:
train_dataset = train_dataset.add_column("id", range(len(train_dataset)))
if "id" not in test_dataset.column_names:
test_dataset = test_dataset.add_column("id", range(len(test_dataset)))


# save datasets to disk
train_dataset.to_json("train_dataset.json", orient="records")
test_dataset.to_json("test_dataset.json", orient="records")
text
Creating json from Arrow format: 100%|██████████| 5/5 [00:00<00:00, 88.32ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 132.94ba/s]

Baseline and evaluation

Following dataset preparation, our next step is establishing a baseline method and evaluation protocol. This crucial step allows us to gauge the effectiveness of future model refinements against a known starting point. We'll assess how well a pre-existing model handles our specific data and how it performs after fine-tuning.

We've selected all-mpnet-base-v2 as our base model that we will fine-tune later. This model isn't particularly performant compared to other models of similar size, but let's see how far we can go with fine-tuning. With only 110 million parameters and a 768-dimensional embedding space, it obtains a score of 57.78 on the MTEB Leaderboard, which is lower than the performance of OpenAI's text-embedding-ada-002, which obtains a score of 60.99. We will also compare this model with the bge-base-en-v1.5, which also has 109 million parameters and a 768-dimensional embedding space. The bge-base-en-v1.5 achieves an impressive score of 63.55 on the MTEB Leaderboard.

Given that we want to improve the Information Retrieval (IR) capabilities of the embeddings, to quantify performance, we will employ the InformationRetrievalEvaluator. This tool assesses how well our model can fetch the most relevant documents for given queries. It calculates various performance metrics, including Mean Reciprocal Rank (MRR), Recall@K, and Normalized Discounted Cumulative Gain (NDCG). A useful explanation of these IR metrics can be found here.

To conduct our evaluation, we will utilize a comprehensive document pool combining train and test data for the corpus, while queries will be sourced exclusively from the test set. This approach ensures we assess the model’s ability to retrieve relevant documents from a larger corpus that includes unseen data, providing a more robust and realistic evaluation of its retrieval capabilities.

python
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets

model_id = "sentence-transformers/all-mpnet-base-v2"
large_model_id = "BAAI/bge-base-en-v1.5"

# Load the models
model = SentenceTransformer(
model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

large_model = SentenceTransformer(
large_model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)


# load the train and test datasets. Concatenate them into a single corpus dataset only to add retrieval difficulty
test_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")
corpus_dataset = concatenate_datasets([train_dataset, test_dataset])

# Convert the datasets to dictionaries
corpus = dict(zip(corpus_dataset["id"], corpus_dataset["positive"]))
queries = dict(zip(test_dataset["id"], test_dataset["anchor"]))

# Create a mapping of the relevant documents for each query.
# In this case, we only have 1 relevant document per query
relevant_docs = {}
for q_id in queries:
relevant_docs[q_id] = [q_id]


model_evaluator = InformationRetrievalEvaluator(
queries=queries,
corpus=corpus,
relevant_docs=relevant_docs,
name=model_id,
score_functions={"cosine": cos_sim},
)
text
Generating train split: 707 examples [00:00, 17752.37 examples/s]
Generating train split: 4012 examples [00:00, 173569.34 examples/s]

We use the 'model_evaluator' to evaluate the baseline bge-base reference model. Later, we will also use it to evaluate the fine-tuned model.

python
# Evaluate the models
model_results = model_evaluator(model)
large_model_results = model_evaluator(large_model)

#Display the results
print(model_results)
print(large_model_results)

We obtain the following results for both models.

Metricall-mpnet-base-v2bge-base-en-v1.5
accuracy@10.78500.8515
accuracy@30.87550.9349
accuracy@50.90240.9491
accuracy@100.92780.9590
precision@10.78500.8515
precision@30.29180.3116
precision@50.18050.1898
precision@100.09280.0959
recall@10.78500.8515
recall@30.87550.9349
recall@50.90240.9491
recall@100.92780.9590
ndcg@100.85710.9122
mrr@100.83480.8965
map@1000.83670.8973

Defining the Loss Function

In this case, we are using the MultipleNegativesRankingLoss to fine-tune our embedding model. We use this loss function to align with our dataset format, which consists of positive text pairs. You can take a look at dataset format information and loss function information to determine which loss function to use based on your use case.

python
from sentence_transformers.losses import MultipleNegativesRankingLoss

model_id = "sentence-transformers/all-mpnet-base-v2"

model = SentenceTransformer(model_id)

train_loss = MultipleNegativesRankingLoss(model)

Fine-tuning the Model

Now that we've prepared our data and model, we're ready to fine-tune our embedding model using the SentenceTransformersTrainer.

To configure our training process, we'll use the SentenceTransformerTrainingArguments class. This tool allows us to specify various parameters that can impact training performance and help with tracking and debugging. We'll be using parameter values based on those recommended in the Sentence Transformers documentation. However, it's important to note that these are just starting points. You should experiment with different values tailored to your specific dataset and task for optimal results.

python
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")

args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir="mpnet_base-bioasq",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=32,
per_device_eval_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # losses that use "in-batch negatives" benefit from no duplicates
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=100,
run_name="mpnet-base-bioasq-basic-training-args", # Will be used in W&B if `wandb` is installed
)

from sentence_transformers import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset.select_columns(["positive", "anchor"]),
loss=train_loss,
evaluator=model_evaluator,
)

# start training the model
trainer.train()

# The model will be saved to the hub and the output directory
trainer.save_model()

#Alternative to save the model: model.save_pretrained("models/mpnet-base-all-rest-of-name/final")

# push model to hub
trainer.model.push_to_hub("all-mpnet-base-v2-bioasq-1epoc-batch32-100")

text
| Step | Training Loss | Sentence-transformers/all-mpnet-base-v2 Cosine Accuracy@1 | Cosine Accuracy@3 | Cosine Accuracy@5 | Cosine Accuracy@10 | Cosine NDCG@10 | Cosine MRR@10 | Cosine MAP@100 |
|------|---------------|----------------------------------------------------------|-------------------|-------------------|---------------------|----------------|---------------|----------------|
| 100 | 0.115500 | 0.845827 | 0.934936 | 0.947666 | 0.960396 | 0.909272 | 0.892253 | 0.893660 |
The training process completed in 46 seconds over 1 epoch (126/126 steps).

The training on 4k samples took around 1 minute on an Nvidia A10G instance of Modal labs. At the time of writing (August 2024), the instance costs 1.1 USD/hour, which indicates a cost of less than 0.1 USD for the training.

Now we can evaluate the fine-tuned model using the 'model evaluator' from earlier.

python
from sentence_transformers import SentenceTransformer

fine_tuned_model = SentenceTransformer(
args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
# Evaluate the model
fine_tuned_results = model_evaluator(fine_tuned_model)

print(fine_tuned_results)

If we focus on only a couple of metrics that are more relevant in our case, we get the following information.

ModelMRR@10NDCG@10
all-mpnet-base-v2 (Baseline)0.83470.8571
bge-base-en-v1.50.89650.9122
all-mpnet-base-v2 Fine-tuned0.89190.9093

The fine-tuned model shows significant improvements over the baseline model, with a 6.85% increase in MRR@10 and a 6.09% increase in NDCG@10. It reached the performance level of the bge-base-en-v1.5 embeddings.

Conclusion

Embedding models play a crucial role in the success of RAG applications, as the quality of retrieved context directly impacts the generated answers. Using the Sentence Transformers 3 library, we fine-tuned the all-mpnet-base-v2 model on a biomedical question-answering dataset.

Results show substantial improvements:

  • MRR@10 increased from 0.8347 to 0.8919 (6.85% improvement)
  • NDCG@10 improved from 0.8571 to 0.9093 (6.09% improvement)

Our fine-tuned model achieved performance comparable to the more performant bge-base-en-v1.5 model despite starting from a lower baseline.

The fine-tuning process has become highly accessible and efficient. With only 4,719 question-answer pairs, we achieved these improvements in approximately one minute training time on an Nvidia A10G GPU. The estimated cost for this training was less than 0.1 USD, making it a cost-effective approach for enhancing domain-specific retrieval tasks. These results show the value of customizing embedding models for specific domains or use cases. We can get significant performance gains even with a relatively small dataset and minimal training time.