To set custom kernel for notebook https://scicomp.aalto.fi/triton/apps/jupyter/#installing-kernels-from-virtualenvs-or-anaconda-environments
import warnings
warnings.filterwarnings('ignore')
import os
import sys
import time
## Set environment variables, this must be done before importing transformers
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['HF_HOME']='/scratch/shareddata/dldata/huggingface-hub-cache'
print(os.environ['TRANSFORMERS_OFFLINE'])
1
from transformers import pipeline
pipeline = pipeline("sentiment-analysis")
# Prepare input text
inputs = ["What a lovely day today!","It is freezing outside."]
results = pipeline(inputs)
print("Results:", results)
No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english). Using a pipeline without specifying a model name and revision in production is not recommended.
Results: [{'label': 'POSITIVE', 'score': 0.999874472618103}, {'label': 'NEGATIVE', 'score': 0.9937865734100342}]
from transformers import pipeline
pipeline = pipeline("text-generation")
# Prepare input text
input_text = "The capital of France is"
output = pipeline(input_text, max_length=50)
generated_text = output[0]['generated_text']
print("Generated text:", generated_text)
No model was supplied, defaulted to gpt2 and revision 6c0e608 (https://huggingface.co/gpt2). Using a pipeline without specifying a model name and revision in production is not recommended. Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Generated text: The capital of France is that of the great French historian and poet Maurice de Loewie. His books, the History of France, which are the result of the same years and during the latter years of his life, are regarded as the definitive work
What happens in the pipeline?
Tokenization => Model => Post Processing
# Print relevant tokenizer information
print("Tokenizer Name:", pipeline.tokenizer.name_or_path)
print("Vocabulary Size:", pipeline.tokenizer.vocab_size)
print("Max Model Input Sizes:", pipeline.tokenizer.model_max_length)
print("Special Tokens:", pipeline.tokenizer.special_tokens_map)
Tokenizer Name: gpt2 Vocabulary Size: 50257 Max Model Input Sizes: 1024 Special Tokens: {'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}
# checkout the model architecture
pipeline.model
GPT2LMHeadModel( (transformer): GPT2Model( (wte): Embedding(50257, 768) (wpe): Embedding(1024, 768) (drop): Dropout(p=0.1, inplace=False) (h): ModuleList( (0-11): 12 x GPT2Block( (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (attn): GPT2Attention( (c_attn): Conv1D() (c_proj): Conv1D() (attn_dropout): Dropout(p=0.1, inplace=False) (resid_dropout): Dropout(p=0.1, inplace=False) ) (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (mlp): GPT2MLP( (c_fc): Conv1D() (c_proj): Conv1D() (act): NewGELUActivation() (dropout): Dropout(p=0.1, inplace=False) ) ) ) (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) (lm_head): Linear(in_features=768, out_features=50257, bias=False) )
# checkout the model config
pipeline.model.config
GPT2Config { "_name_or_path": "gpt2", "activation_function": "gelu_new", "architectures": [ "GPT2LMHeadModel" ], "attn_pdrop": 0.1, "bos_token_id": 50256, "do_sample": true, "embd_pdrop": 0.1, "eos_token_id": 50256, "initializer_range": 0.02, "layer_norm_epsilon": 1e-05, "max_length": 50, "model_type": "gpt2", "n_ctx": 1024, "n_embd": 768, "n_head": 12, "n_inner": null, "n_layer": 12, "n_positions": 1024, "reorder_and_upcast_attn": false, "resid_pdrop": 0.1, "scale_attn_by_inverse_layer_idx": false, "scale_attn_weights": true, "summary_activation": null, "summary_first_dropout": 0.1, "summary_proj_to_labels": true, "summary_type": "cls_index", "summary_use_proj": true, "task_specific_params": { "text-generation": { "do_sample": true, "max_length": 50 } }, "transformers_version": "4.36.0", "use_cache": true, "vocab_size": 50257 }
Tokenizers prepares text data for processing by Transformer models.
Tokenizers' function:
Text Preprocessing: Splitting Text into Tokens
Convert Tokens to IDs: Each token is mapped to a unique integer ID.
Add Special Tokens:
Handle Fixed Sequence Lengths: Transformer models require inputs of a fixed length. Tokenizers pad shorter inputs with [PAD] tokens and truncate longer ones to meet the model's length requirements.
Attention Mask: The tokenizer generates an attention mask to differentiate real tokens from padding tokens ([PAD]) such that the model will pay attention only to the relevant parts of the input.
Consistency Across Languages: For multilingual models, tokenizers ensure consistent tokenization across different languages, maintaining a balanced and shared vocabulary.
Three tokenizer types: Word-based, Subword-based, Character-based.
BERT (Bidirectional Encoder Representations from Transformers): Uses the WordPiece tokenizer.
GPT-2 and GPT-3 (Generative Pre-trained Transformer): Utilize a variant of Byte Pair Encoding (BPE).
T5 (Text-To-Text Transfer Transformer): Employs the SentencePiece tokenizer, which is versatile and can be used across different languages and scripts.
from transformers import BertTokenizer
# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Example text
text = "Hello, how many GPUs do you need?"
# Tokenize the text
tokens = tokenizer.tokenize(text)
print(tokens)
# Convert tokens to token IDs
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)
['hello', ',', 'how', 'many', 'gp', '##us', 'do', 'you', 'need', '?'] [7592, 1010, 2129, 2116, 14246, 2271, 2079, 2017, 2342, 1029]
from transformers import GPT2Tokenizer
# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# Example text
text = "Hello, how many GPUs do you need?"
# Tokenize the text
tokens = tokenizer.tokenize(text)
print(tokens)
# Convert tokens to token IDs
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)
['Hello', ',', 'Ġhow', 'Ġmany', 'ĠGPUs', 'Ġdo', 'Ġyou', 'Ġneed', '?'] [15496, 11, 703, 867, 32516, 466, 345, 761, 30]
from transformers import T5Tokenizer
# Initialize the tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-base')
# Example text
text = "Hello, how many GPUs do you need?"
# Tokenize the text
tokens = tokenizer.tokenize(text,add_special_tokens=True)
print(tokens)
# Convert tokens to token IDs
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
['▁Hello', ',', '▁how', '▁many', '▁GPU', 's', '▁do', '▁you', '▁need', '?'] [8774, 6, 149, 186, 23356, 7, 103, 25, 174, 58]
NOTE: A pretrained model only performs properly when the input was tokenized under the same rules that its training data were tokenized.
from transformers import PreTrainedTokenizer
#Directely call a PreTrainedTokenizer, this will throw errors.
tokenizer = PreTrainedTokenizer.from_pretrained('bert-base-uncased')
encoded_input = tokenizer("Hello, Hugging Face!")
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. The tokenizer class you load from this checkpoint is 'BertTokenizer'. The class this function is called from is 'PreTrainedTokenizer'.
--------------------------------------------------------------------------- NotImplementedError Traceback (most recent call last) Cell In[16], line 4 1 from transformers import PreTrainedTokenizer 3 #Directely call a PreTrainedTokenizer, this will throw errors. ----> 4 tokenizer = PreTrainedTokenizer.from_pretrained('bert-base-uncased') 5 encoded_input = tokenizer("Hello, Hugging Face!") File /scratch/shareddata/LLMs_tools/conda-llm/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:2028, in PreTrainedTokenizerBase.from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, *init_inputs, **kwargs) 2025 else: 2026 logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}") -> 2028 return cls._from_pretrained( 2029 resolved_vocab_files, 2030 pretrained_model_name_or_path, 2031 init_configuration, 2032 *init_inputs, 2033 token=token, 2034 cache_dir=cache_dir, 2035 local_files_only=local_files_only, 2036 _commit_hash=commit_hash, 2037 _is_local=is_local, 2038 **kwargs, 2039 ) File /scratch/shareddata/LLMs_tools/conda-llm/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:2260, in PreTrainedTokenizerBase._from_pretrained(cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, token, cache_dir, local_files_only, _commit_hash, _is_local, *init_inputs, **kwargs) 2258 # Instantiate the tokenizer. 2259 try: -> 2260 tokenizer = cls(*init_inputs, **init_kwargs) 2261 except OSError: 2262 raise OSError( 2263 "Unable to load vocabulary from file. " 2264 "Please check that the provided vocabulary is accessible and not corrupted." 2265 ) File /scratch/shareddata/LLMs_tools/conda-llm/lib/python3.12/site-packages/transformers/tokenization_utils.py:367, in PreTrainedTokenizer.__init__(self, **kwargs) 363 super().__init__(**kwargs) 365 # 4. If some of the special tokens are not part of the vocab, we add them, at the end. 366 # the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following `tokenizers` --> 367 self._add_tokens( 368 [token for token in self.all_special_tokens_extended if token not in self._added_tokens_encoder], 369 special_tokens=True, 370 ) 372 self._decode_use_source_tokenizer = False File /scratch/shareddata/LLMs_tools/conda-llm/lib/python3.12/site-packages/transformers/tokenization_utils.py:467, in PreTrainedTokenizer._add_tokens(self, new_tokens, special_tokens) 465 return added_tokens 466 # TODO this is fairly slow to improve! --> 467 current_vocab = self.get_vocab().copy() 468 new_idx = len(current_vocab) # only call this once, len gives the last index + 1 469 for token in new_tokens: File /scratch/shareddata/LLMs_tools/conda-llm/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1675, in PreTrainedTokenizerBase.get_vocab(self) 1665 def get_vocab(self) -> Dict[str, int]: 1666 """ 1667 Returns the vocabulary as a dictionary of token to index. 1668 (...) 1673 `Dict[str, int]`: The vocabulary. 1674 """ -> 1675 raise NotImplementedError() NotImplementedError:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',padding=True,truncation=True,max_length=20)
# Example text
text = "The capital of Finland is?"
# Tokenize the text
tokens = tokenizer.tokenize(text)
print(tokens)
['the', 'capital', 'of', 'finland', 'is', '?']
Hyperparameters in tokenizer:
tokens = tokenizer.tokenize(text, padding=True,truncation=True,max_length=20)
tokens
Keyword arguments {'padding': True, 'truncation': True, 'max_length': 20} not recognized.
['the', 'capital', 'of', 'finland', 'is', '?']
NOTE: Call a tokenizer directly is used when you're preparing data for model input (like training or inference). Whereas the tokenize() method is used when you need a token-level analysis or manipulation of the text.
Hyperparameters like padding
, truncate
, `max_length`` are not recognized by tokenize() method.
text = ["Hello, Hugging Face! Tell me about all your tokenizer types.", "Hello, world!"]
# call a tokenizer directly, invoking its __call__ method
encoded_input = tokenizer(text, padding=True,truncation=True,max_length=20)
for item in encoded_input.items():
print(item)
('input_ids', [[101, 7592, 1010, 17662, 2227, 999, 2425, 2033, 2055, 2035, 2115, 19204, 17629, 4127, 1012, 102], [101, 7592, 1010, 2088, 999, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) ('token_type_ids', [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) ('attention_mask', [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
https://huggingface.co/docs/transformers/model_doc/auto
Base model is also referred to as a pre-trained model, is a model that has already been trained on a large, generic dataset. The primary purpose of a base model is to capture a wide range of language features and understandings, such as grammar, context, and basic associations. A base model provides a robust foundation of language understanding which can be adapted for specific tasks.
Base models in Huggingface are often named after the architecture they use, like bert-base-uncased, gpt2-medium,t5-base, etc.
A fine-tuned model is a model that has undergone additional training (fine-tuning) on a smaller, task-specific dataset. This can include tasks like sentiment analysis, question answering, or domain-specific language understanding.
Fine-tuned models usually have additional descriptors in their names indicating the specific task or dataset they are fine-tuned for. For instance, "bert-base-uncased-finetuned-squad" is a BERT model fine-tuned on the SQuAD dataset for question answering, whereas "bert-base-uncased" is a base model.
More information can usually be found in the README or model description in the model repo. Besides, inspecting the Model's Configuration or architecture can also give hints.
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# Initialize the tokenizer for GPT-2
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Load the pre-trained GPT-2 model
model = GPT2LMHeadModel.from_pretrained("gpt2")
# Prepare input text
input_text = "The capital of France is"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate attention mask
attention_mask = tokenizer(input_text, return_tensors="pt").attention_mask
# Set pad token ID if it's not already set
model.config.pad_token_id = model.config.eos_token_id
# Generate output
outputs = model.generate(input_ids, max_length=50)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)
Generated text: The capital of France is the capital of the French Republic, and the capital of the French Republic is the capital of the French Republic. The French Republic is the capital of the French Republic. The French Republic is the capital of the
attention_mask
tensor([[1, 1, 1, 1, 1]])
Do I need to look for the specific tokenizer and model classes for my tasks every time?
In many cases, no. The architecture you want to use can be guessed from the name or the path of the pretrained model. Huggingface provides AutoClasses to help you automatically retrieve the relevant model given the name/path to the pretrained weights/config/vocabulary.
## NOTE: AutoModel will instantiate a base model class without a specific head, so we still need
## a "relatively specific" class AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
# Initialize the tokenizer for GPT-2
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Load the pre-trained GPT-2 model
model = AutoModelForCausalLM.from_pretrained("gpt2")
# Prepare input text
input_text = "The capital of France is"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate attention mask
attention_mask = tokenizer(input_text, return_tensors="pt").attention_mask
# Set pad token ID if it's not already set
model.config.pad_token_id = model.config.eos_token_id
# Generate output
outputs = model.generate(input_ids, max_length=50)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)
Generated text: The capital of France is the capital of the French Republic, and the capital of the French Republic is the capital of the French Republic. The French Republic is the capital of the French Republic. The French Republic is the capital of the
output_hidden_states=True
in the configuration or when calling the model to obtain Hidden States.output_attentions=True
in the configuration or when calling the model to obtain Attentions.model
BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(28996, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0-11): 12 x BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() ) )
from transformers import AutoTokenizer, AutoModel
model = AutoModel.from_pretrained("bert-base-cased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
# Prepare input text
input_text = "The capital of France is"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# get hidden state
outputs = model(input_ids)
print(outputs.last_hidden_state)
tensor([[[ 0.2629, 0.0496, 0.1699, ..., -0.0339, 0.2812, -0.0489], [-0.3381, -0.2910, 0.2394, ..., 0.4664, -0.4263, 0.2448], [-0.3315, -0.1127, -0.1425, ..., 0.6752, -0.1898, 0.5174], ..., [-0.1510, 0.4374, -0.2816, ..., 0.3068, 0.4450, 0.4092], [ 0.0758, 0.1059, 0.0871, ..., 0.3782, 0.2463, -0.2250], [-0.0174, -0.1541, -1.0330, ..., 0.4842, 0.6491, 0.2534]]], grad_fn=<NativeLayerNormBackward0>)
from transformers import AutoTokenizer, AutoModelForCausalLM
# Initialize the tokenizer for GPT-2
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Load the pre-trained GPT-2 model
model = AutoModelForCausalLM.from_pretrained("gpt2")
# Prepare input text
input_text = "The capital of France is"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate attention mask
attention_mask = tokenizer(input_text, return_tensors="pt").attention_mask
# Set pad token ID if it's not already set
model.config.pad_token_id = model.config.eos_token_id
# Generate output
outputs = model(input_ids, output_hidden_states=True, output_attentions=True)
print("logits:",outputs.logits)
print("Attentions:",outputs.attentions)
logits: tensor([[[ -36.2874, -35.0114, -38.0793, ..., -40.5164, -41.3760, -34.9193], [ -75.1021, -75.6483, -82.6827, ..., -82.5961, -79.3913, -76.2687], [ -80.0968, -78.6868, -81.2341, ..., -83.7548, -85.6541, -79.8042], [ -86.0085, -86.4618, -91.0184, ..., -98.6912, -93.3734, -87.9286], [-108.9542, -108.9327, -112.5793, ..., -118.3345, -113.1505, -110.3779]]], grad_fn=<UnsafeViewBackward0>) Attentions: (tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.4640e-01, 1.5360e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.0135e-01, 2.2373e-01, 7.4919e-02, 0.0000e+00, 0.0000e+00], [6.0768e-01, 1.7884e-01, 1.4391e-01, 6.9565e-02, 0.0000e+00], [6.0990e-01, 1.5188e-01, 6.2560e-02, 9.4493e-02, 8.1164e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.4739e-04, 9.9985e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [2.7310e-03, 8.2182e-04, 9.9645e-01, 0.0000e+00, 0.0000e+00], [3.3176e-04, 2.4361e-03, 1.4651e-03, 9.9577e-01, 0.0000e+00], [3.2342e-03, 2.1838e-03, 1.5252e-02, 1.1193e-03, 9.7821e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.1416e-01, 8.5841e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.7615e-01, 1.3776e-01, 1.8609e-01, 0.0000e+00, 0.0000e+00], [4.5474e-01, 2.0124e-01, 1.1568e-01, 2.2834e-01, 0.0000e+00], [4.6935e-01, 7.1428e-02, 2.0851e-01, 7.0084e-02, 1.8062e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.5014e-01, 8.4986e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.3307e-01, 9.9187e-02, 7.6774e-01, 0.0000e+00, 0.0000e+00], [2.8450e-02, 1.4408e-02, 1.2834e-03, 9.5586e-01, 0.0000e+00], [1.0551e-01, 1.1721e-02, 6.7463e-02, 1.7100e-02, 7.9820e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [3.9975e-01, 6.0025e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [4.2455e-01, 5.0290e-01, 7.2550e-02, 0.0000e+00, 0.0000e+00], [5.8611e-02, 1.8143e-02, 1.3441e-02, 9.0980e-01, 0.0000e+00], [2.7160e-01, 1.7796e-01, 1.3677e-01, 1.3924e-01, 2.7444e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [2.7893e-02, 9.7211e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [2.9158e-01, 2.6920e-02, 6.8150e-01, 0.0000e+00, 0.0000e+00], [1.4353e-02, 2.2892e-04, 7.7237e-06, 9.8541e-01, 0.0000e+00], [6.2411e-02, 9.3040e-03, 8.2062e-03, 1.0775e-03, 9.1900e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.9259e-01, 3.0741e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [4.7506e-01, 4.8846e-01, 3.6488e-02, 0.0000e+00, 0.0000e+00], [4.7012e-01, 2.7895e-01, 3.9287e-02, 2.1164e-01, 0.0000e+00], [2.1669e-01, 3.4696e-01, 3.3619e-02, 3.5308e-01, 4.9655e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7262e-01, 2.7382e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [5.1730e-01, 3.3254e-01, 1.5016e-01, 0.0000e+00, 0.0000e+00], [4.0525e-01, 9.0645e-02, 3.4848e-01, 1.5563e-01, 0.0000e+00], [2.3696e-01, 1.1449e-01, 1.5967e-01, 1.5589e-01, 3.3299e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.0300e-01, 9.6995e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [4.8718e-01, 1.0044e-01, 4.1238e-01, 0.0000e+00, 0.0000e+00], [6.1475e-01, 1.2612e-01, 1.7063e-01, 8.8498e-02, 0.0000e+00], [2.3666e-01, 4.0859e-02, 2.1303e-01, 3.6777e-02, 4.7267e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.5478e-01, 4.5223e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.4656e-01, 1.0486e-01, 2.4857e-01, 0.0000e+00, 0.0000e+00], [6.0311e-01, 1.5732e-01, 1.9775e-01, 4.1819e-02, 0.0000e+00], [4.2266e-01, 1.1131e-01, 2.1519e-01, 8.6133e-02, 1.6471e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.1178e-01, 2.8822e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [5.8995e-01, 9.0134e-02, 3.1991e-01, 0.0000e+00, 0.0000e+00], [4.8390e-01, 1.0257e-01, 1.5057e-01, 2.6296e-01, 0.0000e+00], [4.0053e-01, 9.5700e-02, 1.9708e-01, 3.7479e-02, 2.6921e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.4934e-01, 1.5066e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.3224e-01, 2.2437e-01, 1.4339e-01, 0.0000e+00, 0.0000e+00], [3.9314e-01, 2.7727e-01, 1.5651e-01, 1.7308e-01, 0.0000e+00], [4.6852e-01, 1.1897e-01, 1.0104e-01, 1.2253e-01, 1.8895e-01]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7209e-01, 2.7912e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [5.8621e-01, 3.8064e-01, 3.3142e-02, 0.0000e+00, 0.0000e+00], [4.8232e-01, 1.2759e-01, 2.5068e-01, 1.3941e-01, 0.0000e+00], [4.5573e-01, 2.2592e-01, 1.1976e-01, 4.9527e-02, 1.4906e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8905e-01, 1.0951e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.3572e-01, 8.3821e-02, 8.0454e-02, 0.0000e+00, 0.0000e+00], [6.8450e-01, 4.7232e-02, 2.0742e-01, 6.0846e-02, 0.0000e+00], [7.3180e-01, 3.9842e-02, 5.9887e-02, 5.9164e-02, 1.0930e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9220e-01, 7.7975e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.6141e-01, 1.8003e-02, 2.2059e-01, 0.0000e+00, 0.0000e+00], [6.0938e-01, 2.2223e-02, 2.3404e-01, 1.3435e-01, 0.0000e+00], [4.5839e-01, 1.5597e-02, 1.6847e-01, 1.0259e-01, 2.5495e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.4533e-01, 3.5467e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.1803e-01, 1.8478e-01, 1.9720e-01, 0.0000e+00, 0.0000e+00], [5.5501e-01, 1.2625e-01, 1.3541e-01, 1.8333e-01, 0.0000e+00], [5.1175e-01, 8.8049e-02, 1.0100e-01, 1.3373e-01, 1.6547e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.9888e-01, 1.0112e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.8684e-01, 1.3459e-01, 7.8575e-02, 0.0000e+00, 0.0000e+00], [7.1967e-01, 9.3313e-02, 6.5694e-02, 1.2132e-01, 0.0000e+00], [6.3908e-01, 8.8730e-02, 5.7430e-02, 1.1279e-01, 1.0197e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6850e-01, 3.1500e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [5.3495e-01, 3.9359e-03, 4.6111e-01, 0.0000e+00, 0.0000e+00], [7.2299e-01, 1.3465e-01, 3.9143e-02, 1.0322e-01, 0.0000e+00], [6.0384e-01, 1.3601e-02, 5.6788e-02, 1.7283e-02, 3.0849e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.3237e-01, 6.7633e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.1456e-01, 1.3563e-01, 4.9807e-02, 0.0000e+00, 0.0000e+00], [6.8662e-01, 9.9697e-02, 6.8591e-02, 1.4509e-01, 0.0000e+00], [7.6440e-01, 1.0466e-01, 5.6559e-02, 2.9471e-02, 4.4903e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.0040e-01, 9.9604e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.5084e-01, 1.5419e-01, 9.4970e-02, 0.0000e+00, 0.0000e+00], [6.1832e-01, 1.9061e-01, 9.1609e-02, 9.9462e-02, 0.0000e+00], [5.6564e-01, 1.3307e-01, 7.9954e-02, 1.1634e-01, 1.0500e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6794e-01, 3.2058e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.7170e-01, 4.3231e-02, 8.5070e-02, 0.0000e+00, 0.0000e+00], [7.7739e-01, 4.0126e-02, 1.2773e-01, 5.4759e-02, 0.0000e+00], [6.4928e-01, 5.4109e-02, 8.0593e-02, 7.3304e-02, 1.4271e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9020e-01, 9.8014e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.2868e-01, 1.3217e-02, 5.8099e-02, 0.0000e+00, 0.0000e+00], [8.9125e-01, 1.5728e-02, 7.7379e-02, 1.5646e-02, 0.0000e+00], [8.5727e-01, 1.4522e-02, 3.8854e-02, 1.3800e-02, 7.5554e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [2.0297e-04, 9.9980e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.7313e-04, 5.9022e-01, 4.0891e-01, 0.0000e+00, 0.0000e+00], [5.2686e-04, 3.3926e-01, 2.4788e-01, 4.1233e-01, 0.0000e+00], [8.0960e-04, 2.2636e-01, 1.7240e-01, 2.9428e-01, 3.0615e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [4.9741e-01, 5.0259e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [4.5825e-01, 1.1647e-03, 5.4058e-01, 0.0000e+00, 0.0000e+00], [5.6682e-02, 2.2267e-03, 2.0238e-02, 9.2085e-01, 0.0000e+00], [1.2164e-01, 1.7738e-03, 2.0815e-02, 2.8863e-03, 8.5288e-01]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.8816, 0.1184, 0.0000, 0.0000, 0.0000], [0.3900, 0.3470, 0.2630, 0.0000, 0.0000], [0.4344, 0.0606, 0.4109, 0.0942, 0.0000], [0.3666, 0.0940, 0.3683, 0.1146, 0.0565]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9685, 0.0315, 0.0000, 0.0000, 0.0000], [0.8868, 0.0496, 0.0636, 0.0000, 0.0000], [0.7343, 0.0683, 0.1390, 0.0584, 0.0000], [0.7296, 0.0639, 0.0412, 0.1309, 0.0345]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9432, 0.0568, 0.0000, 0.0000, 0.0000], [0.1469, 0.8471, 0.0060, 0.0000, 0.0000], [0.1572, 0.2076, 0.5838, 0.0514, 0.0000], [0.2244, 0.0963, 0.1862, 0.3246, 0.1686]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9846, 0.0154, 0.0000, 0.0000, 0.0000], [0.8928, 0.0777, 0.0295, 0.0000, 0.0000], [0.6239, 0.0908, 0.1662, 0.1192, 0.0000], [0.6062, 0.0426, 0.2135, 0.0812, 0.0565]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9782, 0.0218, 0.0000, 0.0000, 0.0000], [0.1871, 0.7836, 0.0293, 0.0000, 0.0000], [0.6109, 0.1238, 0.2466, 0.0187, 0.0000], [0.6416, 0.1750, 0.0845, 0.0081, 0.0908]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9765, 0.0235, 0.0000, 0.0000, 0.0000], [0.7527, 0.1426, 0.1048, 0.0000, 0.0000], [0.4404, 0.1803, 0.2962, 0.0830, 0.0000], [0.4287, 0.0690, 0.1540, 0.1683, 0.1800]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9733, 0.0267, 0.0000, 0.0000, 0.0000], [0.8905, 0.0347, 0.0747, 0.0000, 0.0000], [0.8629, 0.0255, 0.0468, 0.0648, 0.0000], [0.7898, 0.0282, 0.0614, 0.0545, 0.0661]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.7351, 0.2649, 0.0000, 0.0000, 0.0000], [0.5648, 0.1821, 0.2531, 0.0000, 0.0000], [0.3673, 0.1207, 0.1130, 0.3989, 0.0000], [0.2760, 0.0600, 0.0901, 0.2262, 0.3478]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9666, 0.0334, 0.0000, 0.0000, 0.0000], [0.5883, 0.2326, 0.1791, 0.0000, 0.0000], [0.3643, 0.2940, 0.3139, 0.0278, 0.0000], [0.2934, 0.0450, 0.4518, 0.0157, 0.1942]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.8553, 0.1447, 0.0000, 0.0000, 0.0000], [0.2809, 0.6997, 0.0194, 0.0000, 0.0000], [0.2087, 0.5914, 0.0692, 0.1307, 0.0000], [0.1291, 0.2955, 0.0696, 0.4771, 0.0287]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.8377, 0.1623, 0.0000, 0.0000, 0.0000], [0.7713, 0.1540, 0.0747, 0.0000, 0.0000], [0.5854, 0.1238, 0.0504, 0.2403, 0.0000], [0.5171, 0.0990, 0.0531, 0.1615, 0.1693]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9215, 0.0785, 0.0000, 0.0000, 0.0000], [0.8768, 0.0633, 0.0599, 0.0000, 0.0000], [0.8145, 0.0514, 0.0311, 0.1030, 0.0000], [0.7703, 0.0372, 0.0291, 0.0807, 0.0826]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8268e-01, 1.7316e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9714e-01, 2.4256e-03, 4.3351e-04, 0.0000e+00, 0.0000e+00], [9.4433e-01, 3.0325e-05, 1.0623e-04, 5.5529e-02, 0.0000e+00], [9.6740e-01, 1.8683e-04, 2.7851e-04, 4.5622e-04, 3.1674e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7486e-01, 2.5143e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.1439e-01, 3.3117e-02, 5.2495e-02, 0.0000e+00, 0.0000e+00], [7.0372e-01, 1.6742e-01, 7.7478e-02, 5.1381e-02, 0.0000e+00], [8.5533e-01, 2.5359e-02, 7.8012e-03, 1.3220e-02, 9.8292e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.0629e-01, 9.3707e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.0309e-01, 2.8151e-01, 1.1539e-01, 0.0000e+00, 0.0000e+00], [1.7294e-01, 7.0645e-02, 4.8604e-01, 2.7038e-01, 0.0000e+00], [2.1087e-01, 4.5951e-02, 4.1121e-01, 2.4295e-01, 8.9014e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9231e-01, 7.6922e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.8290e-02, 9.6893e-01, 1.2778e-02, 0.0000e+00, 0.0000e+00], [3.3402e-01, 3.9228e-01, 2.6383e-01, 9.8732e-03, 0.0000e+00], [7.7857e-01, 8.7052e-02, 6.7886e-02, 2.3009e-02, 4.3481e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9965e-01, 3.5409e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9085e-01, 3.1320e-03, 6.0142e-03, 0.0000e+00, 0.0000e+00], [9.3759e-01, 3.5951e-03, 1.7021e-02, 4.1793e-02, 0.0000e+00], [9.4960e-01, 2.0509e-03, 2.6982e-03, 1.1689e-02, 3.3964e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6235e-01, 3.7648e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.4446e-01, 8.6601e-02, 6.8941e-02, 0.0000e+00, 0.0000e+00], [8.4127e-01, 6.9890e-02, 5.7835e-02, 3.1001e-02, 0.0000e+00], [6.2594e-01, 1.5199e-01, 1.1569e-01, 5.1549e-02, 5.4835e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.7570e-01, 1.2430e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [3.5125e-01, 4.5002e-01, 1.9873e-01, 0.0000e+00, 0.0000e+00], [3.4573e-01, 2.2871e-01, 1.0734e-01, 3.1822e-01, 0.0000e+00], [1.4163e-01, 2.1510e-01, 4.3367e-02, 5.0629e-01, 9.3614e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7368e-01, 2.6315e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.0697e-01, 8.3537e-01, 5.7653e-02, 0.0000e+00, 0.0000e+00], [2.4387e-01, 2.1662e-01, 3.3404e-01, 2.0547e-01, 0.0000e+00], [1.8795e-01, 2.9717e-01, 8.5084e-02, 3.4794e-01, 8.1855e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7094e-01, 2.9060e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [5.5981e-01, 1.0658e-01, 3.3361e-01, 0.0000e+00, 0.0000e+00], [2.1676e-01, 4.6400e-02, 6.6238e-01, 7.4458e-02, 0.0000e+00], [2.9181e-01, 1.1980e-02, 3.1150e-01, 9.6777e-02, 2.8793e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6866e-01, 3.1337e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.3343e-01, 1.9364e-01, 7.2931e-02, 0.0000e+00, 0.0000e+00], [6.4080e-01, 2.5656e-01, 7.0099e-02, 3.2543e-02, 0.0000e+00], [7.2064e-01, 7.5143e-02, 5.0348e-02, 5.4231e-03, 1.4845e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7586e-01, 2.4136e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.2307e-01, 4.7648e-02, 1.2928e-01, 0.0000e+00, 0.0000e+00], [7.8726e-01, 1.5413e-02, 2.2299e-02, 1.7503e-01, 0.0000e+00], [7.7859e-01, 1.0238e-02, 9.9621e-03, 9.4805e-02, 1.0640e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.3870e-01, 3.6130e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [2.9023e-02, 9.5808e-01, 1.2895e-02, 0.0000e+00, 0.0000e+00], [1.6684e-01, 7.1874e-01, 9.9462e-02, 1.4962e-02, 0.0000e+00], [8.6535e-02, 8.2408e-01, 5.8495e-02, 7.2601e-03, 2.3631e-02]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8763e-01, 1.2369e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.5513e-01, 1.9957e-01, 4.5302e-02, 0.0000e+00, 0.0000e+00], [1.6612e-01, 2.8464e-01, 5.0885e-01, 4.0393e-02, 0.0000e+00], [2.7830e-01, 4.6737e-02, 6.3905e-01, 1.1467e-02, 2.4444e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.5025e-01, 4.9747e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.2971e-01, 8.2469e-01, 4.5595e-02, 0.0000e+00, 0.0000e+00], [5.5166e-01, 2.8164e-01, 1.4566e-01, 2.1041e-02, 0.0000e+00], [8.4285e-01, 1.1891e-02, 1.3288e-02, 7.6044e-03, 1.2436e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9410e-01, 5.8977e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.3703e-01, 3.3505e-02, 2.9464e-02, 0.0000e+00, 0.0000e+00], [9.5604e-01, 1.2286e-02, 2.3039e-02, 8.6327e-03, 0.0000e+00], [9.1609e-01, 1.2754e-02, 1.0708e-02, 3.6108e-02, 2.4342e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.3213e-01, 6.7874e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.0969e-01, 8.7633e-01, 1.3974e-02, 0.0000e+00, 0.0000e+00], [3.6800e-01, 4.9800e-01, 2.5786e-02, 1.0822e-01, 0.0000e+00], [7.8568e-02, 4.8116e-01, 3.8334e-01, 4.9499e-02, 7.4341e-03]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8429e-01, 1.5711e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.1697e-01, 3.0108e-01, 8.1957e-02, 0.0000e+00, 0.0000e+00], [7.6427e-01, 6.6728e-02, 9.3641e-02, 7.5364e-02, 0.0000e+00], [7.8823e-01, 6.3681e-02, 2.6464e-02, 5.1123e-02, 7.0499e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.5242e-01, 4.7581e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.2807e-01, 2.4630e-01, 1.2563e-01, 0.0000e+00, 0.0000e+00], [5.1321e-01, 9.6500e-02, 3.4471e-01, 4.5581e-02, 0.0000e+00], [3.8266e-01, 1.7605e-01, 3.3798e-01, 6.5144e-02, 3.8159e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8426e-01, 1.5742e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.0698e-01, 1.3276e-01, 6.0261e-02, 0.0000e+00, 0.0000e+00], [7.5616e-01, 6.8283e-02, 7.8978e-02, 9.6574e-02, 0.0000e+00], [5.0845e-01, 6.3645e-02, 8.5731e-02, 2.4365e-01, 9.8529e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6902e-01, 3.0983e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [5.3681e-01, 4.5272e-02, 4.1791e-01, 0.0000e+00, 0.0000e+00], [4.9830e-01, 2.4730e-02, 9.5733e-02, 3.8124e-01, 0.0000e+00], [3.9466e-01, 1.4528e-02, 9.0154e-02, 4.5654e-02, 4.5500e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9760e-01, 2.4048e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.5771e-01, 1.1905e-02, 3.0384e-02, 0.0000e+00, 0.0000e+00], [9.3648e-01, 3.0840e-03, 9.9071e-03, 5.0528e-02, 0.0000e+00], [8.9304e-01, 6.8501e-03, 3.8995e-02, 2.6777e-02, 3.4342e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6133e-01, 3.8670e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.5277e-01, 3.9294e-02, 7.9354e-03, 0.0000e+00, 0.0000e+00], [8.8129e-01, 7.9166e-02, 3.2179e-02, 7.3699e-03, 0.0000e+00], [4.4197e-01, 1.8336e-01, 3.2659e-01, 2.6132e-02, 2.1946e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9621e-01, 3.7885e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.5926e-01, 1.7418e-02, 2.3323e-02, 0.0000e+00, 0.0000e+00], [8.7619e-01, 7.0181e-03, 3.9708e-02, 7.7079e-02, 0.0000e+00], [9.3684e-01, 8.5476e-03, 1.8532e-02, 1.4737e-02, 2.1345e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.0000e+00, 3.7350e-06, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.1243e-05, 9.9694e-01, 3.0456e-03, 0.0000e+00, 0.0000e+00], [5.4056e-08, 1.0553e-09, 1.0000e+00, 2.4611e-07, 0.0000e+00], [2.7906e-08, 7.2148e-08, 2.3880e-07, 9.9999e-01, 4.6058e-06]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9132e-01, 8.6850e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9452e-01, 2.3545e-03, 3.1232e-03, 0.0000e+00, 0.0000e+00], [9.6897e-01, 1.1418e-03, 1.0935e-03, 2.8795e-02, 0.0000e+00], [9.7149e-01, 1.3023e-04, 5.7703e-04, 6.0039e-03, 2.1799e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9957e-01, 4.2876e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9992e-01, 1.5292e-05, 6.4213e-05, 0.0000e+00, 0.0000e+00], [9.9904e-01, 5.1686e-06, 1.8187e-06, 9.5816e-04, 0.0000e+00], [9.9772e-01, 2.1845e-05, 4.3966e-07, 1.9277e-04, 2.0637e-03]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.4495e-01, 5.5048e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [3.9356e-01, 5.7606e-01, 3.0387e-02, 0.0000e+00, 0.0000e+00], [3.6411e-01, 3.0299e-01, 3.1687e-01, 1.6032e-02, 0.0000e+00], [3.2875e-01, 2.4410e-01, 3.0265e-01, 7.0624e-02, 5.3878e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.3013e-01, 6.9867e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.7017e-01, 8.9004e-02, 4.0831e-02, 0.0000e+00, 0.0000e+00], [6.9796e-01, 1.5881e-01, 1.1179e-01, 3.1447e-02, 0.0000e+00], [6.5844e-01, 1.5417e-01, 4.4342e-02, 6.0779e-02, 8.2266e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6758e-01, 3.2418e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.3914e-01, 2.9498e-01, 6.5881e-02, 0.0000e+00, 0.0000e+00], [3.9811e-01, 9.0394e-02, 4.8018e-01, 3.1310e-02, 0.0000e+00], [6.7694e-01, 8.3124e-02, 1.0220e-01, 6.6097e-02, 7.1634e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9676e-01, 3.2360e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9158e-01, 4.4951e-03, 3.9272e-03, 0.0000e+00, 0.0000e+00], [9.9506e-01, 2.2144e-03, 6.9065e-05, 2.6588e-03, 0.0000e+00], [9.5271e-01, 6.7518e-03, 1.3870e-02, 3.2582e-03, 2.3410e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9196e-01, 8.0396e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.6724e-01, 1.3005e-01, 2.7055e-03, 0.0000e+00, 0.0000e+00], [9.1523e-01, 3.7084e-02, 2.4366e-02, 2.3315e-02, 0.0000e+00], [9.3293e-01, 4.1201e-03, 4.5888e-04, 4.6213e-02, 1.6283e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9209e-01, 7.9106e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [5.6535e-01, 3.9672e-01, 3.7929e-02, 0.0000e+00, 0.0000e+00], [6.8777e-01, 2.3920e-01, 6.3899e-02, 9.1290e-03, 0.0000e+00], [3.2970e-01, 4.7910e-01, 1.4188e-01, 2.3569e-02, 2.5749e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8486e-01, 1.5137e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8760e-01, 6.8198e-03, 5.5837e-03, 0.0000e+00, 0.0000e+00], [9.5438e-01, 1.0198e-02, 4.4678e-03, 3.0952e-02, 0.0000e+00], [9.5295e-01, 6.5464e-03, 2.5009e-03, 1.1826e-02, 2.6180e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9748e-01, 2.5237e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8592e-01, 1.0216e-02, 3.8620e-03, 0.0000e+00, 0.0000e+00], [9.7857e-01, 9.8701e-03, 1.9850e-03, 9.5784e-03, 0.0000e+00], [9.3577e-01, 2.1244e-02, 6.9420e-03, 1.6773e-02, 1.9270e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.4617e-01, 5.3829e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.6554e-01, 1.8921e-01, 1.4525e-01, 0.0000e+00, 0.0000e+00], [8.4404e-01, 4.7671e-02, 4.6438e-02, 6.1854e-02, 0.0000e+00], [4.1475e-01, 1.2054e-01, 1.0108e-01, 2.6399e-01, 9.9632e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7725e-01, 2.2754e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.4521e-01, 4.2486e-02, 1.1230e-01, 0.0000e+00, 0.0000e+00], [7.0981e-01, 1.4360e-01, 6.2041e-02, 8.4544e-02, 0.0000e+00], [6.3260e-01, 6.4462e-02, 4.2408e-02, 5.3823e-02, 2.0671e-01]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.3107e-01, 6.8927e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.2165e-01, 1.8779e-01, 9.0556e-02, 0.0000e+00, 0.0000e+00], [4.6368e-01, 1.8552e-01, 2.9214e-01, 5.8654e-02, 0.0000e+00], [4.4191e-01, 2.1471e-01, 1.9086e-01, 1.1681e-01, 3.5719e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6630e-01, 3.3704e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.5391e-01, 1.1163e-01, 3.4456e-02, 0.0000e+00, 0.0000e+00], [9.2002e-01, 2.2570e-02, 4.1658e-02, 1.5755e-02, 0.0000e+00], [3.4370e-01, 3.0937e-01, 3.2868e-01, 4.5510e-03, 1.3704e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8087e-01, 1.9129e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.4840e-01, 1.1681e-02, 3.9920e-02, 0.0000e+00, 0.0000e+00], [8.9529e-01, 1.0272e-02, 3.1628e-02, 6.2805e-02, 0.0000e+00], [8.2158e-01, 1.2511e-02, 2.0896e-02, 8.3904e-02, 6.1105e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8491e-01, 1.5086e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6406e-01, 1.9128e-02, 1.6817e-02, 0.0000e+00, 0.0000e+00], [8.3050e-01, 3.6039e-02, 8.6933e-02, 4.6528e-02, 0.0000e+00], [7.5217e-01, 2.4834e-02, 7.3416e-02, 8.6154e-02, 6.3430e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.4821e-01, 1.5179e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.3873e-01, 9.3524e-02, 6.7741e-02, 0.0000e+00, 0.0000e+00], [9.3704e-01, 2.1367e-02, 1.2340e-02, 2.9252e-02, 0.0000e+00], [7.2025e-01, 2.1460e-02, 2.7114e-02, 1.6127e-01, 6.9903e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7658e-01, 2.3415e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.4204e-01, 3.6862e-02, 2.1100e-02, 0.0000e+00, 0.0000e+00], [8.7403e-01, 2.0384e-02, 2.8869e-02, 7.6717e-02, 0.0000e+00], [8.1515e-01, 3.4231e-02, 2.6072e-02, 6.2564e-02, 6.1980e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6897e-01, 3.1028e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.3794e-01, 2.7844e-02, 3.4217e-02, 0.0000e+00, 0.0000e+00], [9.2826e-01, 1.8394e-02, 2.4454e-03, 5.0899e-02, 0.0000e+00], [8.8630e-01, 2.5514e-02, 5.7392e-03, 2.7406e-02, 5.5041e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.1927e-01, 8.0730e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.3202e-01, 1.1687e-01, 5.1114e-02, 0.0000e+00, 0.0000e+00], [8.1200e-01, 4.5452e-02, 6.5906e-02, 7.6644e-02, 0.0000e+00], [6.6297e-01, 6.4572e-02, 8.3910e-02, 1.6866e-01, 1.9883e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.2020e-01, 7.9799e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [3.2156e-01, 6.7010e-01, 8.3458e-03, 0.0000e+00, 0.0000e+00], [6.2888e-01, 1.7730e-01, 1.7765e-01, 1.6179e-02, 0.0000e+00], [7.1778e-01, 1.0738e-01, 5.3128e-02, 9.8006e-02, 2.3706e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9811e-01, 1.8942e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9769e-01, 9.5311e-04, 1.3615e-03, 0.0000e+00, 0.0000e+00], [9.9786e-01, 1.8052e-04, 8.9000e-05, 1.8703e-03, 0.0000e+00], [9.9209e-01, 2.7830e-04, 1.6944e-04, 3.0577e-03, 4.4005e-03]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8963e-01, 1.0372e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7655e-01, 1.2386e-02, 1.1060e-02, 0.0000e+00, 0.0000e+00], [9.2873e-01, 7.9086e-03, 6.7334e-03, 5.6627e-02, 0.0000e+00], [9.4687e-01, 3.3349e-03, 3.8526e-03, 2.0949e-02, 2.4993e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6317e-01, 3.6828e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.1453e-01, 1.5719e-01, 2.8284e-02, 0.0000e+00, 0.0000e+00], [8.4460e-01, 3.3552e-02, 1.1012e-01, 1.1731e-02, 0.0000e+00], [8.2286e-01, 5.7689e-02, 4.7373e-02, 4.5768e-02, 2.6306e-02]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6141e-01, 3.8593e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.0344e-01, 6.1702e-02, 3.4861e-02, 0.0000e+00, 0.0000e+00], [7.9124e-01, 2.2201e-02, 1.6933e-01, 1.7236e-02, 0.0000e+00], [9.5321e-01, 3.3564e-03, 6.9881e-03, 1.9175e-02, 1.7265e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6476e-01, 3.5243e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8655e-01, 9.8743e-03, 3.5791e-03, 0.0000e+00, 0.0000e+00], [9.4783e-01, 5.8394e-03, 3.1383e-03, 4.3189e-02, 0.0000e+00], [9.5562e-01, 6.2320e-03, 6.4283e-03, 2.5439e-02, 6.2816e-03]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9179e-01, 8.2106e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9778e-01, 7.7677e-04, 1.4401e-03, 0.0000e+00, 0.0000e+00], [9.9611e-01, 7.4620e-05, 7.2848e-04, 3.0882e-03, 0.0000e+00], [9.9262e-01, 2.0443e-04, 5.5199e-04, 5.4467e-03, 1.1742e-03]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.3731e-01, 6.2687e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.8342e-01, 1.4043e-01, 7.6153e-02, 0.0000e+00, 0.0000e+00], [8.3782e-01, 4.1296e-02, 4.8374e-02, 7.2512e-02, 0.0000e+00], [7.7558e-01, 7.1647e-02, 2.1223e-02, 9.3007e-02, 3.8544e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7783e-01, 2.2168e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8155e-01, 1.4647e-02, 3.8022e-03, 0.0000e+00, 0.0000e+00], [5.0584e-01, 1.1217e-01, 3.3681e-01, 4.5181e-02, 0.0000e+00], [8.6466e-01, 4.3985e-02, 5.2316e-02, 1.8824e-02, 2.0212e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.4112e-01, 5.8879e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.1998e-01, 4.8161e-02, 3.1862e-02, 0.0000e+00, 0.0000e+00], [7.8413e-01, 5.7792e-02, 1.1763e-01, 4.0456e-02, 0.0000e+00], [8.0141e-01, 3.4749e-02, 3.2050e-02, 1.0286e-01, 2.8935e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9054e-01, 9.4629e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6498e-01, 9.4434e-03, 2.5576e-02, 0.0000e+00, 0.0000e+00], [8.5708e-01, 8.2129e-03, 2.5883e-02, 1.0882e-01, 0.0000e+00], [8.8442e-01, 4.7589e-03, 1.1856e-02, 8.4370e-02, 1.4590e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.5244e-01, 4.7557e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8452e-01, 8.4211e-03, 7.0570e-03, 0.0000e+00, 0.0000e+00], [9.5362e-01, 9.5855e-03, 5.3842e-03, 3.1412e-02, 0.0000e+00], [9.6968e-01, 4.2234e-03, 3.7680e-03, 1.0003e-02, 1.2325e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.4719e-01, 5.2808e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.9683e-01, 6.0842e-02, 4.2329e-02, 0.0000e+00, 0.0000e+00], [4.7391e-01, 2.8253e-01, 2.1125e-01, 3.2311e-02, 0.0000e+00], [5.3445e-01, 2.6445e-01, 1.2724e-01, 5.9709e-02, 1.4149e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6439e-01, 3.5606e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.3692e-01, 5.0612e-02, 1.2466e-02, 0.0000e+00, 0.0000e+00], [9.5147e-01, 2.5873e-02, 1.2778e-02, 9.8748e-03, 0.0000e+00], [8.9290e-01, 2.6920e-02, 4.0316e-03, 3.8151e-02, 3.7993e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9627e-01, 3.7340e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9756e-01, 1.4468e-03, 9.9798e-04, 0.0000e+00, 0.0000e+00], [9.9382e-01, 4.3808e-04, 4.7226e-04, 5.2677e-03, 0.0000e+00], [9.8542e-01, 1.1592e-03, 1.3844e-03, 8.3883e-03, 3.6457e-03]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7623e-01, 2.3771e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9240e-01, 2.1715e-03, 5.4276e-03, 0.0000e+00, 0.0000e+00], [9.7309e-01, 7.1527e-04, 1.4431e-02, 1.1764e-02, 0.0000e+00], [9.8076e-01, 4.1391e-04, 1.7371e-03, 5.9987e-03, 1.1089e-02]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9763, 0.0237, 0.0000, 0.0000, 0.0000], [0.8533, 0.0129, 0.1338, 0.0000, 0.0000], [0.8869, 0.0167, 0.0442, 0.0523, 0.0000], [0.8437, 0.0111, 0.0380, 0.0377, 0.0695]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9941, 0.0059, 0.0000, 0.0000, 0.0000], [0.9844, 0.0065, 0.0091, 0.0000, 0.0000], [0.9806, 0.0017, 0.0087, 0.0090, 0.0000], [0.9638, 0.0018, 0.0092, 0.0168, 0.0084]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9541, 0.0459, 0.0000, 0.0000, 0.0000], [0.9257, 0.0401, 0.0343, 0.0000, 0.0000], [0.9612, 0.0068, 0.0181, 0.0138, 0.0000], [0.9195, 0.0095, 0.0134, 0.0390, 0.0186]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9704, 0.0296, 0.0000, 0.0000, 0.0000], [0.9800, 0.0112, 0.0088, 0.0000, 0.0000], [0.9829, 0.0071, 0.0032, 0.0068, 0.0000], [0.8489, 0.0203, 0.0402, 0.0247, 0.0660]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9750, 0.0250, 0.0000, 0.0000, 0.0000], [0.8638, 0.0884, 0.0478, 0.0000, 0.0000], [0.7929, 0.0489, 0.1298, 0.0284, 0.0000], [0.6916, 0.0867, 0.0752, 0.1165, 0.0300]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9613, 0.0387, 0.0000, 0.0000, 0.0000], [0.9194, 0.0337, 0.0470, 0.0000, 0.0000], [0.7719, 0.0633, 0.1241, 0.0407, 0.0000], [0.6144, 0.0881, 0.0239, 0.2148, 0.0588]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9615, 0.0385, 0.0000, 0.0000, 0.0000], [0.9602, 0.0233, 0.0165, 0.0000, 0.0000], [0.9305, 0.0045, 0.0343, 0.0308, 0.0000], [0.8467, 0.0134, 0.0293, 0.0560, 0.0545]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9363, 0.0637, 0.0000, 0.0000, 0.0000], [0.4824, 0.1020, 0.4156, 0.0000, 0.0000], [0.3589, 0.0568, 0.5517, 0.0326, 0.0000], [0.4467, 0.0870, 0.0771, 0.1530, 0.2362]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9491, 0.0509, 0.0000, 0.0000, 0.0000], [0.9092, 0.0701, 0.0207, 0.0000, 0.0000], [0.9470, 0.0234, 0.0119, 0.0177, 0.0000], [0.7555, 0.0474, 0.0371, 0.0684, 0.0917]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9607, 0.0393, 0.0000, 0.0000, 0.0000], [0.8362, 0.0832, 0.0806, 0.0000, 0.0000], [0.9208, 0.0155, 0.0268, 0.0369, 0.0000], [0.8085, 0.0206, 0.0381, 0.0475, 0.0853]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9299, 0.0701, 0.0000, 0.0000, 0.0000], [0.8491, 0.0858, 0.0651, 0.0000, 0.0000], [0.9007, 0.0214, 0.0235, 0.0544, 0.0000], [0.7305, 0.0466, 0.0381, 0.1136, 0.0712]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9450, 0.0550, 0.0000, 0.0000, 0.0000], [0.8148, 0.1320, 0.0532, 0.0000, 0.0000], [0.8963, 0.0091, 0.0066, 0.0880, 0.0000], [0.7990, 0.0131, 0.0100, 0.1260, 0.0518]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.3734e-01, 6.2660e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7290e-01, 2.1706e-02, 5.3903e-03, 0.0000e+00, 0.0000e+00], [9.4636e-01, 3.4305e-02, 4.0194e-03, 1.5316e-02, 0.0000e+00], [6.3312e-01, 2.3632e-01, 1.3801e-02, 4.7460e-02, 6.9299e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9656e-01, 3.4444e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8700e-01, 5.0174e-03, 7.9826e-03, 0.0000e+00, 0.0000e+00], [9.7410e-01, 1.7798e-03, 1.5016e-02, 9.1018e-03, 0.0000e+00], [9.7638e-01, 1.8578e-03, 3.9721e-03, 6.3269e-03, 1.1460e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7258e-01, 2.7418e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7676e-01, 2.0872e-02, 2.3695e-03, 0.0000e+00, 0.0000e+00], [9.1199e-01, 4.0753e-02, 7.5556e-03, 3.9700e-02, 0.0000e+00], [7.7058e-01, 7.7728e-02, 1.7429e-02, 6.0164e-02, 7.4093e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.0931e-01, 9.0695e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [6.9254e-01, 2.1706e-01, 9.0391e-02, 0.0000e+00, 0.0000e+00], [7.7797e-01, 6.4412e-02, 1.0519e-01, 5.2428e-02, 0.0000e+00], [5.8848e-01, 1.4391e-01, 9.1029e-03, 1.7338e-01, 8.5124e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8742e-01, 1.2580e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6524e-01, 2.6900e-02, 7.8572e-03, 0.0000e+00, 0.0000e+00], [9.7642e-01, 7.5138e-03, 4.3141e-03, 1.1750e-02, 0.0000e+00], [9.1618e-01, 2.3913e-02, 1.6411e-02, 2.9356e-02, 1.4135e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.2926e-01, 7.0740e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.0852e-01, 1.1819e-01, 7.3287e-02, 0.0000e+00, 0.0000e+00], [8.6614e-01, 3.8358e-02, 2.9329e-02, 6.6175e-02, 0.0000e+00], [7.9866e-01, 5.3226e-02, 2.3835e-02, 7.9351e-02, 4.4932e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7826e-01, 2.1736e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8715e-01, 9.6764e-03, 3.1743e-03, 0.0000e+00, 0.0000e+00], [9.8490e-01, 6.7368e-03, 3.0691e-03, 5.2954e-03, 0.0000e+00], [9.3938e-01, 1.1518e-02, 4.7326e-03, 3.2758e-02, 1.1610e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8930e-01, 1.0700e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.2223e-01, 2.7651e-02, 5.0124e-02, 0.0000e+00, 0.0000e+00], [9.5263e-01, 6.9835e-03, 2.7537e-02, 1.2847e-02, 0.0000e+00], [8.7247e-01, 1.2950e-02, 2.1495e-02, 6.7099e-02, 2.5981e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.6224e-01, 1.3776e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.7177e-01, 1.9972e-01, 2.8508e-02, 0.0000e+00, 0.0000e+00], [9.6954e-01, 5.4648e-03, 3.1324e-03, 2.1867e-02, 0.0000e+00], [3.4693e-01, 1.6847e-02, 7.2692e-03, 5.7058e-01, 5.8381e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9252e-01, 7.4822e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9383e-01, 5.0692e-03, 1.0965e-03, 0.0000e+00, 0.0000e+00], [9.9701e-01, 9.0284e-04, 3.2885e-04, 1.7612e-03, 0.0000e+00], [9.5627e-01, 5.9275e-03, 2.2744e-03, 2.8549e-02, 6.9835e-03]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8147e-01, 1.8532e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.5299e-01, 2.0602e-02, 2.6409e-02, 0.0000e+00, 0.0000e+00], [6.2735e-01, 9.5907e-02, 2.3771e-01, 3.9033e-02, 0.0000e+00], [7.5219e-01, 8.5292e-02, 2.4630e-02, 4.7091e-02, 9.0795e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9783e-01, 2.1697e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.9067e-01, 5.2246e-03, 4.1037e-03, 0.0000e+00, 0.0000e+00], [9.9046e-01, 1.1382e-03, 5.2418e-03, 3.1609e-03, 0.0000e+00], [9.8521e-01, 1.8107e-03, 2.1089e-03, 6.7891e-03, 4.0793e-03]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9339, 0.0661, 0.0000, 0.0000, 0.0000], [0.8350, 0.1076, 0.0574, 0.0000, 0.0000], [0.9477, 0.0126, 0.0054, 0.0342, 0.0000], [0.6596, 0.0298, 0.0083, 0.2848, 0.0175]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9831, 0.0169, 0.0000, 0.0000, 0.0000], [0.9702, 0.0253, 0.0045, 0.0000, 0.0000], [0.9699, 0.0080, 0.0049, 0.0172, 0.0000], [0.9169, 0.0214, 0.0061, 0.0443, 0.0113]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9892, 0.0108, 0.0000, 0.0000, 0.0000], [0.9817, 0.0110, 0.0073, 0.0000, 0.0000], [0.9556, 0.0100, 0.0169, 0.0176, 0.0000], [0.9302, 0.0113, 0.0085, 0.0361, 0.0139]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9441, 0.0559, 0.0000, 0.0000, 0.0000], [0.6784, 0.1339, 0.1877, 0.0000, 0.0000], [0.9437, 0.0257, 0.0081, 0.0225, 0.0000], [0.7729, 0.0632, 0.0215, 0.0909, 0.0515]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9885, 0.0115, 0.0000, 0.0000, 0.0000], [0.9187, 0.0477, 0.0336, 0.0000, 0.0000], [0.9301, 0.0121, 0.0399, 0.0180, 0.0000], [0.8441, 0.0120, 0.0173, 0.0896, 0.0369]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9504, 0.0496, 0.0000, 0.0000, 0.0000], [0.9814, 0.0093, 0.0093, 0.0000, 0.0000], [0.5579, 0.2399, 0.1236, 0.0785, 0.0000], [0.9161, 0.0337, 0.0077, 0.0196, 0.0228]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9162, 0.0838, 0.0000, 0.0000, 0.0000], [0.9020, 0.0839, 0.0141, 0.0000, 0.0000], [0.9402, 0.0360, 0.0078, 0.0160, 0.0000], [0.7907, 0.1090, 0.0101, 0.0662, 0.0240]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.8641, 0.1359, 0.0000, 0.0000, 0.0000], [0.8441, 0.0754, 0.0805, 0.0000, 0.0000], [0.8984, 0.0119, 0.0237, 0.0660, 0.0000], [0.6092, 0.0413, 0.0390, 0.2633, 0.0472]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9853, 0.0147, 0.0000, 0.0000, 0.0000], [0.9720, 0.0101, 0.0179, 0.0000, 0.0000], [0.9818, 0.0037, 0.0067, 0.0078, 0.0000], [0.9799, 0.0021, 0.0079, 0.0045, 0.0056]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9309, 0.0691, 0.0000, 0.0000, 0.0000], [0.9018, 0.0287, 0.0695, 0.0000, 0.0000], [0.4176, 0.4065, 0.0637, 0.1122, 0.0000], [0.8082, 0.0574, 0.0078, 0.0640, 0.0626]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9741, 0.0259, 0.0000, 0.0000, 0.0000], [0.9720, 0.0156, 0.0124, 0.0000, 0.0000], [0.9828, 0.0027, 0.0032, 0.0113, 0.0000], [0.8386, 0.0111, 0.0051, 0.1371, 0.0080]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9328, 0.0672, 0.0000, 0.0000, 0.0000], [0.9101, 0.0257, 0.0643, 0.0000, 0.0000], [0.8983, 0.0159, 0.0747, 0.0111, 0.0000], [0.7882, 0.0640, 0.0092, 0.0809, 0.0578]]]], grad_fn=<SoftmaxBackward0>), tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [5.3151e-01, 4.6849e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [3.1730e-01, 4.3761e-01, 2.4509e-01, 0.0000e+00, 0.0000e+00], [2.6517e-01, 2.0439e-01, 2.4353e-01, 2.8692e-01, 0.0000e+00], [3.3609e-01, 2.2130e-01, 1.1017e-01, 1.9541e-01, 1.3704e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7409e-01, 2.5907e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.7520e-01, 3.0482e-02, 9.4323e-02, 0.0000e+00, 0.0000e+00], [9.2752e-01, 1.6799e-02, 3.2042e-02, 2.3639e-02, 0.0000e+00], [8.3775e-01, 2.8084e-02, 5.1645e-02, 4.0437e-02, 4.2088e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.2957e-01, 7.0432e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.1524e-01, 4.5500e-02, 1.3926e-01, 0.0000e+00, 0.0000e+00], [9.4933e-01, 1.3745e-02, 1.4443e-02, 2.2482e-02, 0.0000e+00], [8.3442e-01, 2.5942e-02, 3.9919e-02, 7.1195e-02, 2.8529e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.6437e-01, 1.3563e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.0515e-01, 1.0092e-01, 9.3926e-02, 0.0000e+00, 0.0000e+00], [6.4718e-01, 1.5522e-01, 1.2086e-01, 7.6748e-02, 0.0000e+00], [6.1426e-01, 1.1765e-01, 5.6041e-02, 9.6448e-02, 1.1561e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.8664e-01, 1.1336e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.1557e-01, 5.0012e-02, 1.3442e-01, 0.0000e+00, 0.0000e+00], [8.3685e-01, 4.3709e-02, 5.9272e-02, 6.0167e-02, 0.0000e+00], [6.7246e-01, 6.8544e-02, 6.9060e-02, 9.6767e-02, 9.3171e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6281e-01, 3.7191e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.5268e-01, 2.2550e-02, 2.4769e-02, 0.0000e+00, 0.0000e+00], [9.1691e-01, 1.8713e-02, 4.0654e-02, 2.3725e-02, 0.0000e+00], [9.3674e-01, 1.0702e-02, 2.3813e-02, 1.1739e-02, 1.7011e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6620e-01, 3.3804e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6304e-01, 2.0481e-02, 1.6477e-02, 0.0000e+00, 0.0000e+00], [9.5439e-01, 1.7477e-02, 1.2198e-02, 1.5939e-02, 0.0000e+00], [9.0202e-01, 3.0633e-02, 1.0887e-02, 4.2302e-02, 1.4158e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.6207e-01, 3.7926e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.7194e-01, 2.2417e-02, 1.0564e-01, 0.0000e+00, 0.0000e+00], [9.2389e-01, 1.4402e-02, 4.3938e-02, 1.7769e-02, 0.0000e+00], [8.2730e-01, 1.7167e-02, 5.6471e-02, 3.9279e-02, 5.9779e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [3.8025e-04, 9.9962e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.2789e-05, 4.6021e-01, 5.3970e-01, 0.0000e+00, 0.0000e+00], [2.6519e-05, 5.6117e-01, 1.3128e-01, 3.0752e-01, 0.0000e+00], [3.6884e-04, 3.4064e-01, 2.2959e-01, 1.7917e-01, 2.5023e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.8572e-01, 1.4284e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.7978e-01, 1.0668e-02, 9.5542e-03, 0.0000e+00, 0.0000e+00], [9.6238e-01, 6.6888e-03, 1.6962e-02, 1.3967e-02, 0.0000e+00], [9.3741e-01, 1.0572e-02, 1.6528e-02, 1.6398e-02, 1.9097e-02]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.7891e-01, 1.2109e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00], [7.3038e-01, 1.2828e-01, 1.4134e-01, 0.0000e+00, 0.0000e+00], [6.1038e-01, 1.4152e-01, 1.6569e-01, 8.2418e-02, 0.0000e+00], [4.1426e-01, 8.5663e-02, 5.9702e-02, 3.3670e-01, 1.0368e-01]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [9.3430e-01, 6.5705e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00], [8.8622e-01, 4.1646e-02, 7.2138e-02, 0.0000e+00, 0.0000e+00], [7.2493e-01, 1.3754e-01, 8.3539e-02, 5.3992e-02, 0.0000e+00], [6.6676e-01, 8.8356e-02, 4.1274e-02, 6.3429e-02, 1.4018e-01]]]], grad_fn=<SoftmaxBackward0>))
outputs.logits.shape
torch.Size([1, 5, 50257])
Hyperparameters to change a model's architecture.
from transformers import GPT2Model,GPT2Config
# Default configuration
model = GPT2Model.from_pretrained("gpt2")
model
GPT2Model( (wte): Embedding(50257, 768) (wpe): Embedding(1024, 768) (drop): Dropout(p=0.1, inplace=False) (h): ModuleList( (0-11): 12 x GPT2Block( (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (attn): GPT2Attention( (c_attn): Conv1D() (c_proj): Conv1D() (attn_dropout): Dropout(p=0.1, inplace=False) (resid_dropout): Dropout(p=0.1, inplace=False) ) (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (mlp): GPT2MLP( (c_fc): Conv1D() (c_proj): Conv1D() (act): NewGELUActivation() (dropout): Dropout(p=0.1, inplace=False) ) ) ) (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True) )
model.config
GPT2Config { "_name_or_path": "gpt2", "activation_function": "gelu_new", "architectures": [ "GPT2LMHeadModel" ], "attn_pdrop": 0.1, "bos_token_id": 50256, "embd_pdrop": 0.1, "eos_token_id": 50256, "initializer_range": 0.02, "layer_norm_epsilon": 1e-05, "model_type": "gpt2", "n_ctx": 1024, "n_embd": 768, "n_head": 12, "n_inner": null, "n_layer": 12, "n_positions": 1024, "reorder_and_upcast_attn": false, "resid_pdrop": 0.1, "scale_attn_by_inverse_layer_idx": false, "scale_attn_weights": true, "summary_activation": null, "summary_first_dropout": 0.1, "summary_proj_to_labels": true, "summary_type": "cls_index", "summary_use_proj": true, "task_specific_params": { "text-generation": { "do_sample": true, "max_length": 50 } }, "transformers_version": "4.36.0", "use_cache": true, "vocab_size": 50257 }
# Create a custom configuration
config = GPT2Config(
n_layer=6,
n_head=8
)
# Load model with custom configuration
model = GPT2Model.from_pretrained("gpt2", config=config)
model
GPT2Model( (wte): Embedding(50257, 768) (wpe): Embedding(1024, 768) (drop): Dropout(p=0.1, inplace=False) (h): ModuleList( (0-5): 6 x GPT2Block( (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (attn): GPT2Attention( (c_attn): Conv1D() (c_proj): Conv1D() (attn_dropout): Dropout(p=0.1, inplace=False) (resid_dropout): Dropout(p=0.1, inplace=False) ) (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (mlp): GPT2MLP( (c_fc): Conv1D() (c_proj): Conv1D() (act): NewGELUActivation() (dropout): Dropout(p=0.1, inplace=False) ) ) ) (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True) )
from transformers import pipeline
import torch
model = "gpt2"
pipeline = pipeline(
"text-generation",
model=model,
trust_remote_code=True,
torch_dtype=torch.float32
)
sequences = pipeline(
'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n',
do_sample=True,
top_k=20,
pad_token_id=tokenizer.eos_token_id,
temperature=1.0,
max_length=50,
num_return_sequences=3
)
for seq in sequences:
print(f"Result: {seq['generated_text']}\n")
Result: I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like? Brock (Nrama): What's your favorite show? Hollywood Reporter: The only show that really made Result: I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like? I've had the urge to watch a lot of shows and I have a few. I've seen "Pulp Fiction Result: I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like? I haven't even tried any other shows yet. I haven't tried it for a while. I didn't see a
Exercise 1: Exploring Pre-trained Models
Objective: Familiarize with the Hugging Face Model Hub.
Task: Browse the Hugging Face Model Hub and find a pre-trained model suitable for sentiment analysis. Write a short script to explore the model's architecture, configration, output, etc.
Exercise 2: Text Generation
Objective: Understand the capabilities of text generation models.
Task: Use a text generation model to generate a short text based on a given prompt. Experiment with different temperature settings and observe how it affects the creativity of the output.