在这个笔记本中,我们将尝试将一个公共数据集的交易分类到我们预先定义的几个类别中。这些方法应该可以复制到任何多类别分类的用例中,其中我们试图将交易数据适配到预定义的类别中,通过运行这些方法,您应该能够学会处理有标签和无标签数据集的几种方法。
在这个笔记本中,我们将采取以下不同的方法:
%load_ext autoreload
%autoreload
%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers
import openai
import pandas as pd
import numpy as np
import json
import os
COMPLETIONS_MODEL = "gpt-4"
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if you didn't set as an env var>"))
我们将使用苏格兰图书馆的一份公开交易数据集,其中包含超过25,000英镑的交易。该数据集包含三个我们将使用的特征:
来源:
https://data.nls.uk/data/organisational-data/transactions-over-25k/
transactions = pd.read_csv('./data/25000_spend_dataset_current.csv', encoding= 'unicode_escape')
len(transactions)
359
transactions.head()
Date | Supplier | Description | Transaction value (£) | |
---|---|---|---|---|
0 | 21/04/2016 | M & J Ballantyne Ltd | George IV Bridge Work | 35098.0 |
1 | 26/04/2016 | Private Sale | Literary & Archival Items | 30000.0 |
2 | 30/04/2016 | City Of Edinburgh Council | Non Domestic Rates | 40800.0 |
3 | 09/05/2016 | Computacenter Uk | Kelvin Hall | 72835.0 |
4 | 09/05/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 64361.0 |
def request_completion(prompt):
completion_response = openai.chat.completions.create(
prompt=prompt,
temperature=0,
max_tokens=5,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=COMPLETIONS_MODEL)
return completion_response
def classify_transaction(transaction,prompt):
prompt = prompt.replace('SUPPLIER_NAME',transaction['Supplier'])
prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])
prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))
classification = request_completion(prompt).choices[0].message.content.replace('\n','')
return classification
# 此功能接收来自Finetuning API中prepare_data函数的训练和验证输出,并
# 确认每个人上的课程数量相同。
# 如果它们的类数量不同,微调将失败并返回错误。
def check_finetune_classes(train_file,valid_file):
train_classes = set()
valid_classes = set()
with open(train_file, 'r') as json_file:
json_list = list(json_file)
print(len(json_list))
for json_str in json_list:
result = json.loads(json_str)
train_classes.add(result['completion'])
#print(f"result: {result['completion']}")
#print(isinstance(result, dict))
with open(valid_file, 'r') as json_file:
json_list = list(json_file)
print(len(json_list))
for json_str in json_list:
result = json.loads(json_str)
valid_classes.add(result['completion'])
#print(f"result: {result['completion']}")
#print(isinstance(result, dict))
if len(train_classes) == len(valid_classes):
print('All good')
else:
print('Classes do not match, please prepare data again')
我们首先将评估基本模型在使用简单提示对这些交易进行分类时的性能。我们将为模型提供5个类别和一个“无法分类”的总类。
zero_shot_prompt = '''You are a data expert working for the National Library of Scotland.
You are analysing all transactions over £25,000 in value and classifying them into one of five categories.
The five categories are Building Improvement, Literature & Archive, Utility Bills, Professional Services and Software/IT.
If you can't tell what it is, say Could not classify
Transaction:
Supplier: SUPPLIER_NAME
Description: DESCRIPTION_TEXT
Value: TRANSACTION_VALUE
The classification is:'''
# 获取测试交易
transaction = transactions.iloc[0]
# 将这些数值代入提示中
prompt = zero_shot_prompt.replace('SUPPLIER_NAME',transaction['Supplier'])
prompt = prompt.replace('DESCRIPTION_TEXT',transaction['Description'])
prompt = prompt.replace('TRANSACTION_VALUE',str(transaction['Transaction value (£)']))
# 利用我们的补全功能来返回一个预测结果
completion_response = request_completion(prompt)
print(completion_response.choices[0].text)
Building Improvement
我们的第一次尝试是正确的,M&J Ballantyne Ltd是一家房屋建筑商,他们进行的工作确实是建筑改善。
让我们将样本大小扩大到25,看看它的表现如何,同样只需一个简单的提示来引导它。
test_transactions = transactions.iloc[:25]
test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x,zero_shot_prompt),axis=1)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:2: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
test_transactions['Classification'].value_counts()
Building Improvement 14 Could not classify 5 Literature & Archive 3 Software/IT 2 Utility Bills 1 Name: Classification, dtype: int64
test_transactions.head(25)
Date | Supplier | Description | Transaction value (£) | Classification | |
---|---|---|---|---|---|
0 | 21/04/2016 | M & J Ballantyne Ltd | George IV Bridge Work | 35098.0 | Building Improvement |
1 | 26/04/2016 | Private Sale | Literary & Archival Items | 30000.0 | Literature & Archive |
2 | 30/04/2016 | City Of Edinburgh Council | Non Domestic Rates | 40800.0 | Utility Bills |
3 | 09/05/2016 | Computacenter Uk | Kelvin Hall | 72835.0 | Software/IT |
4 | 09/05/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 64361.0 | Building Improvement |
5 | 09/05/2016 | A McGillivray | Causewayside Refurbishment | 53690.0 | Building Improvement |
6 | 16/05/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 365344.0 | Building Improvement |
7 | 23/05/2016 | Computacenter Uk | Kelvin Hall | 26506.0 | Software/IT |
8 | 23/05/2016 | ECG Facilities Service | Facilities Management Charge | 32777.0 | Building Improvement |
9 | 23/05/2016 | ECG Facilities Service | Facilities Management Charge | 32777.0 | Building Improvement |
10 | 30/05/2016 | ALDL | ALDL Charges | 32317.0 | Could not classify |
11 | 10/06/2016 | Wavetek Ltd | Kelvin Hall | 87589.0 | Could not classify |
12 | 10/06/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 381803.0 | Building Improvement |
13 | 28/06/2016 | ECG Facilities Service | Facilities Management Charge | 32832.0 | Building Improvement |
14 | 30/06/2016 | Glasgow City Council | Kelvin Hall | 1700000.0 | Building Improvement |
15 | 11/07/2016 | Wavetek Ltd | Kelvin Hall | 65692.0 | Could not classify |
16 | 11/07/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 139845.0 | Building Improvement |
17 | 15/07/2016 | Sotheby'S | Literary & Archival Items | 28500.0 | Literature & Archive |
18 | 18/07/2016 | Christies | Literary & Archival Items | 33800.0 | Literature & Archive |
19 | 25/07/2016 | A McGillivray | Causewayside Refurbishment | 30113.0 | Building Improvement |
20 | 31/07/2016 | ALDL | ALDL Charges | 32317.0 | Could not classify |
21 | 08/08/2016 | ECG Facilities Service | Facilities Management Charge | 32795.0 | Building Improvement |
22 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866.0 | Could not classify |
23 | 15/08/2016 | John Graham Construction Ltd | Causewayside Refurbishment | 196807.0 | Building Improvement |
24 | 24/08/2016 | ECG Facilities Service | Facilities Management Charge | 32795.0 | Building Improvement |
即使没有标记的示例,初始结果也相当不错!它无法分类的那些案例是更难的情况,很少有线索表明它们的主题,但也许如果我们清理标记的数据集以提供更多示例,我们可以获得更好的性能。
让我们从到目前为止已经分类的小数据集中创建嵌入 - 通过在我们的数据集中对101个交易运行零-shot分类器并手动纠正我们得到的15个无法分类结果,我们已经创建了一组带标签的示例。
这个初始部分重用了从Get_embeddings_from_dataset Notebook中的方法,从一个合并字段中连接我们所有特征来创建嵌入。
df = pd.read_csv('./data/labelled_transactions.csv')
df.head()
Date | Supplier | Description | Transaction value (£) | Classification | |
---|---|---|---|---|---|
0 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866 | Other |
1 | 29/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 74806 | Building Improvement |
2 | 29/05/2017 | Morris & Spottiswood Ltd | George IV Bridge Work | 56448 | Building Improvement |
3 | 31/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 164691 | Building Improvement |
4 | 24/07/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 27926 | Building Improvement |
df['combined'] = "Supplier: " + df['Supplier'].str.strip() + "; Description: " + df['Description'].str.strip() + "; Value: " + str(df['Transaction value (£)']).strip()
df.head(2)
Date | Supplier | Description | Transaction value (£) | Classification | combined | |
---|---|---|---|---|---|---|
0 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866 | Other | Supplier: Creative Video Productions Ltd; Desc... |
1 | 29/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 74806 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... |
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
df['n_tokens'] = df.combined.apply(lambda x: len(tokenizer.encode(x)))
len(df)
101
embedding_path = './data/transactions_with_embeddings_100.csv'
from utils.embeddings_utils import get_embedding
df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, model='gpt-4'))
df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x, model='gpt-4'))
df.to_csv(embedding_path)
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from ast import literal_eval
fs_df = pd.read_csv(embedding_path)
fs_df["babbage_similarity"] = fs_df.babbage_similarity.apply(literal_eval).apply(np.array)
fs_df.head()
Unnamed: 0 | Date | Supplier | Description | Transaction value (£) | Classification | combined | n_tokens | babbage_similarity | babbage_search | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866 | Other | Supplier: Creative Video Productions Ltd; Desc... | 136 | [-0.009802100248634815, 0.022551486268639565, ... | [-0.00232666521333158, 0.019198870286345482, 0... |
1 | 1 | 29/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 74806 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 140 | [-0.009065819904208183, 0.012094118632376194, ... | [0.005169447045773268, 0.00473341578617692, -0... |
2 | 2 | 29/05/2017 | Morris & Spottiswood Ltd | George IV Bridge Work | 56448 | Building Improvement | Supplier: Morris & Spottiswood Ltd; Descriptio... | 141 | [-0.009000026620924473, 0.02405017428100109, -... | [0.0028343256562948227, 0.021166473627090454, ... |
3 | 3 | 31/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 164691 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 140 | [-0.009065819904208183, 0.012094118632376194, ... | [0.005169447045773268, 0.00473341578617692, -0... |
4 | 4 | 24/07/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 27926 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 140 | [-0.009065819904208183, 0.012094118632376194, ... | [0.005169447045773268, 0.00473341578617692, -0... |
X_train, X_test, y_train, y_test = train_test_split(
list(fs_df.babbage_similarity.values), fs_df.Classification, test_size=0.2, random_state=42
)
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
probas = clf.predict_proba(X_test)
report = classification_report(y_test, preds)
print(report)
precision recall f1-score support Building Improvement 0.92 1.00 0.96 11 Literature & Archive 1.00 1.00 1.00 3 Other 0.00 0.00 0.00 1 Software/IT 1.00 1.00 1.00 1 Utility Bills 1.00 1.00 1.00 5 accuracy 0.95 21 macro avg 0.78 0.80 0.79 21 weighted avg 0.91 0.95 0.93 21
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result))
这个模型的性能非常强大,因此创建嵌入并使用一个更简单的分类器看起来也是一个有效的方法,零-shot分类器帮助我们对未标记的数据集进行初始分类。
让我们再进一步,看看在相同标记数据集上训练的微调模型是否给我们带来可比较的结果。
在这个用例中,我们将尝试通过在相同标记的101笔交易数据集上训练一个微调模型来改进上面的少样本分类,并将这个微调模型应用于一组未见过的交易数据。
首先,我们需要进行一些数据准备工作,以使我们的数据准备就绪。这将包括以下步骤:
ft_prep_df = fs_df.copy()
len(ft_prep_df)
101
ft_prep_df.head()
Unnamed: 0 | Date | Supplier | Description | Transaction value (£) | Classification | combined | n_tokens | babbage_similarity | babbage_search | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866 | Other | Supplier: Creative Video Productions Ltd; Desc... | 12 | [-0.009630300104618073, 0.009887108579277992, ... | [-0.008217384107410908, 0.025170527398586273, ... |
1 | 1 | 29/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 74806 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 16 | [-0.006144719664007425, -0.0018709596479311585... | [-0.007424891460686922, 0.008475713431835175, ... |
2 | 2 | 29/05/2017 | Morris & Spottiswood Ltd | George IV Bridge Work | 56448 | Building Improvement | Supplier: Morris & Spottiswood Ltd; Descriptio... | 17 | [-0.005225738976150751, 0.015156379900872707, ... | [-0.007611643522977829, 0.030322374776005745, ... |
3 | 3 | 31/05/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 164691 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 16 | [-0.006144719664007425, -0.0018709596479311585... | [-0.007424891460686922, 0.008475713431835175, ... |
4 | 4 | 24/07/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 27926 | Building Improvement | Supplier: John Graham Construction Ltd; Descri... | 16 | [-0.006144719664007425, -0.0018709596479311585... | [-0.007424891460686922, 0.008475713431835175, ... |
classes = list(set(ft_prep_df['Classification']))
class_df = pd.DataFrame(classes).reset_index()
class_df.columns = ['class_id','class']
class_df , len(class_df)
( class_id class 0 0 Literature & Archive 1 1 Utility Bills 2 2 Building Improvement 3 3 Software/IT 4 4 Other, 5)
ft_df_with_class = ft_prep_df.merge(class_df,left_on='Classification',right_on='class',how='inner')
# 在每个补全内容前添加一个前导空格,以帮助模型更好地理解上下文。
ft_df_with_class['class_id'] = ft_df_with_class.apply(lambda x: ' ' + str(x['class_id']),axis=1)
ft_df_with_class = ft_df_with_class.drop('class', axis=1)
# 在每个提示的末尾添加一个通用分隔符,以便模型知道何时提示结束。
ft_df_with_class['prompt'] = ft_df_with_class.apply(lambda x: x['combined'] + '\n\n###\n\n',axis=1)
ft_df_with_class.head()
Unnamed: 0 | Date | Supplier | Description | Transaction value (£) | Classification | combined | n_tokens | babbage_similarity | babbage_search | class_id | prompt | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 15/08/2016 | Creative Video Productions Ltd | Kelvin Hall | 26866 | Other | Supplier: Creative Video Productions Ltd; Desc... | 12 | [-0.009630300104618073, 0.009887108579277992, ... | [-0.008217384107410908, 0.025170527398586273, ... | 4 | Supplier: Creative Video Productions Ltd; Desc... |
1 | 51 | 31/03/2017 | NLS Foundation | Grant Payment | 177500 | Other | Supplier: NLS Foundation; Description: Grant P... | 11 | [-0.022305507212877274, 0.008543581701815128, ... | [-0.020519884303212166, 0.01993306167423725, -... | 4 | Supplier: NLS Foundation; Description: Grant P... |
2 | 70 | 26/06/2017 | British Library | Legal Deposit Services | 50056 | Other | Supplier: British Library; Description: Legal ... | 11 | [-0.01019938476383686, 0.015277703292667866, -... | [-0.01843327097594738, 0.03343546763062477, -0... | 4 | Supplier: British Library; Description: Legal ... |
3 | 71 | 24/07/2017 | ALDL | Legal Deposit Services | 27067 | Other | Supplier: ALDL; Description: Legal Deposit Ser... | 11 | [-0.008471488021314144, 0.004098685923963785, ... | [-0.012966590002179146, 0.01299362163990736, 0... | 4 | Supplier: ALDL; Description: Legal Deposit Ser... |
4 | 100 | 24/07/2017 | AM Phillip | Vehicle Purchase | 26604 | Other | Supplier: AM Phillip; Description: Vehicle Pur... | 10 | [-0.003459023078903556, 0.004626389592885971, ... | [-0.0010945454705506563, 0.008626140654087067,... | 4 | Supplier: AM Phillip; Description: Vehicle Pur... |
# 如果每个类别中都有多个观测值,则此步骤不必要。
# 在我们的情况下,我们没有这样做,所以我们对数据进行了随机打乱,以便我们在训练集和验证集中有更好的机会得到相等的类别
# 如果验证集中类别较少,我们的微调模型会出错,所以这是必要的步骤
import random
labels = [x for x in ft_df_with_class['class_id']]
text = [x for x in ft_df_with_class['prompt']]
ft_df = pd.DataFrame(zip(text, labels), columns = ['prompt','class_id']) #[:300]
ft_df.columns = ['prompt','completion']
ft_df['ordering'] = ft_df.apply(lambda x: random.randint(0,len(ft_df)), axis = 1)
ft_df.set_index('ordering',inplace=True)
ft_df_sorted = ft_df.sort_index(ascending=True)
ft_df_sorted.head()
prompt | completion | |
---|---|---|
ordering | ||
0 | Supplier: Sothebys; Description: Literary & Ar... | 0 |
1 | Supplier: Sotheby'S; Description: Literary & A... | 0 |
2 | Supplier: City Of Edinburgh Council; Descripti... | 1 |
2 | Supplier: John Graham Construction Ltd; Descri... | 2 |
3 | Supplier: John Graham Construction Ltd; Descri... | 2 |
# 这一步是为了删除任何已存在的文件,以防我们已经为这个分类器生成了训练/验证集
#!rm transactions_grouped*
# 我们将打乱后的数据框输出为 .jsonl 文件,并运行 prepare_data 函数来生成我们的输入文件
ft_df_sorted.to_json("transactions_grouped.jsonl", orient='records', lines=True)
!openai tools fine_tunes.prepare_data -f transactions_grouped.jsonl -q
# 此函数用于检查您的所有类是否同时出现在这两个准备好的文件中。
# 如果它们不满足条件,微调模型的创建将会失败。
check_finetune_classes('transactions_grouped_prepared_train.jsonl','transactions_grouped_prepared_valid.jsonl')
31 8 All good
# 这一步将创建你的模型
!openai api fine_tunes.create -t "transactions_grouped_prepared_train.jsonl" -v "transactions_grouped_prepared_valid.jsonl" --compute_classification_metrics --classification_n_classes 5 -m curie
# 您可以使用以下命令来获取微调作业的状态和模型名称,请将作业名称替换为您自己的作业名称。
#!openai api fine_tunes.get -i ft-YBIc01t4hxYBC7I5qhRF3Qdx
# Congrats, you've got a fine-tuned model!
# Copy/paste the name provided into the variable below and we'll take it for a spin
fine_tuned_model = 'curie:ft-personal-2022-10-20-10-42-56'
现在我们将应用我们的分类器来查看它的表现。我们的训练集中只有31个唯一的观察值,验证集中有8个,让我们看看表现如何。
test_set = pd.read_json('transactions_grouped_prepared_valid.jsonl', lines=True)
test_set.head()
prompt | completion | |
---|---|---|
0 | Supplier: Wavetek Ltd; Description: Kelvin Hal... | 2 |
1 | Supplier: ECG Facilities Service; Description:... | 1 |
2 | Supplier: M & J Ballantyne Ltd; Description: G... | 2 |
3 | Supplier: Private Sale; Description: Literary ... | 0 |
4 | Supplier: Ex Libris; Description: IT equipment... | 3 |
test_set['predicted_class'] = test_set.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, prompt=x['prompt'], max_tokens=1, temperature=0, logprobs=5),axis=1)
test_set['pred'] = test_set.apply(lambda x : x['predicted_class']['choices'][0]['text'],axis=1)
test_set['result'] = test_set.apply(lambda x: str(x['pred']).strip() == str(x['completion']).strip(), axis = 1)
test_set['result'].value_counts()
True 4 False 4 Name: result, dtype: int64
性能不是很好 - 不幸的是这是可以预料到的。每个类别只有很少的示例,上述使用嵌入和传统分类器的方法效果更好。
一个经过微调的模型在有大量标记观测数据时效果最好。如果我们有几百或几千个观测数据,可能会得到更好的结果,但让我们在一个留出集上进行最后一次测试,以确认它对新的观测数据集泛化效果不佳。
holdout_df = transactions.copy().iloc[101:]
holdout_df.head()
Date | Supplier | Description | Transaction value (£) | |
---|---|---|---|---|
101 | 23/10/2017 | City Building LLP | Causewayside Refurbishment | 53147.0 |
102 | 30/10/2017 | ECG Facilities Service | Facilities Management Charge | 35758.0 |
103 | 30/10/2017 | ECG Facilities Service | Facilities Management Charge | 35758.0 |
104 | 06/11/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 134208.0 |
105 | 06/11/2017 | ALDL | Legal Deposit Services | 27067.0 |
holdout_df['combined'] = "Supplier: " + holdout_df['Supplier'].str.strip() + "; Description: " + holdout_df['Description'].str.strip() + '\n\n###\n\n' # + "; Value: " + str(df['Transaction value (£)']).strip()
holdout_df['prediction_result'] = holdout_df.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, prompt=x['combined'], max_tokens=1, temperature=0, logprobs=5),axis=1)
holdout_df['pred'] = holdout_df.apply(lambda x : x['prediction_result']['choices'][0]['text'],axis=1)
holdout_df.head(10)
Date | Supplier | Description | Transaction value (£) | combined | prediction_result | pred | |
---|---|---|---|---|---|---|---|
101 | 23/10/2017 | City Building LLP | Causewayside Refurbishment | 53147.0 | Supplier: City Building LLP; Description: Caus... | {'id': 'cmpl-63YDadbYLo8xKsGY2vReOFCMgTOvG', '... | 2 |
102 | 30/10/2017 | ECG Facilities Service | Facilities Management Charge | 35758.0 | Supplier: ECG Facilities Service; Description:... | {'id': 'cmpl-63YDbNK1D7UikDc3xi5ATihg5kQEt', '... | 2 |
103 | 30/10/2017 | ECG Facilities Service | Facilities Management Charge | 35758.0 | Supplier: ECG Facilities Service; Description:... | {'id': 'cmpl-63YDbwfiHjkjMWsfTKNt6naeqPzOe', '... | 2 |
104 | 06/11/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 134208.0 | Supplier: John Graham Construction Ltd; Descri... | {'id': 'cmpl-63YDbWAndtsRqPTi2ZHZtPodZvOwr', '... | 2 |
105 | 06/11/2017 | ALDL | Legal Deposit Services | 27067.0 | Supplier: ALDL; Description: Legal Deposit Ser... | {'id': 'cmpl-63YDbDu7WM3svYWsRAMdDUKtSFDBu', '... | 2 |
106 | 27/11/2017 | Maggs Bros Ltd | Literary & Archival Items | 26500.0 | Supplier: Maggs Bros Ltd; Description: Literar... | {'id': 'cmpl-63YDbxNNI8ZH5CJJNxQ0IF9Zf925C', '... | 0 |
107 | 30/11/2017 | Glasgow City Council | Kelvin Hall | 42345.0 | Supplier: Glasgow City Council; Description: K... | {'id': 'cmpl-63YDb8R1FWu4bjwM2xE775rouwneV', '... | 2 |
108 | 11/12/2017 | ECG Facilities Service | Facilities Management Charge | 35758.0 | Supplier: ECG Facilities Service; Description:... | {'id': 'cmpl-63YDcAPsp37WhbPs9kwfUX0kBk7Hv', '... | 2 |
109 | 11/12/2017 | John Graham Construction Ltd | Causewayside Refurbishment | 159275.0 | Supplier: John Graham Construction Ltd; Descri... | {'id': 'cmpl-63YDcML2welrC3wF0nuKgcNmVu1oQ', '... | 2 |
110 | 08/01/2018 | ECG Facilities Service | Facilities Management Charge | 35758.0 | Supplier: ECG Facilities Service; Description:... | {'id': 'cmpl-63YDc95SSdOHnIliFB2cjMEEm7Z2u', '... | 2 |
holdout_df['pred'].value_counts()
2 231 0 27 Name: pred, dtype: int64
这些结果同样令人失望 - 因此我们得出结论,对于一个带有少量标记观测数据的数据集,零样本分类或使用嵌入进行传统分类会比微调模型产生更好的结果。
微调模型仍然是一个很好的工具,但在每个要分类的类别中有更多标记示例的情况下效果更好。