This notebook regroups the code sample of the video below, which is a part of the Hugging Face course.
#@title
from IPython.display import HTML
HTML('<iframe width="560" height="315" src="https://www.youtube.com/embed/Hm8_PgVTFuc?rel=0&controls=0&showinfo=0" frameborder="0" allowfullscreen></iframe>')
Install the Transformers and Datasets libraries to run this notebook.
! pip install datasets transformers[sentencepiece]
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator
accelerator = Accelerator()
tokenizer = AutoTokenizer.from_pretrained("huggingface-course/code-search-net-tokenizer")
model = AutoModelForCausalLM.from_pretrained("huggingface-course/codeparrot-ds")
keytoken_ids = []
for keyword in [
"plt",
"pd",
"sk",
"fit",
"predict",
" plt",
" pd",
" sk",
" fit",
" predict",
]:
ids = tokenizer([keyword]).input_ids[0]
keytoken_ids.append(ids[0])
batch = tokenizer(["import numpy as np"], return_tensors="pt")
model = accelerator.prepare(model)
from torch.nn import CrossEntropyLoss
import torch
def keytoken_weighted_loss(inputs, logits, keytoken_ids, alpha=1.0):
# Shift so that tokens < n predict n
shift_labels = inputs[..., 1:].contiguous()
shift_logits = logits[..., :-1, :].contiguous()
# Calculate per-token loss
loss_fct = CrossEntropyLoss(reduce=False)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# Resize and average loss per sample
loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)
# Calculate and scale weighting
weights = torch.stack([(inputs == kt).float() for kt in keytoken_ids]).sum(
axis=[0, 2]
)
weights = alpha * (1.0 + weights)
# Calculate weighted average
weighted_loss = (loss_per_sample * weights).mean()
return weighted_loss
logits = model(batch["input_ids"]).logits
loss = keytoken_weighted_loss(batch["input_ids"], logits, keytoken_ids)
accelerator.backward(loss)
from transformers import Trainer
class MyTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.get("input_ids")
outputs = model(input_ids)
loss = keytoken_weighted_loss(input_ids, outputs.logits, keytoken_ids)
return (loss, outputs) if return_outputs else loss