Fine-Tuning Sentence Transformers for Domain Tasks
When I trained ModernPubMedBERT, generic embedding models failed on medical terminology. This post covers the full pipeline: curating contrastive pairs, choosing loss functions, avoiding catastrophic forgetting, and benchmarking.
Generic embedding models are impressive. They handle most text reasonably well out of the box. But “reasonably well” isn’t good enough when your domain has specialized terminology, unique semantic relationships, and users who expect precision.
When I built ModernPubMedBERT, I learned this the hard way. Medical terminology defeated off-the-shelf models. “Myocardial infarction” and “heart attack” are semantically identical—but generic models often missed this. Acronyms like “MI,” “CHF,” and “COPD” were practically foreign languages.
Here’s the complete pipeline I developed to fix this.
Why Generic Models Fail in Domains
Sentence transformers are trained on general-purpose datasets: Wikipedia, web crawls, news articles. They learn broad semantic relationships but miss domain-specific nuances.
Consider these biomedical pairs:
| Query | Correct Match | Incorrect Match |
|---|---|---|
| ”treatment for hypertension" | "antihypertensive therapy" | "high blood pressure symptoms" |
| "NSAID adverse effects" | "ibuprofen side effects" | "pain medication benefits" |
| "metformin mechanism" | "biguanide action" | "diabetes medication” |
Generic models often rank the incorrect matches higher because they match surface-level keywords rather than understanding domain relationships.
The Training Pipeline
Step 1: Curating Training Data
Quality training data is everything. For biomedical embeddings, I used three sources:
- PubMed Abstracts: Scientific papers with structured language
- Clinical Notes: (De-identified) Real-world medical documentation
- Medical Ontologies: SNOMED, ICD-10, MeSH for synonym relationships
The key insight: contrastive learning needs hard negatives. Easy negatives (completely unrelated texts) don’t teach the model much. Hard negatives (semantically similar but not equivalent) force the model to learn fine-grained distinctions.
from datasets import Dataset
import random
def create_training_pairs(
documents: list[dict],
ontology: dict[str, list[str]]
) -> Dataset:
pairs = []
for doc in documents:
# Positive: Same concept, different phrasing
if concept := doc.get("concept"):
synonyms = ontology.get(concept, [])
for synonym in synonyms:
pairs.append({
"anchor": doc["text"],
"positive": synonym,
"label": 1.0
})
# Hard negative: Related but different concept
related_concepts = get_related_concepts(concept)
for related in related_concepts[:3]:
related_text = random.choice(ontology.get(related, [doc["text"]]))
pairs.append({
"anchor": doc["text"],
"negative": related_text,
"label": 0.0
})
return Dataset.from_list(pairs)
Step 2: Choosing the Right Loss Function
Not all contrastive losses are equal. After extensive experimentation, here’s what I learned:
MultipleNegativesRankingLoss works well when you have (anchor, positive) pairs and use in-batch negatives:
from sentence_transformers import losses
# Good for large batch sizes with diverse samples
loss = losses.MultipleNegativesRankingLoss(model)
TripletLoss is better when you have explicit (anchor, positive, negative) triplets:
# Better control over hard negatives
loss = losses.TripletLoss(
model,
distance_metric=losses.TripletDistanceMetric.COSINE,
triplet_margin=0.5
)
CosineSimilarityLoss works well with similarity scores:
# Good when you have graded similarity labels
loss = losses.CosineSimilarityLoss(model)
For ModernPubMedBERT, I used a combination:
from sentence_transformers import SentenceTransformer
from sentence_transformers.trainer import SentenceTransformerTrainer
def get_training_loss(model, dataset_type: str):
if dataset_type == "synonym_pairs":
return losses.MultipleNegativesRankingLoss(model)
elif dataset_type == "triplets":
return losses.TripletLoss(model)
elif dataset_type == "similarity_scored":
return losses.CosineSimilarityLoss(model)
Step 3: Avoiding Catastrophic Forgetting
Fine-tuning too aggressively destroys general knowledge. The model might perfectly understand “pulmonary embolism” but forget how to handle regular English.
Several techniques help:
Learning Rate Scheduling: Start low, decay further.
from transformers import get_linear_schedule_with_warmup
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=100,
num_training_steps=len(dataloader) * epochs
)
Mixed Training Data: Include some general-domain examples.
def create_mixed_dataset(
domain_data: Dataset,
general_data: Dataset,
domain_ratio: float = 0.7
) -> Dataset:
domain_size = int(len(domain_data) * domain_ratio)
general_size = len(domain_data) - domain_size
mixed = concatenate_datasets([
domain_data.shuffle().select(range(domain_size)),
general_data.shuffle().select(range(general_size))
])
return mixed.shuffle()
Regularization: Keep weights close to the original.
class EWC:
"""Elastic Weight Consolidation to prevent forgetting."""
def __init__(self, model, dataloader, importance: float = 1000):
self.importance = importance
self.params = {
n: p.clone().detach()
for n, p in model.named_parameters()
}
self.fisher = self._compute_fisher(model, dataloader)
def penalty(self, model) -> torch.Tensor:
loss = 0
for n, p in model.named_parameters():
loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
return self.importance * loss
Step 4: Training Configuration
The training configuration that worked best for ModernPubMedBERT:
from sentence_transformers import SentenceTransformerTrainingArguments
training_args = SentenceTransformerTrainingArguments(
output_dir="./models/modern-pubmed-bert",
num_train_epochs=3,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
learning_rate=2e-5,
weight_decay=0.01,
warmup_ratio=0.1,
fp16=True,
evaluation_strategy="steps",
eval_steps=500,
save_steps=1000,
logging_steps=100,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
)
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss_fn,
)
trainer.train()
Evaluation: BEIR Benchmarks
Training is nothing without rigorous evaluation. I used the BEIR benchmark suite, which tests retrieval across multiple domains.
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch
def evaluate_on_beir(model, dataset_name: str) -> dict:
# Load dataset
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
data_path = util.download_and_unzip(url, "./datasets")
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
# Create retriever
retriever = DenseRetrievalExactSearch(
model,
batch_size=64,
corpus_chunk_size=50000
)
# Evaluate
results = retriever.search(corpus, queries, top_k=100)
evaluation = EvaluateRetrieval.evaluate(qrels, results, k_values=[1, 10, 100])
return evaluation
Results Comparison
| Dataset | Generic BERT | PubMedBERT | ModernPubMedBERT |
|---|---|---|---|
| TREC-COVID | 0.59 | 0.68 | 0.74 |
| BioASQ | 0.43 | 0.51 | 0.58 |
| NFCorpus | 0.31 | 0.35 | 0.41 |
| SciFact | 0.62 | 0.67 | 0.72 |
The improvements are consistent across all biomedical benchmarks. More importantly, performance on general benchmarks (MS MARCO, Natural Questions) only dropped 2-3%—an acceptable trade-off.
Deployment to HuggingFace
Once trained and evaluated, deployment is straightforward:
from huggingface_hub import HfApi
def push_to_hub(model, model_name: str, readme_content: str):
# Save model locally first
model.save("./model_output")
# Create model card
with open("./model_output/README.md", "w") as f:
f.write(readme_content)
# Push to Hub
api = HfApi()
api.upload_folder(
folder_path="./model_output",
repo_id=f"lokeshch19/{model_name}",
repo_type="model"
)
The model card matters for discoverability:
# ModernPubMedBERT
A sentence-transformer model fine-tuned on PubMed abstracts for
biomedical semantic search.
## Usage
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("lokeshch19/ModernPubMedBERT")
sentences = ["myocardial infarction treatment", "heart attack therapy"]
embeddings = model.encode(sentences)
## Training
- Base model: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract
- Training data: 1M PubMed abstracts + medical ontologies
- Loss: MultipleNegativesRankingLoss + TripletLoss
## Benchmarks
[Include BEIR results table]
Lessons Learned
After shipping ModernPubMedBERT (now at 1,000+ monthly downloads), here’s what I’ve learned:
-
Data quality trumps quantity. 100k well-curated pairs beat 1M noisy pairs.
-
Hard negatives are essential. Without them, the model learns nothing useful.
-
Catastrophic forgetting is real. Plan for it from the start.
-
Benchmark on your actual use case. BEIR is great, but your retrieval task might have different characteristics.
-
Iterate quickly. Train small models first, validate the approach, then scale.
Domain-specific embeddings aren’t just nice to have—they’re often the difference between a system that works and one that frustrates users. The investment in fine-tuning pays dividends in retrieval quality.
Interested in building custom embedding models for your domain? Get in touch—I’d love to hear about your use case.