OpenAI's CLIP model is pretty nifty connecting text with images. It efficiently learns visual concepts from natural language supervision. Can we do the same for protein sequences and structure? Instead of text and images, we're integrating protein sequences with their structural information.
In this notebook we go through multimodal training using a contrastive learning framework, tailor-made for sequence-structure pretraining. By aligning antibody sequence data with their corresponding structural embeddings, we aim to develop a modeling approach that is better suited for antibody therapeutic property optimization such as antibody binding to its cognate antigen.
Key Steps
Trainer
class.!pip install torch transformers accelerate &> /dev/null
!pip install --upgrade huggingface_hub &> /dev/null
# To share your model with the community
# First store your authentication token from the Hugging Face website and then execute this cell
# Make sure to get token with WRITE access
from huggingface_hub import notebook_login
notebook_login()
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…
import os
from pathlib import Path
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
See previous two notebooks for obtaining and processing antibody sequence structure data, and obtaining antibody structure embedding using frozen ESM-IF1 model.
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)
path = Path("/content/gdrive/")
path_data = Path("/content/gdrive/MyDrive/data/proteinflow_esmif1_20240520-0899946")
Mounted at /content/gdrive
import pandas as pd
import torch
from torch.utils.data import Dataset
class AntibodyDataset(Dataset):
"""
Initialize the dataset.
Args:
data_path (str): Path to the pickle file containing data.
tokenizer (transformers.PreTrainedTokenizer): Tokenizer to process the sequences.
"""
def __init__(self, data_path, tokenizer):
self.data = pd.read_pickle(data_path)
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
"""
Get item by index.
"""
row = self.data.iloc[idx]
sequence = row['sequence']
embedding = torch.tensor(row['embedding'], dtype=torch.float32)
inputs = self.tokenizer(sequence, return_tensors='pt', padding=False,
truncation=False)
return {
'input_ids': inputs['input_ids'].squeeze(),
'attention_mask': inputs['attention_mask'].squeeze(),
'labels': embedding
}
Github copilot comment: The conversion of row['embedding'] to a tensor is done directly within getitem. This is generally fine, but if the dataset is large and this operation is costly, consider pre-processing steps or caching mechanisms.
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import DataLoader
# Initialize tokenizer
model_ckpt = 'facebook/esm2_t6_8M_UR50D'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
# Initialize datasets
train_ds = AntibodyDataset(path_data/'train_data.pkl', tokenizer)
valid_ds = AntibodyDataset(path_data/'valid_data.pkl', tokenizer)
test_ds = AntibodyDataset(path_data/'test_data.pkl', tokenizer)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: The secret `HF_TOKEN` does not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn(
tokenizer_config.json: 0%| | 0.00/95.0 [00:00<?, ?B/s]
vocab.txt: 0%| | 0.00/93.0 [00:00<?, ?B/s]
special_tokens_map.json: 0%| | 0.00/125 [00:00<?, ?B/s]
# Initialize DataCollator
data_collator = DataCollatorWithPadding(tokenizer)
# Initialize DataLoader with DataCollator
batch_size=5
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
batch = next(iter(train_dl))
#print(batch)
print(batch['input_ids'].shape)
print(batch['attention_mask'].shape)
print(batch['labels'].shape)
torch.Size([5, 634]) torch.Size([5, 634]) torch.Size([5, 512])
import numpy as np
import torch
import torch.nn as nn
class ESMCLIPHead(nn.Module):
"""Head for CLIP multimodal tasks."""
def __init__(self, config, projection_dim=512, dropout_prob=0.0):
super().__init__()
self.sequence_projection = nn.Sequential(
nn.Linear(config.hidden_size, projection_dim),
nn.GELU(),
nn.Dropout(dropout_prob)
)
self.structure_projection = nn.Sequential(
nn.Linear(projection_dim, projection_dim),
nn.GELU(),
nn.Dropout(dropout_prob)
)
def forward(self, sequence_output, structure_embeddings):
sequence_projected = self.sequence_projection(sequence_output[:, 0, :]) # [CLS] token pooling
structure_projected = self.structure_projection(structure_embeddings)
return sequence_projected, structure_projected
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import EsmPreTrainedModel, EsmModel, EsmConfig
class ESMForCLIP(EsmPreTrainedModel):
config_class = EsmConfig
def __init__(self, config):
super().__init__(config)
self.esm = EsmModel(config, add_pooling_layer=False)
self.clip_head = ESMCLIPHead(config)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, labels=None):
outputs = self.esm(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
sequence_output = outputs.last_hidden_state
sequence_projected, structure_projected = self.clip_head(sequence_output, labels)
# Clamp the logit scale value to ensure it does not exceed log(100)
self.logit_scale.data.clamp_(max=np.log(100.0))
return {
'logits': (sequence_projected, structure_projected)
}
from transformers import AutoConfig
clip_config = AutoConfig.from_pretrained(model_ckpt)
#clip_config.projection_dim = 512
#clip_config.dropout_rate = 0.0
#print(clip_config)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. warnings.warn(
config.json: 0%| | 0.00/775 [00:00<?, ?B/s]
model = ESMForCLIP.from_pretrained(model_ckpt, config=clip_config)
model
model.safetensors: 0%| | 0.00/31.4M [00:00<?, ?B/s]
Some weights of ESMForCLIP were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['clip_head.sequence_projection.0.bias', 'clip_head.sequence_projection.0.weight', 'clip_head.structure_projection.0.bias', 'clip_head.structure_projection.0.weight', 'logit_scale'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
ESMForCLIP( (esm): EsmModel( (embeddings): EsmEmbeddings( (word_embeddings): Embedding(33, 320, padding_idx=1) (dropout): Dropout(p=0.0, inplace=False) (position_embeddings): Embedding(1026, 320, padding_idx=1) ) (encoder): EsmEncoder( (layer): ModuleList( (0-5): 6 x EsmLayer( (attention): EsmAttention( (self): EsmSelfAttention( (query): Linear(in_features=320, out_features=320, bias=True) (key): Linear(in_features=320, out_features=320, bias=True) (value): Linear(in_features=320, out_features=320, bias=True) (dropout): Dropout(p=0.0, inplace=False) (rotary_embeddings): RotaryEmbedding() ) (output): EsmSelfOutput( (dense): Linear(in_features=320, out_features=320, bias=True) (dropout): Dropout(p=0.0, inplace=False) ) (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True) ) (intermediate): EsmIntermediate( (dense): Linear(in_features=320, out_features=1280, bias=True) ) (output): EsmOutput( (dense): Linear(in_features=1280, out_features=320, bias=True) (dropout): Dropout(p=0.0, inplace=False) ) (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True) ) ) (emb_layer_norm_after): LayerNorm((320,), eps=1e-05, elementwise_affine=True) ) (contact_head): EsmContactPredictionHead( (regression): Linear(in_features=120, out_features=1, bias=True) (activation): Sigmoid() ) ) (clip_head): ESMCLIPHead( (sequence_projection): Sequential( (0): Linear(in_features=320, out_features=512, bias=True) (1): GELU(approximate='none') (2): Dropout(p=0.0, inplace=False) ) (structure_projection): Sequential( (0): Linear(in_features=512, out_features=512, bias=True) (1): GELU(approximate='none') (2): Dropout(p=0.0, inplace=False) ) ) )
with torch.no_grad():
outputs = model.forward(**batch)
logits = outputs['logits']
print(logits[0].shape)
print(logits[1].shape)
torch.Size([5, 512]) torch.Size([5, 512])
Scaled pairwise cosine similarities are a crucial component in many contrastive learning frameworks, including CLIP (Contrastive Language-Image Pre-training).
Cosine similarity measures the cosine of the angle between two vectors in an inner product space. It is a measure of similarity between two non-zero vectors, giving a value between -1 and 1.
For two vectors ( A ) and ( B ): $$ \text{cosine_similarity}(A, B) = \frac{A \cdot B}{\|A\| \|B\|} $$
In PyTorch, cosine similarity between two sets of embeddings can be computed using:
import torch.nn.functional as F
cos_sim = F.cosine_similarity(embedding1, embedding2, dim=-1)
Pairwise cosine similarity computes the cosine similarity between each pair of vectors from two sets of vectors. This is useful in comparing all possible pairs in a batch.
For sequence embeddings $ \text{seq_embeddings} $ and structure embeddings $ \text{struct_embeddings}$:
cos_sim_matrix = torch.mm(seq_embeddings, struct_embeddings.t())
This gives a matrix of cosine similarities between each sequence and each structure embedding.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ContrastiveLoss(nn.Module):
def __init__(self):
super(ContrastiveLoss, self).__init__()
def forward(self, seq_embeddings, struct_embeddings, logit_scale):
# Normalize embeddings to unit vectors
seq_embeddings = F.normalize(seq_embeddings, dim=1)
struct_embeddings = F.normalize(struct_embeddings, dim=1)
# Compute pairwise cosine similarities and scale with temperature
temperature = torch.exp(logit_scale)
logits_per_seq = torch.mm(seq_embeddings, struct_embeddings.t()) * temperature
logits_per_struct = logits_per_seq.t()
# Labels for contrastive loss
labels = torch.arange(seq_embeddings.size(0)).to(seq_embeddings.device)
# Contrastive loss as described in the paper
loss_seq = F.cross_entropy(logits_per_seq, labels)
loss_struct = F.cross_entropy(logits_per_struct, labels)
loss = (loss_seq + loss_struct) / 2
return loss
The ContrastiveLoss
class is designed to implement a contrastive learning objective for sequence and structure embeddings, similar to how CLIP (Contrastive Language-Image Pre-training) works. The goal is to bring corresponding sequence and structure embeddings closer in the embedding space while pushing non-corresponding pairs further apart.
Normalization: Both sequence and structure embeddings are normalized to unit vectors. This ensures that the cosine similarity is computed correctly.
Cosine Similarity: The similarity between the embeddings is computed using the dot product. The temperature parameter τ is used to scale the logits (cosine similarities) before applying the softmax function in contrastive learning. It controls the sharpness of the distribution:
Making τ a learnable parameter allows the model to adapt the scaling dynamically based on the data and the training process.
Contrastive Loss:
Cross-Entropy Loss: The loss is computed using cross-entropy, treating the problem as a classification task where the correct pair should have the highest similarity score.
The computation of temperature = torch.exp(logit_scale) could potentially lead to numerical instability if logit_scale is large. There is a clamp (torch.clamp(logit_scale, max=...)) in model def to ensure stability.
# Test case
loss_fct = ContrastiveLoss() # Instantiate the loss function
with torch.no_grad():
loss = loss_fct(logits[0], logits[1], model.logit_scale)
print(loss)
tensor(1.6297)
We will implement the Alignment and Uniformity metrics as proposed by Wang & Isola (2020), in addition to Contrastive Accuracy and Top-K Accuracy.
$$ \text{Alignment} = \mathbb{E}_{(x, y) \sim p_{\text{pos}}} \left[ \| f(x) - f(y) \|^2 \right] $$ This is calculated as the average squared Euclidean distance between embeddings of positive pairs. A good alignment score is close to 0, indicating that positive pairs are nearly identical in the embedding space.
$$ \text{Uniformity} = \log \mathbb{E}_{(x, y) \sim p_{\text{data}}} \left[ e^{-2 \| f(x) - f(y) \|^2} \right] $$ This is calculated as the logarithm of the expected exponential of the negative squared Euclidean distance between all pairs of embeddings. High Uniformity indicates that embeddings are well spread out uniformly across the embedding space, which is desirable.
Cosine Similarity measures the cosine of the angle between two non-zero vectors.
Contrastive Accuracy measures how often the model correctly identifies the matching pair among a set of negatives.
Top-K Accuracy measures whether the true positive is within the top K closest predictions.
import torch
import torch.nn.functional as F
# Function to compute the Alignment metric
def compute_alignment(seq_embeddings, struct_embeddings):
distances = (seq_embeddings - struct_embeddings).pow(2).sum(dim=1)
alignment = distances.mean().item()
return alignment
# Function to compute the Uniformity metric
# This function is computationally intensive and may take a while to run
def compute_uniformity(embeddings):
pairwise_distances = torch.cdist(embeddings, embeddings, p=2).pow(2)
uniformity = torch.log(torch.exp(-2 * pairwise_distances).mean()).item()
return uniformity
# Function to compute the Cosine Similarity metric
def compute_cosine_similarity(seq_embeddings, struct_embeddings):
"""
This function normalizes the sequence and structure embeddings to unit
vectors and then computes the cosine similarity between each pair using
matrix multiplication.
"""
cosine_sim = torch.mm(seq_embeddings, struct_embeddings.t())
return cosine_sim
# Function to compute the Contrastive Accuracy
def compute_contrastive_accuracy(cosine_sim):
"""
This function finds the index of the maximum cosine similarity for each
sequence embedding and compares it to the correct index.
It then computes the mean accuracy.
"""
correct_preds = cosine_sim.argmax(dim=1)
correct = correct_preds == torch.arange(cosine_sim.size(0)).to(cosine_sim)
return correct.float().mean().item()
# Function to compute the Top-K Accuracy
def compute_top_k_accuracy(cosine_sim, k=1):
"""
This function finds the top K predictions for each sequence embedding and
checks if the correct match is within these top K predictions.
It then computes the mean accuracy.
"""
top_k_preds = cosine_sim.topk(k, dim=1)[1]
correct = torch.arange(cosine_sim.size(0)).unsqueeze(1).expand_as(top_k_preds)
correct = correct == top_k_preds
return correct.any(dim=1).float().mean().item()
def compute_metrics(eval_pred):
seq_embeddings, struct_embeddings = eval_pred.predictions
seq_embeddings = torch.tensor(seq_embeddings)
struct_embeddings = torch.tensor(struct_embeddings)
# Normalize embeddings
seq_embeddings = F.normalize(seq_embeddings, dim=1)
struct_embeddings = F.normalize(struct_embeddings, dim=1)
# Compute metrics
alignment = compute_alignment(seq_embeddings, struct_embeddings)
combined_embeddings = torch.cat((seq_embeddings, struct_embeddings), dim=0)
uniformity = compute_uniformity(combined_embeddings)
cosine_sim = compute_cosine_similarity(seq_embeddings, struct_embeddings)
contrastive_accuracy = compute_contrastive_accuracy(cosine_sim)
top_3_accuracy = compute_top_k_accuracy(cosine_sim, k=3)
metrics = {
"alignment": alignment,
"uniformity": uniformity,
"contrastive_accuracy": contrastive_accuracy,
"top_3_accuracy": top_3_accuracy
}
return metrics
# Test case
class EvalPrediction:
def __init__(self, predictions, label_ids):
self.predictions = predictions
self.label_ids = label_ids
eval_pred = EvalPrediction(predictions=(logits[0], logits[1]), label_ids=None)
# Compute metrics
metrics = compute_metrics(eval_pred)
metrics
<ipython-input-17-a9da5a3d8f02>:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). seq_embeddings = torch.tensor(seq_embeddings) <ipython-input-17-a9da5a3d8f02>:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). struct_embeddings = torch.tensor(struct_embeddings)
{'alignment': 2.111138105392456, 'uniformity': -0.9696311354637146, 'contrastive_accuracy': 0.0, 'top_3_accuracy': 0.800000011920929}
from transformers import Trainer
class ContrastiveTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs) # Forward pass to get logits
logits = outputs['logits']
# Access logit_scale from the underlying model
logit_scale = model.logit_scale
loss_fct = ContrastiveLoss() # Instantiate the loss function
loss = loss_fct(logits[0], logits[1], logit_scale) # Compute the loss
return (loss, outputs) if return_outputs else loss
import gc # Python's garbage collection module
def clear_memory():
gc.collect() # explicitly triggers garbage collection, free up memory
if torch.cuda.is_available(): torch.cuda.empty_cache() # clears the PyTorch CUDA memory cache
# Original CLIP parameters from **Learning Transferable Visual Models From Natural Language Supervision**
original_dataset_size = 400 * 10**6 # 400 million pairs
original_batch_size = 32768
original_epochs = 32
original_warmup_steps = 2000
# Calculate total training steps for the original setup
original_total_training_steps = (original_dataset_size * original_epochs) / original_batch_size
# Calculate the warmup ratio
original_warmup_ratio = original_warmup_steps / original_total_training_steps
print(f"Original total training steps: {original_total_training_steps}")
print(f"Original warmup ratio: {original_warmup_ratio}")
# Your training setup parameters
our_dataset_size = 1571 # Your dataset size
our_batch_size = 8
our_epochs = 5
# Calculate total training steps for your setup
our_total_training_steps = (our_dataset_size * our_epochs) / our_batch_size
# Calculate your warmup steps using the original warmup ratio
our_warmup_steps = int(original_warmup_ratio * our_total_training_steps)
our_warmup_steps = max(1, our_warmup_steps) # Ensure at least 1 warmup step
print(f"Our total training steps: {our_total_training_steps}")
print(f"Our warmup steps: {our_warmup_steps}")
Original total training steps: 390625.0 Original warmup ratio: 0.00512 Our total training steps: 981.875 Our warmup steps: 5
from transformers import TrainingArguments, set_seed
set_seed(42)
num_epochs = 5
batch_size = 8
logging_steps = len(train_ds) // batch_size
model_name = f"{model_ckpt}-Ab-CLIP-v0"
# Training arguments
training_args = TrainingArguments(
output_dir=model_name,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
learning_rate=1e-4,
weight_decay=0.25,
adam_beta1=0.9,
adam_beta2=0.98,
adam_epsilon=1e-6,
fp16=True, # Mixed-precision training
lr_scheduler_type="cosine",
warmup_steps=5, # for stability during the initial phase of training
load_best_model_at_end=True,
disable_tqdm=False,
logging_steps=logging_steps,
evaluation_strategy="epoch", # Evaluate at the end of each epoch
save_strategy="epoch", # Save the model at the end of each epoch
push_to_hub=True,
)
/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1474: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
from transformers import Trainer
# Initialize DataCollator
data_collator = DataCollatorWithPadding(tokenizer)
trainer = ContrastiveTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=valid_ds,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics
)
# Optionally set the max_split_size_mb to avoid fragmentation issues
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
clear_memory()
# Train the model
#print(torch.cuda.memory_summary())
trainer.train()
trainer.push_to_hub(commit_message="Training completed!")
Epoch | Training Loss | Validation Loss | Alignment | Uniformity | Contrastive Accuracy | Top 3 Accuracy |
---|---|---|---|---|---|---|
1 | 1.015300 | 1.295794 | 1.272869 | -2.411410 | 0.068027 | 0.183673 |
2 | 0.462300 | 1.072861 | 1.175534 | -2.800844 | 0.095238 | 0.265306 |
3 | 0.264600 | 1.086700 | 1.151074 | -2.963047 | 0.136054 | 0.326531 |
4 | 0.181500 | 1.076758 | 1.149029 | -3.040295 | 0.129252 | 0.367347 |
5 | 0.140300 | 1.052683 | 1.142588 | -3.066312 | 0.122449 | 0.360544 |
CommitInfo(commit_url='https://huggingface.co/arjan-hada/esm2_t6_8M_UR50D-Ab-CLIP-v0/commit/8b840c6c73f563ccece9b4c599846cfeaa1dc2f2', commit_message='Training completed!', commit_description='', oid='8b840c6c73f563ccece9b4c599846cfeaa1dc2f2', pr_url=None, pr_revision=None, pr_num=None)
# Evaluate the model on the test set
test_result = trainer.evaluate(eval_dataset=test_ds)
# Print the results
print(f"Test Loss: {test_result['eval_loss']}")
for key, value in test_result.items():
if key != 'eval_loss':
print(f"{key}: {value}")
Test Loss: 0.8743131756782532 eval_alignment: 1.021303653717041 eval_uniformity: -3.077335834503174 eval_contrastive_accuracy: 0.2266666740179062 eval_top_3_accuracy: 0.5066666603088379 eval_runtime: 1.4628 eval_samples_per_second: 102.546 eval_steps_per_second: 12.989 epoch: 5.0
# Saves the best model
trainer.save_model("models/esm2_t6_8M_UR50D-Ab-CLIP-v0")
model.config.save_pretrained("models/esm2_t6_8M_UR50D-Ab-CLIP-v0")
tokenizer.save_pretrained("models/esm2_t6_8M_UR50D-Ab-CLIP-v0")
events.out.tfevents.1720548652.a3d6e2bd7b02.796.1: 0%| | 0.00/588 [00:00<?, ?B/s]
('models/esm2_t6_8M_UR50D-Ab-CLIP-v0/tokenizer_config.json', 'models/esm2_t6_8M_UR50D-Ab-CLIP-v0/special_tokens_map.json', 'models/esm2_t6_8M_UR50D-Ab-CLIP-v0/vocab.txt', 'models/esm2_t6_8M_UR50D-Ab-CLIP-v0/added_tokens.json')
The provided training and evaluation results show the progression of our model's performance across epochs. Here’s a detailed analysis and some insights:
Training Loss and Validation Loss
Metrics