# code for loading notebook's format
import os
# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', '..', 'notebook_format'))
from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
os.chdir(path)
%load_ext watermark
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
import pandas as pd
import transformers
import pytorch_lightning as pl
from dataclasses import dataclass
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedTokenizerBase,
)
from typing import Any, Dict, List, Optional, Union
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy
%watermark -a 'Ethen' -d -v -u -p pytorch_lightning,transformers,datasets,torch
Author: Ethen Last updated: 2024-10-12 Python implementation: CPython Python version : 3.10.14 IPython version : 8.26.0 pytorch_lightning: 2.1.4 transformers : 4.41.1 datasets : 3.0.0 torch : 2.1.2+cu121
In this article, we'll be implementing a LLM pairwise judge, where a LLM is presented with a question and two answers, and tasked with determining which answer is better or declaring a tie. Using LLMs as judges for evaluation offers several benefits:
Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena [5] provides a thorough examination of using LLMs as judges. The researchers curated two distinct benchmark suites for this purpose:
They verify by using state of art LLMs, GPT-4, as judges, it's capable of matching human evaluation at an agreement rate exceeding 80%.
We'll first implement a generation module for generating responses from LLM. We use Qwen 2.5 Collection [2] in this article, feel free to pick your favorite LLM. While doing so be sure to set the correct padding token, padding side as well as configure max_new_tokens [1].
max_new_tokens
is not explicitly specified in GenerationConfig
.@dataclass
class DataCollatorForGeneration(DataCollatorMixin):
"""
tokenize raw text (prompt) as well as padding while forming a batch for data loader.
"""
tokenizer: PreTrainedTokenizerBase
max_seq_len: int = 512
padding: Union[bool, str, PaddingStrategy] = True
return_tensors: str = "pt"
prompt_col_name: str = "prompt"
def __post_init__(self):
self.tokenizer.padding_side = "left"
self.tokenizer.pad_token = self.tokenizer.eos_token
def __call__(
self, features: List[Dict[str, Any]], return_tensors=None
) -> Dict[str, Any]:
prompts = [feature[self.prompt_col_name] for feature in features]
tokenized_text = self.tokenizer(
prompts,
padding=self.padding,
max_length=self.max_seq_len,
truncation=True,
return_attention_mask=True,
return_tensors=self.return_tensors,
)
batch = {
"prompts": prompts,
"input_ids": tokenized_text["input_ids"],
"attention_mask": tokenized_text["attention_mask"],
}
return batch
examples = [{"prompt": "What is the capital of France?"}, {"prompt": "What is the biggest planet in the solar system?"}]
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
data_collator = DataCollatorForGeneration(tokenizer)
data_loader = DataLoader(examples, batch_size=2, num_workers=2, collate_fn=data_collator)
batch = next(iter(data_loader))
batch
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
{'prompts': ['What is the capital of France?', 'What is the biggest planet in the solar system?'], 'input_ids': tensor([[151645, 151645, 151645, 3838, 374, 279, 6722, 315, 9625, 30], [ 3838, 374, 279, 8538, 11580, 304, 279, 12941, 1849, 30]]), 'attention_mask': tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
class LLMGenerateLightningModule(pl.LightningModule):
"""
Generate responses from LLM. Expects input prompts, tokenized input_ids, attention_mask
"""
def __init__(
self,
pretrained_model_name_or_path,
generation_config,
prediction_config,
cache_dir="/data",
):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, cache_dir=cache_dir
)
self.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, padding_side="left", cache_dir=cache_dir
)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = generation_config
self._setup_prediction(prediction_config)
def predict_step(self, batch, batch_idx, dataloader_idx=None):
prompts = batch["prompts"]
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
responses = self.generate(input_ids, attention_mask)
prediction_output = {
"prompts": prompts,
"responses": responses,
}
self.prediction_outputs.append(prediction_output)
return prediction_output
def generate(self, input_ids, attention_mask):
model_output = self.model.generate(
input_ids,
attention_mask=attention_mask,
generation_config=self.generation_config
)
# crop input prompt from generated response
input_seq_length = input_ids.shape[-1]
model_output_answer_only = model_output[:, input_seq_length:]
responses = self.tokenizer.batch_decode(model_output_answer_only, skip_special_tokens=True)
return responses
def _setup_prediction(self, prediction_config):
if prediction_config:
self.prediction_outputs = []
self._prediction_partition_idx = 0
self.prediction_partition_format = prediction_config["prediction_partition_format"]
self.prediction_output_path = prediction_config["prediction_output_path"]
self.prediction_accumulation_steps = prediction_config.get("prediction_accumulation_steps", 10)
def _save_prediction_outputs(self):
if self.prediction_output_path:
data = {field: [] for field in self.prediction_outputs[0]}
for prediction_output in self.prediction_outputs:
for field in data:
data[field].extend(prediction_output[field])
partition_file_name = self.prediction_partition_format.format(
rank=self.global_rank, partition=self._prediction_partition_idx
)
formatted_output_path = os.path.join(
self.prediction_output_path, partition_file_name
)
# saves prediction batch locally via pandas data frame
df_prediction_outputs = pd.DataFrame.from_dict(data)
os.makedirs(self.prediction_output_path, exist_ok=True)
df_prediction_outputs.to_parquet(formatted_output_path, index=False)
self._prediction_partition_idx += 1
self.prediction_outputs.clear()
def on_predict_batch_end(self, outputs, batch, batch_idx):
if len(self.prediction_outputs) == self.prediction_accumulation_steps:
self._save_prediction_outputs()
def on_predict_epoch_end(self):
if len(self.prediction_outputs) > 0:
self._save_prediction_outputs()
generation_config = GenerationConfig(
max_new_tokens=250
)
llm_generate_module = LLMGenerateLightningModule(
pretrained_model_name_or_path="Qwen/Qwen2.5-1.5B-Instruct",
generation_config=generation_config,
prediction_config={
"prediction_output_path": "prediction",
"prediction_partition_format": "rank-{rank:02d}-partition-{partition:06d}.parquet"
}
)
trainer = pl.Trainer()
prediction_output = trainer.predict(llm_generate_module, data_loader)
prediction_output
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 1/1 [00:10<00:00, 0.09it/s]
[{'prompts': ['What is the capital of France?', 'What is the biggest planet in the solar system?'], 'responses': [' The capital of France is Paris. It is located in the north of the country and is the largest city in France. Paris is known for its famous landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. It is also the political, cultural, and financial center of France. Paris is home to many museums, theaters, and other cultural institutions, and is known for its fashion, cuisine, and nightlife. The city is also famous for its art, music, and literature, and is home to many famous artists and writers. Paris is also known for its beautiful parks, including the Luxembourg Gardens and the Bois de Boulogne. The city is also home to many universities and research institutions, including the Sorbonne University and the École Normale Supérieure. Paris is also known for its fashion, cuisine, and nightlife, and is home to many famous artists and writers. The city is also known for its beautiful parks, including the Luxembourg Gardens and the Bois de Boulogne. The city is also home to many universities and research institutions, including the Sorbonne University and the École Normale Supérieure. Paris is also known for its fashion, cuisine, and nightlife', ' The biggest planet in the solar system is Jupiter. It is a gas giant planet with a diameter of about 86,881 miles (139,822 kilometers) and a mass of about 1.90 x 10^27 kilograms. Jupiter is also the fifth planet from the sun and is the largest planet in the solar system. It is composed mostly of hydrogen and helium, with a small amount of other elements. Jupiter has a strong magnetic field and a ring system, and it has at least 79 known moons. It is also known for its Great Red Spot, a giant storm that has been raging on Jupiter for at least 400 years. Jupiter is also known for its powerful winds, which can reach speeds of up to 430 miles per hour (690 kilometers per hour). Jupiter is also known for its powerful winds, which can reach speeds of up to 430 miles per hour (690 kilometers per hour). Jupiter is also known for its powerful winds, which can reach speeds of up to 430 miles per hour (690 kilometers per hour). Jupiter is also known for its powerful winds, which can reach speeds']}]
df_prediction_output = pd.read_parquet("prediction")
df_prediction_output
prompts | responses | |
---|---|---|
0 | What is the capital of France? | The capital of France is Paris. It is located... |
1 | What is the biggest planet in the solar system? | The biggest planet in the solar system is Jup... |
The pairwise judge's implementation (prompt used) is inspired by huggingface's HfPairwiseJudge. At the time of writing this, its backend relies on their own inference client which has poses some restriction on the model size free tier users are allowed to use.
Our judge will also make an attempt to handle position bias. Position bias is when an LLM exhibits a propensity to favor certain positions over others, regardless of the actual content or quality of the answers. A conservative approach for addressing this issue is to call the judge twice, swapping the two answers' order, and only declare a win when an answer is preferred in both orders. If results are inconsistent after swapping, a tie can be declared. A more aggressive approach is to assign positions randomly, which can be effective at a large scale with the correct expectations. In the following experiments, we use the conservative approach.
@dataclass
class DataCollatorForPairwiseJudge(DataCollatorMixin):
"""
tokenize raw text (prompt) as well as padding while forming a batch for data loader.
Parameters
----------
system_prompt :
System prompt to be used for the judge. If not provided, a default prompt is used.
System prompt should contain following placeholders: `{prompt}`, `{response1}`, and `{response2}`.
"""
default_system_prompt = '''
I require a leaderboard for various large language models. I'll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective.
Instruction: {prompt}
Model Outputs: Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier.
"model_identifier": "1", "output": """{response1}""" "model_identifier": "2", "output": """{response2}"""
Task Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...).
'''
tokenizer: PreTrainedTokenizerBase
max_seq_len: int = 1024
padding: Union[bool, str, PaddingStrategy] = True
return_tensors: str = "pt"
prompt_col_name: str = "prompts"
response1_col_name: str = "responses1"
response2_col_name: str = "responses2"
system_prompt: Optional[str] = None
def __post_init__(self):
self.tokenizer.padding_side = "left"
self.tokenizer.pad_token = self.tokenizer.eos_token
self.system_prompt = self.system_prompt if self.system_prompt is not None else self.default_system_prompt
def __call__(
self, features: List[Dict[str, Any]], return_tensors=None
) -> Dict[str, Any]:
judge_prompts = []
judge_swapped_position_prompts = []
for feature in features:
prompt = feature[self.prompt_col_name]
response1 = feature[self.response1_col_name]
response2 = feature[self.response2_col_name]
judge_prompt = self.system_prompt.format(
prompt=prompt, response1=response1, response2=response2
)
judge_swapped_position_prompt = self.system_prompt.format(
prompt=prompt, response1=response2, response2=response1
)
judge_prompts.append(judge_prompt)
judge_swapped_position_prompts.append(judge_swapped_position_prompt)
tokenized_text = self.tokenizer(
judge_prompts,
padding=self.padding,
max_length=self.max_seq_len,
truncation=True,
return_attention_mask=True,
return_tensors=self.return_tensors,
)
tokenized_swapped_position_text = self.tokenizer(
judge_swapped_position_prompts,
padding=self.padding,
max_length=self.max_seq_len,
truncation=True,
return_attention_mask=True,
return_tensors=self.return_tensors,
)
batch = {
"prompts": judge_prompts,
"input_ids": tokenized_text["input_ids"],
"attention_mask": tokenized_text["attention_mask"],
"input_ids_swapped_position": tokenized_swapped_position_text["input_ids"],
"attention_mask_swapped_position": tokenized_swapped_position_text["attention_mask"],
}
return batch
examples = [
{"prompts": "What is the capital of France?", "responses1": "Paris", "responses2": "Taipei"},
{"prompts": "What is the biggest planet in the solar system?", "responses1": "Saturn", "responses2": "Jupiter"}
]
data_collator = DataCollatorForPairwiseJudge(tokenizer)
data_loader = DataLoader(examples, batch_size=2, num_workers=2, collate_fn=data_collator)
batch = next(iter(data_loader))
batch.keys()
dict_keys(['prompts', 'input_ids', 'attention_mask', 'input_ids_swapped_position', 'attention_mask_swapped_position'])
class PairwiseLLMJudgeLightningModule(pl.LightningModule):
def __init__(
self,
pretrained_model_name_or_path,
generation_config,
prediction_config,
cache_dir="/data",
):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, cache_dir=cache_dir
)
self.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, padding_side="left", cache_dir=cache_dir
)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = generation_config
self._setup_prediction(prediction_config)
def predict_step(self, batch, batch_idx, dataloader_idx=None):
prompts = batch["prompts"]
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
responses = self.generate(input_ids, attention_mask)
input_ids_swapped_position = batch["input_ids_swapped_position"]
attention_mask_swapped_position = batch["attention_mask_swapped_position"]
responses_swapped_position = self.generate(input_ids_swapped_position, attention_mask_swapped_position)
prediction_output = {
"prompts": prompts,
"responses": responses,
"responses_swapped_position": responses_swapped_position,
}
self.prediction_outputs.append(prediction_output)
return prediction_output
def generate(self, input_ids, attention_mask):
model_output = self.model.generate(
input_ids,
attention_mask=attention_mask,
generation_config=self.generation_config
)
# crop input prompt from generated response
input_seq_length = input_ids.shape[-1]
model_output_answer_only = model_output[:, input_seq_length:]
responses = self.tokenizer.batch_decode(model_output_answer_only, skip_special_tokens=True)
return responses
def _setup_prediction(self, prediction_config):
if prediction_config:
self.prediction_outputs = []
self._prediction_partition_idx = 0
self.prediction_partition_format = prediction_config["prediction_partition_format"]
self.prediction_output_path = prediction_config["prediction_output_path"]
self.prediction_accumulation_steps = prediction_config.get("prediction_accumulation_steps", 10)
def _save_prediction_outputs(self):
if self.prediction_output_path:
data = {field: [] for field in self.prediction_outputs[0]}
for prediction_output in self.prediction_outputs:
for field in data:
data[field].extend(prediction_output[field])
partition_file_name = self.prediction_partition_format.format(
rank=self.global_rank, partition=self._prediction_partition_idx
)
formatted_output_path = os.path.join(
self.prediction_output_path, partition_file_name
)
# saves prediction batch locally via pandas data frame
df_prediction_outputs = pd.DataFrame.from_dict(data)
os.makedirs(self.prediction_output_path, exist_ok=True)
df_prediction_outputs.to_parquet(formatted_output_path, index=False)
self._prediction_partition_idx += 1
self.prediction_outputs.clear()
def on_predict_batch_end(self, outputs, batch, batch_idx):
if len(self.prediction_outputs) == self.prediction_accumulation_steps:
self._save_prediction_outputs()
def on_predict_epoch_end(self):
if len(self.prediction_outputs) > 0:
self._save_prediction_outputs()
generation_config = GenerationConfig(
max_new_tokens=2,
)
pairwise_judge_module = PairwiseLLMJudgeLightningModule(
pretrained_model_name_or_path="Qwen/Qwen2.5-3B-Instruct",
generation_config=generation_config,
prediction_config={
"prediction_output_path": "judge",
"prediction_partition_format": "rank-{rank:02d}-partition-{partition:06d}.parquet"
}
)
trainer = pl.Trainer()
prediction_output = trainer.predict(pairwise_judge_module, data_loader)
prediction_output
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.31s/it] Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 1.83it/s]
[{'prompts': ['\n I require a leaderboard for various large language models. I\'ll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective.\n\n Instruction: What is the capital of France?\n\n Model Outputs: Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier.\n\n "model_identifier": "1", "output": """Paris""" "model_identifier": "2", "output": """Taipei"""\n\n Task Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...).\n ', '\n I require a leaderboard for various large language models. I\'ll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective.\n\n Instruction: What is the biggest planet in the solar system?\n\n Model Outputs: Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier.\n\n "model_identifier": "1", "output": """Saturn""" "model_identifier": "2", "output": """Jupiter"""\n\n Task Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...).\n '], 'responses': [' 1', ' 2'], 'responses_swapped_position': [' 2', ' 1']}]
df_prediction_output = pd.read_parquet("judge")
df_prediction_output
prompts | responses | responses_swapped_position | |
---|---|---|---|
0 | \n I require a leaderboard for various larg... | 1 | 2 |
1 | \n I require a leaderboard for various larg... | 2 | 1 |