import nltk from nltk.corpus import gutenberg nltk.download('gutenberg') austen_emma = gutenberg.raw('austen-emma.txt') print(austen_emma[:200]) import string from collections import Counter austen_emma_cleaned = austen_emma.upper() austen_emma_cleaned = ''.join(c for c in austen_emma_cleaned if c in string.ascii_uppercase + ' ') n = 6 ngrams = [''.join(ngram) for ngram in zip(*(austen_emma_cleaned[i:-(n-i)] for i in range(n)))] freqs = Counter(ngrams) freqs english_freqs = Counter(austen_emma_cleaned) def marginal(prefix, counts): marginal_freqs = {ngram[-1]: count for ngram, count in freqs.items() if ngram[:-1] == prefix} if not marginal_freqs: marginal_freqs = english_freqs characters, counts = zip(*marginal_freqs.items()) sum_counts = sum(counts) probabilities = [count / sum_counts for count in counts] return characters, probabilities import random s = 'EMMA ' for i in range(100): characters, probabilities = marginal(s[-(n-1):], freqs) s += random.choices(characters, weights=probabilities, k=1)[0] s import math from tqdm.notebook import tqdm from numba import jit @jit(nopython=True) def float_to_binary_numba(number, length): output = "" for i in range(1, length + 1): digit = int(number >= 0.5) output += str(digit) number = 2 * (number - 0.5 * digit) return output def float_to_binary(number, length): output = "" for i in range(1, length + 1): digit = int(number >= 0.5) output += str(digit) number = 2 * (number - 0.5 * digit) return output def shannon_encoding(freq_table): items = sorted(freq_table.items(), key=lambda x: x[1], reverse=True) encoding = {} Pi = 0. for char, prob in items: encoding[char] = float_to_binary_numba(Pi, math.ceil(math.log2(1/prob))) Pi += prob return encoding def encode(string, encoding): return ''.join(encoding[char] for char in string) def expected_length(freq_table, encoding): return sum(freq_table[char] * len(encoding[char]) for char in freq_table) def shannon_entropy(freq_table): return sum(-freq_table[char] * math.log2(freq_table[char]) for char in freq_table) freq_table = {'a': 0.5, 'b': 0.25, 'c': 0.25} encoding = shannon_encoding(freq_table) encoding print(f'expected length: {expected_length(freq_table, encoding)}') print(f'shannon entropy: {shannon_entropy(freq_table)}') def decode(text, encoding): output = "" while text: token, i = decode_token(text, encoding) output += token text = text[i:] return output def decode_token(text, encoding): inverted_encoding = {value: key for key, value in encoding.items()} for i in range(len(text)): if text[:i+1] in inverted_encoding: return inverted_encoding[text[:i+1]], i+1 decode(encode('aabc', encoding), encoding) %pip install -q torch numpy transformers datasets tiktoken wandb tqdm !git clone https://github.com/karpathy/nanoGPT %cd nanoGPT from model import GPT import tiktoken import torch gpt2 = GPT.from_pretrained('gpt2', dict(dropout=0.0)).to(dtype=torch.float64) enc = tiktoken.get_encoding("gpt2") tokenize = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) detokenize = lambda l: enc.decode(l) import sys max_context = 128 def compress(text): tokens = tokenize('\n' + text) tokens = torch.tensor(tokens, dtype=torch.long) output = '' for i in tqdm(range(1, len(tokens))): # get marginal from model cond_tokens = tokens[None,:i] if i <= max_context else tokens[None,i-max_context:i] probs = gpt2(cond_tokens)[0].flatten().softmax(0) prob_next = probs[tokens[i]] # probability of next token # sort probabilties probs, indices = torch.sort(probs, descending=True) idx = torch.argwhere(indices == tokens[i]).item() # Pi is the cumulative sum up until idx (not including idx) zero_tensor = torch.tensor([0]) probs = torch.cat((zero_tensor, probs), dim=0) Pi = probs.cumsum(0)[idx].item() # encode as a binary string using Shannon's method encoded_token = float_to_binary_numba(Pi, math.ceil(math.log2(1/prob_next))) output += encoded_token return output def decompress(text): tokens = tokenize('\n') tokens = torch.tensor(tokens, dtype=torch.long) output = [] processed = 0 text_length = len(text) while text: # feed in tokens and truncate if number of tokens exceeds max_context cond_tokens = tokens[None,:] if len(tokens) <= max_context else tokens[None,-max_context:] probs = gpt2(cond_tokens)[0].flatten().softmax(0) # construct frequency table freq_table = {i:prob.item() for i, prob in enumerate(probs)} encoding = shannon_encoding(freq_table) token, i = decode_token(text, encoding) # decompress the rest of the text text = text[i:] output.append(token) tokens = torch.cat((tokens, torch.tensor([token], dtype=torch.long)), dim=0) processed += i print(f'\r Decompressed {processed/text_length*100:.2f}% of text...', end='') return detokenize(output) text = austen_emma[:10000] print(text) %%time compressed_string = compress(text) compressed_string %%time decompressed_string = decompress(compressed_string) assert decompressed_string == text, 'Decompress(Compress(X)) != X' print(f'Size of uncompressed text: {len(text) * 8}') import zlib print(f'Size of compressed text using zlib: {len(zlib.compress(text.encode())) * 8}') print(f'Size of compressed text using our compression algorithm: {len(compressed_string)}')