# 导入必要的库import dspyfrom dsp.utils import deduplicatefrom dspy.datasets import HotPotQAfrom dspy.predict.retry import Retryfrom dspy.teleprompt import BootstrapFewShot, BootstrapFewShotWithRandomSearchfrom dspy.evaluate.evaluate import Evaluate# 导入断言相关的模块from dspy.primitives.assertions import assert_transform_module, backtrack_handler
# 导入必要的库import osimport openai# 设置OpenAI API密钥为环境变量中的值openai.api_key = os.getenv('OPENAI_API_KEY')
# 创建一个 ColBERTv2 模型,使用指定的 URL 连接到 wiki17_abstracts 数据集colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')# 配置 DeepSpeed 设置,指定使用上面创建的 ColBERTv2 模型dspy.settings.configure(rm=colbertv2_wiki17_abstracts)# 创建一个 OpenAI 模型,使用 'gpt-3.5-turbo' 模型,并设置最大 token 数为 500turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=500)# 配置 DeepSpeed 设置,指定使用上面创建的 OpenAI 模型,同时设置 trace 为空列表,温度为 0.7dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)
# 导入HotPotQA数据集dataset = HotPotQA(train_seed=1, train_size=300, eval_seed=2023, dev_size=300, test_size=0)# 从训练集中提取问题作为输入trainset = [x.with_inputs('question') for x in dataset.train]# 从开发集中提取问题作为输入devset = [x.with_inputs('question') for x in dataset.dev]
# 建议辅助函数和Teleprompter指标def validate_query_distinction_local(previous_queries, query): """检查查询是否与先前的查询不同""" if previous_queries == []: return True if dspy.evaluate.answer_exact_match_str(query, previous_queries, frac=0.8): return False return Truedef validate_context_and_answer_and_hops(example, pred, trace=None): if not dspy.evaluate.answer_exact_match(example, pred): return False if not dspy.evaluate.answer_passage_match(example, pred): return False return True
# 外在度量标准def gold_passages_retrieved(example, pred, trace=None): # 将示例中的金标准标题转换为集合 gold_titles = set(map(dspy.evaluate.normalize_text, example['gold_titles'])) # 将预测中找到的标题转换为集合 found_titles = set(map(dspy.evaluate.normalize_text, [c.split(' | ')[0] for c in pred.context])) # 检查金标准标题是否是找到的标题的子集 return gold_titles.issubset(found_titles)
# dspy模块的签名class GenerateAnswer(dspy.Signature): """用简短的事实性答案回答问题。""" context = dspy.InputField(desc="可能包含相关事实") question = dspy.InputField() answer = dspy.OutputField(desc="通常在1到5个单词之间")class GenerateSearchQuery(dspy.Signature): """编写一个简单的搜索查询,以帮助回答复杂问题。""" context = dspy.InputField(desc="可能包含相关事实") question = dspy.InputField() query = dspy.OutputField()
def all_queries_distinct(prev_queries): # 定义一个变量query_distinct,初始值为True query_distinct = True # 遍历prev_queries列表,同时获取索引i和对应的query for i, query in enumerate(prev_queries): # 调用validate_query_distinction_local函数,传入prev_queries列表中当前query之前的所有query和当前query作为参数 # 如果返回值为False,则将query_distinct设为False,并跳出循环 if validate_query_distinction_local(prev_queries[:i], query) == False: query_distinct = False break # 返回query_distinct的值 return query_distinct
class SimplifiedBaleen(dspy.Module): def __init__(self, passages_per_hop=2, max_hops=2): super().__init__() # 生成查询链 self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)] # 检索模块 self.retrieve = dspy.Retrieve(k=passages_per_hop) # 生成答案链 self.generate_answer = dspy.ChainOfThought(GenerateAnswer) self.max_hops = max_hops # 仅用于评估断言 self.passed_suggestions = 0 def forward(self, question): context = [] prev_queries = [question] for hop in range(self.max_hops): # 生成查询 query = self.generate_query[hop](context=context, question=question).query prev_queries.append(query) # 检索文本段落 passages = self.retrieve(query).passages context = deduplicate(context + passages) if all_queries_distinct(prev_queries): self.passed_suggestions += 1 # 生成答案 pred = self.generate_answer(context=context, question=question) pred = dspy.Prediction(context=context, answer=pred.answer) return pred
class SimplifiedBaleenAssertions(dspy.Module): def __init__(self, passages_per_hop=2, max_hops=2): super().__init__() self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)] self.retrieve = dspy.Retrieve(k=passages_per_hop) self.generate_answer = dspy.ChainOfThought(GenerateAnswer) self.max_hops = max_hops # 仅用于评估断言 self.passed_suggestions = 0 def forward(self, question): context = [] prev_queries = [question] for hop in range(self.max_hops): query = self.generate_query[hop](context=context, question=question).query dspy.Suggest( len(query) <= 100, "查询应该简短且少于100个字符", ) dspy.Suggest( validate_query_distinction_local(prev_queries, query), "查询应该与以下内容不同: " + "; ".join(f"{i+1}) {q}" for i, q in enumerate(prev_queries)), ) prev_queries.append(query) passages = self.retrieve(query).passages context = deduplicate(context + passages) if all_queries_distinct(prev_queries): self.passed_suggestions += 1 pred = self.generate_answer(context=context, question=question) pred = dspy.Prediction(context=context, answer=pred.answer) return pred
# 创建一个Evaluate对象,用于在HotpotQA数据集上评估模型性能evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=10, display_progress=True, display_table=False)
def evaluate(module): # 初始化模块的passed_suggestions属性为0 module.passed_suggestions = 0 # 在HotpotQA数据集上评估模块的检索得分 retrieval_score = evaluate_on_hotpotqa( module, metric=gold_passages_retrieved ) # 计算模块通过的建议数量占开发集总数的百分比 suggestions_score = module.passed_suggestions / len(devset) * 100 # 在HotpotQA数据集上评估模块的准确性得分 accuracy_score = evaluate_on_hotpotqa( module, metric=dspy.evaluate.answer_exact_match ) # 打印建议得分 print(f"## Suggestions Score: {suggestions_score}") # 打印检索得分 print(f"## Retrieval Score: {retrieval_score}") # 打印准确性得分 print(f"## Accuracy Score: {accuracy_score}")
# 无编译 + 无断言baleen = SimplifiedBaleen() # 创建一个SimplifiedBaleen对象evaluate(baleen) # 对baleen对象进行评估
# 无编译 + 断言# 使用 assert_transform_module 函数对 SimplifiedBaleenAssertions().map_named_predictors(Retry) 进行转换baleen_with_assertions = assert_transform_module(SimplifiedBaleenAssertions().map_named_predictors(Retry), backtrack_handler) # 对 baleen_with_assertions 进行评估evaluate(baleen_with_assertions)
# 设置最大的bootstrapped演示数量为2max_bootstrapped_demos = 2
# 是编译 + 无断言baleen = SimplifiedBaleen()teleprompter = BootstrapFewShotWithRandomSearch( metric=validate_context_and_answer_and_hops, max_bootstrapped_demos=max_bootstrapped_demos, num_candidate_programs=6,)# 编译 Baleen 模型compiled_baleen = teleprompter.compile(student = SimplifiedBaleen(), teacher = SimplifiedBaleen(), trainset = trainset, valset = devset)# 评估编译后的 Baleen 模型evaluate(compiled_baleen)
# 是编译 + 是断言baleen = SimplifiedBaleen() # 创建一个SimplifiedBaleen对象teleprompter = BootstrapFewShotWithRandomSearch( # 创建一个BootstrapFewShotWithRandomSearch对象 metric=validate_context_and_answer_and_hops, # 指定metric为validate_context_and_answer_and_hops函数 max_bootstrapped_demos=max_bootstrapped_demos, # 设置max_bootstrapped_demos参数 num_candidate_programs=6, # 设置num_candidate_programs参数为6)compiled_baleen = teleprompter.compile( # 调用teleprompter对象的compile方法 student=assert_transform_module( # 设置student参数为assert_transform_module函数的返回值 SimplifiedBaleenAssertions().map_named_predictors(Retry), # 使用SimplifiedBaleenAssertions对象的map_named_predictors方法,并传入Retry参数 backtrack_handler, # 设置backtrack_handler参数 ), teacher=baleen, # 设置teacher参数为baleen对象 trainset=trainset, # 设置trainset参数 valset=devset # 设置valset参数为devset)evaluate(compiled_baleen) # 调用evaluate函数,传入compiled_baleen参数