from google.colab import drive
drive.mount('/content/gdrive')
import os
os.chdir('/content/gdrive/My Drive/finch/tensorflow1/free_chat/chinese_lccc/main')
%tensorflow_version 1.x
!pip install bert4keras
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import AutoRegressiveDecoder
from bert4keras.snippets import uniout
config_path = '../model/GPT_LCCC-large-tf/gpt_config.json'
checkpoint_path = '../model/GPT_LCCC-large-tf/gpt_model.ckpt'
dict_path = '../model/GPT_LCCC-large-tf/vocab.txt'
tokenizer = Tokenizer(dict_path, do_lower_case=True)
speakers = [
tokenizer.token_to_id('[speaker1]'),
tokenizer.token_to_id('[speaker2]')
]
model = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
model='gpt_openai'
)
class ChatBot(AutoRegressiveDecoder):
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids, segment_ids = inputs
curr_segment_ids = np.zeros_like(output_ids) + token_ids[0, -1]
token_ids = np.concatenate([token_ids, output_ids], 1)
segment_ids = np.concatenate([segment_ids, curr_segment_ids], 1)
return model.predict([token_ids, segment_ids])[:, -1]
def response(self, texts, topk=5):
token_ids = [tokenizer._token_start_id, speakers[0]]
segment_ids = [tokenizer._token_start_id, speakers[0]]
for i, text in enumerate(texts):
ids = tokenizer.encode(text)[0][1:-1] + [speakers[(i + 1) % 2]]
token_ids.extend(ids)
segment_ids.extend([speakers[i % 2]] * len(ids))
segment_ids[-1] = speakers[(i + 1) % 2]
results = self.random_sample([token_ids, segment_ids], 1, topk)
return tokenizer.decode(results[0])
chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32)
query_li = [
'你好',
'早上好',
'晚上好',
'再见',
'好久不见',
'想死你了',
'谢谢你',
'爱你',
'你叫什么名字',
'你几岁了',
'现在几点了',
'今天天气怎么样',
'我们现在在哪里',
'你能给我讲个笑话吗',
'你是男孩还是女孩呀',
'你会几种语言呀',
'你能陪我玩吗',
'说话可以大声一点吗',
'天气真好',
'天气太糟糕了',
'下雨了',
'雨好大',
'我讨厌艳阳天',
'好晒啊',
'今天好冷',
'今天好热',
'风好大',
'雾太大了看不清路',
'打雷了好可怕',
'下雪了诶',
'好烦啊',
'好开心',
'太激动了',
'我好难过',
'我想哭',
'太好笑了',
'我好伤心',
'心好痛',
'好累啊',
'我好疲惫',
'我爱你',
'我讨厌你',
'你真是太棒啦',
'你好厉害啊',
'吓死我了',
'我想回家',
'我想爸妈了',
'不知道小孩在家有没有听话',
'想回家撸猫',
]
for q in query_li:
print('Q:', q)
print('A:', chatbot.response([q]))
print()