!pip install BPEmb import math import numpy as np import tensorflow as tf from bpemb import BPEmb def scaled_dot_product_attention(query, key, value, mask=None): key_dim = tf.cast(tf.shape(key)[-1], tf.float32) scaled_scores = tf.matmul(query, key, transpose_b=True) / np.sqrt(key_dim) if mask is not None: scaled_scores = tf.where(mask==0, -np.inf, scaled_scores) softmax = tf.keras.layers.Softmax() weights = softmax(scaled_scores) return tf.matmul(weights, value), weights seq_len = 3 embed_dim = 4 queries = np.random.rand(seq_len, embed_dim) keys = np.random.rand(seq_len, embed_dim) values = np.random.rand(seq_len, embed_dim) print("Queries:\n", queries) output, attn_weights = scaled_dot_product_attention(queries, keys, values) print("Output\n", output, "\n") print("Weights\n", attn_weights) batch_size = 1 seq_len = 3 embed_dim = 12 num_heads = 3 head_dim = embed_dim // num_heads print(f"Dimension of each head: {head_dim}") x = np.random.rand(batch_size, seq_len, embed_dim).round(1) print("Input shape: ", x.shape, "\n") print("Input:\n", x) # The query weights for each head. wq0 = np.random.rand(embed_dim, head_dim).round(1) wq1 = np.random.rand(embed_dim, head_dim).round(1) wq2 = np.random.rand(embed_dim, head_dim).round(1) # The key weights for each head. wk0 = np.random.rand(embed_dim, head_dim).round(1) wk1 = np.random.rand(embed_dim, head_dim).round(1) wk2 = np.random.rand(embed_dim, head_dim).round(1) # The value weights for each head. wv0 = np.random.rand(embed_dim, head_dim).round(1) wv1 = np.random.rand(embed_dim, head_dim).round(1) wv2 = np.random.rand(embed_dim, head_dim).round(1) print("The three sets of query weights (one for each head):") print("wq0:\n", wq0) print("wq1:\n", wq1) print("wq2:\n", wq1) # Geneated queries, keys, and values for the first head. q0 = np.dot(x, wq0) k0 = np.dot(x, wk0) v0 = np.dot(x, wv0) # Geneated queries, keys, and values for the second head. q1 = np.dot(x, wq1) k1 = np.dot(x, wk1) v1 = np.dot(x, wv1) # Geneated queries, keys, and values for the third head. q2 = np.dot(x, wq2) k2 = np.dot(x, wk2) v2 = np.dot(x, wv2) print("Q, K, and V for first head:\n") print(f"q0 {q0.shape}:\n", q0, "\n") print(f"k0 {k0.shape}:\n", k0, "\n") print(f"v0 {v0.shape}:\n", v0) out0, attn_weights0 = scaled_dot_product_attention(q0, k0, v0) print("Output from first attention head: ", out0, "\n") print("Attention weights from first head: ", attn_weights0) out1, _ = scaled_dot_product_attention(q1, k1, v1) out2, _ = scaled_dot_product_attention(q2, k2, v2) print("Output from second attention head: ", out1, "\n") print("Output from third attention head: ", out2,) combined_out_a = np.concatenate((out0, out1, out2), axis=-1) print(f"Combined output from all heads {combined_out_a.shape}:") print(combined_out_a) # The final step would be to run combined_out_a through a linear/dense layer # for further processing. print("Query weights for first head: \n", wq0, "\n") print("Query weights for second head: \n", wq1, "\n") print("Query weights for third head: \n", wq2) wq = np.concatenate((wq0, wq1, wq2), axis=1) print(f"Single query weight matrix {wq.shape}: \n", wq) wk = np.concatenate((wk0, wk1, wk2), axis=1) wv = np.concatenate((wv0, wv1, wv2), axis=1) print(f"Single key weight matrix {wk.shape}:\n", wk, "\n") print(f"Single value weight matrix {wv.shape}:\n", wv) q_s = np.dot(x, wq) k_s = np.dot(x, wk) v_s = np.dot(x, wv) print(f"Query vectors using a single weight matrix {q_s.shape}:\n", q_s) print(q0, "\n") print(q1, "\n") print(q2) # Note: we can achieve the same thing by passing -1 instead of seq_len. q_s_reshaped = tf.reshape(q_s, (batch_size, seq_len, num_heads, head_dim)) print(f"Combined queries: {q_s.shape}\n", q_s, "\n") print(f"Reshaped into separate heads: {q_s_reshaped.shape}\n", q_s_reshaped) q_s_transposed = tf.transpose(q_s_reshaped, perm=[0, 2, 1, 3]).numpy() print(f"Queries transposed into \"separate\" heads {q_s_transposed.shape}:\n", q_s_transposed) print("The separate per-head query matrices from before: ") print(q0, "\n") print(q1, "\n") print(q2) k_s_transposed = tf.transpose(tf.reshape(k_s, (batch_size, -1, num_heads, head_dim)), perm=[0, 2, 1, 3]).numpy() v_s_transposed = tf.transpose(tf.reshape(v_s, (batch_size, -1, num_heads, head_dim)), perm=[0, 2, 1, 3]).numpy() print(f"Keys for all heads in a single matrix {k_s.shape}: \n", k_s_transposed, "\n") print(f"Values for all heads in a single matrix {v_s.shape}: \n", v_s_transposed) all_heads_output, all_attn_weights = scaled_dot_product_attention(q_s_transposed, k_s_transposed, v_s_transposed) print("Self attention output:\n", all_heads_output) print("Per head outputs from using separate sets of weights per head:") print(out0, "\n") print(out1, "\n") print(out2) combined_out_b = tf.reshape(tf.transpose(all_heads_output, perm=[0, 2, 1, 3]), shape=(batch_size, seq_len, embed_dim)) print("Final output from using single query, key, value matrices:\n", combined_out_b, "\n") print("Final output from using separate query, key, value matrices per head:\n", combined_out_a) class MultiHeadSelfAttention(tf.keras.layers.Layer): def __init__(self, d_model, num_heads): super(MultiHeadSelfAttention, self).__init__() self.d_model = d_model self.num_heads = num_heads self.d_head = self.d_model // self.num_heads self.wq = tf.keras.layers.Dense(self.d_model) self.wk = tf.keras.layers.Dense(self.d_model) self.wv = tf.keras.layers.Dense(self.d_model) # Linear layer to generate the final output. self.dense = tf.keras.layers.Dense(self.d_model) def split_heads(self, x): batch_size = x.shape[0] split_inputs = tf.reshape(x, (batch_size, -1, self.num_heads, self.d_head)) return tf.transpose(split_inputs, perm=[0, 2, 1, 3]) def merge_heads(self, x): batch_size = x.shape[0] merged_inputs = tf.transpose(x, perm=[0, 2, 1, 3]) return tf.reshape(merged_inputs, (batch_size, -1, self.d_model)) def call(self, q, k, v, mask): qs = self.wq(q) ks = self.wk(k) vs = self.wv(v) qs = self.split_heads(qs) ks = self.split_heads(ks) vs = self.split_heads(vs) output, attn_weights = scaled_dot_product_attention(qs, ks, vs, mask) output = self.merge_heads(output) return self.dense(output), attn_weights mhsa = MultiHeadSelfAttention(12, 3) output, attn_weights = mhsa(x, x, x, None) print(f"MHSA output{output.shape}:") print(output) def feed_forward_network(d_model, hidden_dim): return tf.keras.Sequential([ tf.keras.layers.Dense(hidden_dim, activation='relu'), tf.keras.layers.Dense(d_model) ]) class EncoderBlock(tf.keras.layers.Layer): def __init__(self, d_model, num_heads, hidden_dim, dropout_rate=0.1): super(EncoderBlock, self).__init__() self.mhsa = MultiHeadSelfAttention(d_model, num_heads) self.ffn = feed_forward_network(d_model, hidden_dim) self.dropout1 = tf.keras.layers.Dropout(dropout_rate) self.dropout2 = tf.keras.layers.Dropout(dropout_rate) self.layernorm1 = tf.keras.layers.LayerNormalization() self.layernorm2 = tf.keras.layers.LayerNormalization() def call(self, x, training, mask): mhsa_output, attn_weights = self.mhsa(x, x, x, mask) mhsa_output = self.dropout1(mhsa_output, training=training) mhsa_output = self.layernorm1(x + mhsa_output) ffn_output = self.ffn(mhsa_output) ffn_output = self.dropout2(ffn_output, training=training) output = self.layernorm2(mhsa_output + ffn_output) return output, attn_weights encoder_block = EncoderBlock(12, 3, 48) block_output, _ = encoder_block(x, True, None) print(f"Output from single encoder block {block_output.shape}:") print(block_output) # Load the English tokenizer. bpemb_en = BPEmb(lang="en") bpemb_vocab_size, bpemb_embed_size = bpemb_en.vectors.shape print("Vocabulary size:", bpemb_vocab_size) print("Embedding size:", bpemb_embed_size) # Embedding for the word "car". bpemb_en.vectors[bpemb_en.words.index('car')] sample_sentence = "Where can I find a pizzeria?" tokens = bpemb_en.encode(sample_sentence) print(tokens) token_seq = np.array(bpemb_en.encode_ids("Where can I find a pizzeria?")) print(token_seq) token_embed = tf.keras.layers.Embedding(bpemb_vocab_size, embed_dim) token_embeddings = token_embed(token_seq) # The untrained embeddings for our sample sentence. print("Embeddings for: ", sample_sentence) print(token_embeddings) max_seq_len = 256 pos_embed = tf.keras.layers.Embedding(max_seq_len, embed_dim) # Generate ids for each position of the token sequence. pos_idx = tf.range(len(token_seq)) print(pos_idx) # These are our positon embeddings. position_embeddings = pos_embed(pos_idx) print("Position embeddings for the input sequence\n", position_embeddings) input = token_embeddings + position_embeddings print("Input to the initial encoder block:\n", input) class Encoder(tf.keras.layers.Layer): def __init__(self, num_blocks, d_model, num_heads, hidden_dim, src_vocab_size, max_seq_len, dropout_rate=0.1): super(Encoder, self).__init__() self.d_model = d_model self.max_seq_len = max_seq_len self.token_embed = tf.keras.layers.Embedding(src_vocab_size, self.d_model) self.pos_embed = tf.keras.layers.Embedding(max_seq_len, self.d_model) # The original Attention Is All You Need paper applied dropout to the # input before feeding it to the first encoder block. self.dropout = tf.keras.layers.Dropout(dropout_rate) # Create encoder blocks. self.blocks = [EncoderBlock(self.d_model, num_heads, hidden_dim, dropout_rate) for _ in range(num_blocks)] def call(self, input, training, mask): token_embeds = self.token_embed(input) # Generate position indices for a batch of input sequences. num_pos = input.shape[0] * self.max_seq_len pos_idx = np.resize(np.arange(self.max_seq_len), num_pos) pos_idx = np.reshape(pos_idx, input.shape) pos_embeds = self.pos_embed(pos_idx) x = self.dropout(token_embeds + pos_embeds, training=training) # Run input through successive encoder blocks. for block in self.blocks: x, weights = block(x, training, mask) return x, weights # Batch of 3 sequences, each of length 10 (10 is also the # maximum sequence length in this case). seqs = np.random.randint(0, 10000, size=(3, 10)) print(seqs.shape) print(seqs) pos_ids = np.resize(np.arange(seqs.shape[1]), seqs.shape[0] * seqs.shape[1]) print(pos_ids) pos_ids = np.reshape(pos_ids, (3, 10)) print(pos_ids.shape) print(pos_ids) pos_embed(pos_ids) input_batch = [ "Where can I find a pizzeria?", "Mass hysteria over listeria.", "I ain't no circle back girl." ] bpemb_en.encode(input_batch) input_seqs = bpemb_en.encode_ids(input_batch) print("Vectorized inputs:") input_seqs padded_input_seqs = tf.keras.preprocessing.sequence.pad_sequences(input_seqs, padding="post") print("Input to the encoder:") print(padded_input_seqs.shape) print(padded_input_seqs) enc_mask = tf.cast(tf.math.not_equal(padded_input_seqs, 0), tf.float32) print("Input:") print(padded_input_seqs, '\n') print("Encoder mask:") print(enc_mask) enc_mask = enc_mask[:, tf.newaxis, tf.newaxis, :] enc_mask num_encoder_blocks = 6 # d_model is the embedding dimension used throughout. d_model = 12 num_heads = 3 # Feed-forward network hidden dimension width. ffn_hidden_dim = 48 src_vocab_size = bpemb_vocab_size max_input_seq_len = padded_input_seqs.shape[1] encoder = Encoder( num_encoder_blocks, d_model, num_heads, ffn_hidden_dim, src_vocab_size, max_input_seq_len) encoder_output, attn_weights = encoder(padded_input_seqs, training=True, mask=enc_mask) print(f"Encoder output {encoder_output.shape}:") print(encoder_output) class DecoderBlock(tf.keras.layers.Layer): def __init__(self, d_model, num_heads, hidden_dim, dropout_rate=0.1): super(DecoderBlock, self).__init__() self.mhsa1 = MultiHeadSelfAttention(d_model, num_heads) self.mhsa2 = MultiHeadSelfAttention(d_model, num_heads) self.ffn = feed_forward_network(d_model, hidden_dim) self.dropout1 = tf.keras.layers.Dropout(dropout_rate) self.dropout2 = tf.keras.layers.Dropout(dropout_rate) self.dropout3 = tf.keras.layers.Dropout(dropout_rate) self.layernorm1 = tf.keras.layers.LayerNormalization() self.layernorm2 = tf.keras.layers.LayerNormalization() self.layernorm3 = tf.keras.layers.LayerNormalization() # Note the decoder block takes two masks. One for the first MHSA, another # for the second MHSA. def call(self, encoder_output, target, training, decoder_mask, memory_mask): mhsa_output1, attn_weights = self.mhsa1(target, target, target, decoder_mask) mhsa_output1 = self.dropout1(mhsa_output1, training=training) mhsa_output1 = self.layernorm1(mhsa_output1 + target) mhsa_output2, attn_weights = self.mhsa2(mhsa_output1, encoder_output, encoder_output, memory_mask) mhsa_output2 = self.dropout2(mhsa_output2, training=training) mhsa_output2 = self.layernorm2(mhsa_output2 + mhsa_output1) ffn_output = self.ffn(mhsa_output2) ffn_output = self.dropout3(ffn_output, training=training) output = self.layernorm3(ffn_output + mhsa_output2) return output, attn_weights class Decoder(tf.keras.layers.Layer): def __init__(self, num_blocks, d_model, num_heads, hidden_dim, target_vocab_size, max_seq_len, dropout_rate=0.1): super(Decoder, self).__init__() self.d_model = d_model self.max_seq_len = max_seq_len self.token_embed = tf.keras.layers.Embedding(target_vocab_size, self.d_model) self.pos_embed = tf.keras.layers.Embedding(max_seq_len, self.d_model) self.dropout = tf.keras.layers.Dropout(dropout_rate) self.blocks = [DecoderBlock(self.d_model, num_heads, hidden_dim, dropout_rate) for _ in range(num_blocks)] def call(self, encoder_output, target, training, decoder_mask, memory_mask): token_embeds = self.token_embed(target) # Generate position indices. num_pos = target.shape[0] * self.max_seq_len pos_idx = np.resize(np.arange(self.max_seq_len), num_pos) pos_idx = np.reshape(pos_idx, target.shape) pos_embeds = self.pos_embed(pos_idx) x = self.dropout(token_embeds + pos_embeds, training=training) for block in self.blocks: x, weights = block(encoder_output, x, training, decoder_mask, memory_mask) return x, weights # Made up values. target_input_seqs = [ [1, 652, 723, 123, 62], [1, 25, 98, 129, 248, 215, 359, 249], [1, 2369, 1259, 125, 486], ] padded_target_input_seqs = tf.keras.preprocessing.sequence.pad_sequences(target_input_seqs, padding="post") print("Padded target inputs to the decoder:") print(padded_target_input_seqs.shape) print(padded_target_input_seqs) dec_padding_mask = tf.cast(tf.math.not_equal(padded_target_input_seqs, 0), tf.float32) dec_padding_mask = dec_padding_mask[:, tf.newaxis, tf.newaxis, :] print(dec_padding_mask) target_input_seq_len = padded_target_input_seqs.shape[1] look_ahead_mask = tf.linalg.band_part(tf.ones((target_input_seq_len, target_input_seq_len)), -1, 0) print(look_ahead_mask) dec_mask = tf.minimum(dec_padding_mask, look_ahead_mask) print("The decoder mask:") print(dec_mask) decoder = Decoder(6, 12, 3, 48, 10000, 8) decoder_output, _ = decoder(encoder_output, padded_target_input_seqs, True, dec_mask, enc_mask) print(f"Decoder output {decoder_output.shape}:") print(decoder_output) class Transformer(tf.keras.Model): def __init__(self, num_blocks, d_model, num_heads, hidden_dim, source_vocab_size, target_vocab_size, max_input_len, max_target_len, dropout_rate=0.1): super(Transformer, self).__init__() self.encoder = Encoder(num_blocks, d_model, num_heads, hidden_dim, source_vocab_size, max_input_len, dropout_rate) self.decoder = Decoder(num_blocks, d_model, num_heads, hidden_dim, target_vocab_size, max_target_len, dropout_rate) # The final dense layer to generate logits from the decoder output. self.output_layer = tf.keras.layers.Dense(target_vocab_size) def call(self, input_seqs, target_input_seqs, training, encoder_mask, decoder_mask, memory_mask): encoder_output, encoder_attn_weights = self.encoder(input_seqs, training, encoder_mask) decoder_output, decoder_attn_weights = self.decoder(encoder_output, target_input_seqs, training, decoder_mask, memory_mask) return self.output_layer(decoder_output), encoder_attn_weights, decoder_attn_weights transformer = Transformer( num_blocks = 6, d_model = 12, num_heads = 3, hidden_dim = 48, source_vocab_size = bpemb_vocab_size, target_vocab_size = 7000, # made-up target vocab size. max_input_len = padded_input_seqs.shape[1], max_target_len = padded_target_input_seqs.shape[1]) transformer_output, _, _ = transformer(padded_input_seqs, padded_target_input_seqs, True, enc_mask, dec_mask, memory_mask=enc_mask) print(f"Transformer output {transformer_output.shape}:") print(transformer_output) # If training, we would use this output to calculate losses. !pip install transformers !pip install datasets import operator import pandas as pd import tensorflow as tf import transformers from datasets import load_dataset from tensorflow import keras from transformers import AutoTokenizer from transformers import pipeline from transformers import TFAutoModelForQuestionAnswering classifier = pipeline("text-classification") classifier("Alice was excited to go the island but it didn't live up to the hype.") classifier("Bob doesn't do well in group situations but he said it wasn't bad.") summarizer = pipeline("summarization") text = """ Hans Niemann is launching a counterattack in his dispute with chess world champion Magnus Carlsen, filing a federal lawsuit that accuses Carlsen of maliciously colluding with others to defame the 19-year-old grandmaster and ruin his career. It's the latest move in a scandal that has injected unprecedented levels of drama into the world of elite chess since early September, when Carlsen suggested Niemann's upset victory over him at the Sinquefield Cup tournament in St. Louis was the result of cheating. Niemann wants a federal court in Missouri's eastern district to award him at least $100 million in damages. Defendants in the lawsuit include Carlsen, his company Play Magnus Group, the online platform Chess.com and its leader, Danny Rensch, along with grandmaster Hikaru Nakamura. """ summarizer(text) qa = pipeline("question-answering") context=""" Hugging Face was founded in 2016 by Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for democratizing machine learning. In March 2021, Hugging Face raised $40 million in a Series B funding round. """ question = "Who are the Hugging Face founders?" qa(question=question, context=context) question = "What does Hugging Face do?" qa(question=question, context=context) ner = pipeline(model="dslim/bert-base-NER") text = "Panic ensues in Redmond as love child of Microsoft and OpenAI declares humanity obsolete." ner(text) data = load_dataset("squad") data pd.DataFrame(data['train'][0, 1, 2, 100, 101, 102], columns=["context", "question", "answers"]) model_name = 'distilroberta-base' tokenizer = AutoTokenizer.from_pretrained(model_name) t = "Where can I find a pizzeria?" print(tokenizer.encode(t)) encoded_t = tokenizer(t) print(encoded_t) print(tokenizer.convert_ids_to_tokens(encoded_t['input_ids'])) encoded_pair = tokenizer("this is a question", "this is the context") print(encoded_pair) print(tokenizer.convert_ids_to_tokens(encoded_pair['input_ids'])) assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast) context = "Sarah went to The Mirthless Cafe last night to meet her friend." question = "Where did Sarah go?" # The answer span and the answer's starting character position in the context. answer = "The Mirthless Cafe" answer_start = 14 x = tokenizer(question, context) x tokenizer.batch_decode(x['input_ids']) example_max_length = 15 x = tokenizer(question, context, max_length=example_max_length, truncation="only_second") x tokenizer.batch_decode(x['input_ids']) x = tokenizer(question, context, max_length=example_max_length, truncation="only_second", return_overflowing_tokens=True, padding="max_length") x len(x['input_ids']) tokenizer.batch_decode(x['input_ids']) tokenizer(['question 1', 'question 2'], ['context 1', 'context 2'], return_overflowing_tokens=True) stride = 5 x = tokenizer(question, context, max_length=example_max_length, truncation="only_second", return_overflowing_tokens=True, stride=stride, padding="max_length") tokenizer.batch_decode(x['input_ids']) print(x.keys(), '\n') x print(answer_start) print(context[answer_start:answer_start+len(answer)]) x = tokenizer(question, context, max_length=example_max_length, truncation="only_second", return_overflowing_tokens=True, stride=stride, return_offsets_mapping=True, padding="max_length") x print(len(x['input_ids'])) print(len(x['offset_mapping'])) print(x['input_ids'][0]) print(x['offset_mapping'][0]) print("First non-special input_id converted to token:") print(tokenizer.convert_ids_to_tokens(x['input_ids'][0][1]), "\n") offset = x['offset_mapping'][0][1] print(f"Span extracted from context using corresponding offset_mapping {offset}:") print(question[offset[0]:offset[1]]) print(x['offset_mapping'][0]) print(x['offset_mapping'][1]) print(x['input_ids'][0]) print(x.sequence_ids(0)) # We can calculate the answer end character position using the answer length. answer_end = answer_start + len(answer) print("Answer start character position:", answer_start) print("Answer end character position:", answer_end) print("Answer pulled from context:", context[answer_start:answer_end]) tokenizer.batch_decode(x['input_ids']) input_ids = x['input_ids'][0] offset_mapping = x['offset_mapping'][0] seq_ids = x.sequence_ids(0) # These are the sequence ids print("Sequence IDs: ", seq_ids) # Get the start index position (i.e. the first occurrence of 1). context_pos_start = seq_ids.index(1) # Utility function to find the *last* occurrence of a sequence. def rindex(lst, value): return len(lst) - operator.indexOf(reversed(lst), value) - 1 # Get the end index position (i.e. the last occurrence of 1). context_pos_end = rindex(seq_ids, 1) print("Context tokens begin at position", context_pos_start) print("Context tokens end at position", context_pos_end) # These are the corresponding offsets. context_offsets = offset_mapping[context_pos_start:context_pos_end+1] print(context_offsets) print("Is the lowest offset value lower than or equal to the starting character position?") print("Answer starting character position:", answer_start) print("First offset:", context_offsets[0]) # Note how we're checking the first tuple value. print(context_offsets[0][0] <= answer_start) print("Is the highest offset value higher than or equal to the ending character position?") print("Answer ending character position:", answer_end) print("Last offset:", context_offsets[-1]) # Note how how we're checking the second tuple value. print(context_offsets[-1][1] >= answer_end) print(tokenizer.batch_decode(input_ids)) input_ids = x['input_ids'][2] offset_mapping = x['offset_mapping'][2] seq_ids = x.sequence_ids(2) context_pos_start = seq_ids.index(1) context_pos_end = rindex(seq_ids, 1) context_offsets = offset_mapping[context_pos_start:context_pos_end+1] print("Is the lowest offset value lower than or equal to the starting character position?") print("Answer starting character position:", answer_start) print("First offset:", context_offsets[0]) # Note how we're checking the first tuple value. print(context_offsets[0][0] <= answer_start) print("Is the highest offset value higher than or equal to the ending character position?") print("Answer ending character position:", answer_end) print("Last offset:", context_offsets[-1]) # Note how how we're checking the second tuple value. print(context_offsets[-1][1] >= answer_end) s = e = 0 # Start scanning the offset_mapping from the # left to find the token position where the answer starts. # It's not guaranteed a tokenizer will output a token where the # starting character matches the first answer character. When # this happens, we take the previous token's position as our start. i = context_pos_start while offset_mapping[i][0] < answer_start: i += 1 if offset_mapping[i][0] == answer_start: s = i else: s = i - 1 # Same idea when finding the ending token position. j = context_pos_end while offset_mapping[j][1] > answer_end: j -= 1 if offset_mapping[j][1] == answer_end: e = j else: e = j + 1 print("Answer start token position in context:", s) print("Answer end token position in context:", e) print("Answer lifted from context:") tokenizer.batch_decode(input_ids[s:e+1]) def prepare_dataset(examples): # Some tokenizers don't strip spaces. If there happens to be question text # with excessive spaces, the context may not get encoded at all. examples["question"] = [q.lstrip() for q in examples["question"]] examples["context"] = [c.lstrip() for c in examples["context"]] # Tokenize. tokenized_examples = tokenizer( examples['question'], examples['context'], truncation="only_second", max_length = max_length, stride=stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length" ) # We'll collect a list of starting positions and ending positions. tokenized_examples['start_positions'] = [] tokenized_examples['end_positions'] = [] # Work through every sequence. for seq_idx in range(len(tokenized_examples['input_ids'])): seq_ids = tokenized_examples.sequence_ids(seq_idx) offset_mappings = tokenized_examples['offset_mapping'][seq_idx] cur_example_idx = tokenized_examples['overflow_to_sample_mapping'][seq_idx] answer = examples['answers'][cur_example_idx] answer_text = answer['text'][0] answer_start = answer['answer_start'][0] answer_end = answer_start + len(answer_text) context_pos_start = seq_ids.index(1) context_pos_end = rindex(seq_ids, 1) s = e = 0 if (offset_mappings[context_pos_start][0] <= answer_start and offset_mappings[context_pos_end][1] >= answer_end): i = context_pos_start while offset_mappings[i][0] < answer_start: i += 1 if offset_mappings[i][0] == answer_start: s = i else: s = i - 1 j = context_pos_end while offset_mappings[j][1] > answer_end: j -= 1 if offset_mappings[j][1] == answer_end: e = j else: e = j + 1 tokenized_examples['start_positions'].append(s) tokenized_examples['end_positions'].append(e) return tokenized_examples max_length = 400 stride = 100 batch_size = 32 tokenized_datasets = data.map( prepare_dataset, batched=True, remove_columns=data["train"].column_names, num_proc=2, ) data = tokenized_datasets.remove_columns(["offset_mapping", "overflow_to_sample_mapping"]) train_set = data['train'].to_tf_dataset(batch_size=batch_size) validation_set = data['validation'].to_tf_dataset(batch_size=batch_size) model = TFAutoModelForQuestionAnswering.from_pretrained(model_name) def get_answer(tokenizer, model, question, context): inputs = tokenizer([question], [context], return_tensors="np") outputs = model(inputs) start_position = tf.argmax(outputs.start_logits, axis=1) end_position = tf.argmax(outputs.end_logits, axis=1) answer = inputs["input_ids"][0, int(start_position) : int(end_position) + 1] return tokenizer.decode(answer).strip() c = "Sarah went to The Mirthless Cafe last night to meet her friend." q = "Where did Sarah go?" get_answer(tokenizer, model, q, c) # https://www.tensorflow.org/guide/mixed_precision keras.mixed_precision.set_global_policy("mixed_float16") # Use a learning rate recommended by the BERT authors. # https://github.com/google-research/bert model.compile(optimizer=keras.optimizers.Adam(learning_rate=3e-5)) model.fit(train_set, validation_data=validation_set, epochs=1) c = "Sarah went to The Mirthless Cafe last night to meet her friend." q = "Where did Sarah go?" get_answer(tokenizer, model, q, c) q = "Who did Sarah meet?" get_answer(tokenizer, model, q, c) q = "When did Sarah meet her friend?" get_answer(tokenizer, model, q, c) q = "Who went to the restaurant?" get_answer(tokenizer, model, q, c) # Asking a logic teaser question is difficult despite the # answer being available. To be fair, there is ambiguity here. q = "Who did Sarah's friend meet?" get_answer(tokenizer, model, q, c) # The model can't determine when a question can't be # answered. Some question answering datasets explicitly # train for this. q = "How did Sarah get to the restaurant?" get_answer(tokenizer, model, q, c) # The model isn't generative, either. q = "What is a possible reason for why Sarah met her friend?" get_answer(tokenizer, model, q, c)