In this tutorial, we are building BERT4Rec model in PyTorch and then training it on the movielens 1m dataset.
!pip install -q wget
import os
import sys
import wget
import math
import json
import random
import zipfile
import shutil
import pickle
import tempfile
from abc import *
import numpy as np
import pandas as pd
import pprint as pp
from pathlib import Path
from datetime import date
from tqdm import tqdm, trange
import torch
import torch.backends.cudnn as cudnn
from torch import optim as optim
import torch.utils.data as data_utils
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
tqdm.pandas()
STATE_DICT_KEY = 'model_state_dict'
OPTIMIZER_STATE_DICT_KEY = 'optimizer_state_dict'
RAW_DATASET_ROOT_FOLDER = '/content/ml-1m'
class Args:
mode = 'train'
test_model_path = '/content/models'
# Dataset
dataset_code = 'ml-1m'
min_rating = 0
min_uc = 5
min_sc = 0
split = 'leave_one_out'
dataset_split_seed = 42
eval_set_size = 500
# Dataloader
dataloader_code = 'bert'
dataloader_random_seed = 0.0
train_batch_size = 128
val_batch_size = 128
test_batch_size = 128
# NegativeSampler
train_negative_sampler_code = 'random'
train_negative_sample_size = 0
train_negative_sampling_seed = 0
test_negative_sampler_code = 'random'
test_negative_sample_size = 100
test_negative_sampling_seed = 42
# Trainer
trainer_code = 'bert'
device = 'cuda'
num_gpu = 1
device_idx = '0'
optimizer='Adam'
lr=0.001
weight_decay=0
momentum=None
enable_lr_schedule = True
decay_step=25
gamma=1.0
num_epochs=10
log_period_as_iter=12800
metric_ks=[1, 5, 10, 20, 50, 100]
best_metric='NDCG@10'
find_best_beta=False
total_anneal_steps=2000
anneal_cap=0.2
# Model
model_code='bert'
model_init_seed=0
bert_max_len=100
bert_num_items=None
bert_hidden_units=256
bert_num_blocks=2
bert_num_heads=4
bert_dropout=0.1
bert_mask_prob=0.15
# Experiment
experiment_dir='experiments'
experiment_description='test'
args = Args()
def download(url, savepath):
wget.download(url, str(savepath))
def unzip(zippath, savepath):
zip = zipfile.ZipFile(zippath)
zip.extractall(savepath)
zip.close()
def get_count(tp, id):
groups = tp[[id]].groupby(id, as_index=False)
count = groups.size()
return count
def recall(scores, labels, k):
scores = scores
labels = labels
rank = (-scores).argsort(dim=1)
cut = rank[:, :k]
hit = labels.gather(1, cut)
return (hit.sum(1).float() / torch.min(torch.Tensor([k]).to(hit.device), labels.sum(1).float())).mean().cpu().item()
def ndcg(scores, labels, k):
scores = scores.cpu()
labels = labels.cpu()
rank = (-scores).argsort(dim=1)
cut = rank[:, :k]
hits = labels.gather(1, cut)
position = torch.arange(2, 2+k)
weights = 1 / torch.log2(position.float())
dcg = (hits.float() * weights).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum() for n in labels.sum(1)])
ndcg = dcg / idcg
return ndcg.mean()
def recalls_and_ndcgs_for_ks(scores, labels, ks):
metrics = {}
scores = scores
labels = labels
answer_count = labels.sum(1)
labels_float = labels.float()
rank = (-scores).argsort(dim=1)
cut = rank
for k in sorted(ks, reverse=True):
cut = cut[:, :k]
hits = labels_float.gather(1, cut)
metrics['Recall@%d' % k] = \
(hits.sum(1) / torch.min(torch.Tensor([k]).to(labels.device), labels.sum(1).float())).mean().cpu().item()
position = torch.arange(2, 2+k)
weights = 1 / torch.log2(position.float())
dcg = (hits * weights.to(hits.device)).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum() for n in answer_count]).to(dcg.device)
ndcg = (dcg / idcg).mean()
metrics['NDCG@%d' % k] = ndcg.cpu().item()
return metrics
def setup_train(args):
set_up_gpu(args)
export_root = create_experiment_export_folder(args)
export_experiments_config_as_json(args, export_root)
pp.pprint({k: v for k, v in vars(args).items() if v is not None}, width=1)
return export_root
def create_experiment_export_folder(args):
experiment_dir, experiment_description = args.experiment_dir, args.experiment_description
if not os.path.exists(experiment_dir):
os.mkdir(experiment_dir)
experiment_path = get_name_of_experiment_path(experiment_dir, experiment_description)
os.mkdir(experiment_path)
print('Folder created: ' + os.path.abspath(experiment_path))
return experiment_path
def get_name_of_experiment_path(experiment_dir, experiment_description):
experiment_path = os.path.join(experiment_dir, (experiment_description + "_" + str(date.today())))
idx = _get_experiment_index(experiment_path)
experiment_path = experiment_path + "_" + str(idx)
return experiment_path
def _get_experiment_index(experiment_path):
idx = 0
while os.path.exists(experiment_path + "_" + str(idx)):
idx += 1
return idx
def load_weights(model, path):
pass
def save_test_result(export_root, result):
filepath = Path(export_root).joinpath('test_result.txt')
with filepath.open('w') as f:
json.dump(result, f, indent=2)
def export_experiments_config_as_json(args, experiment_path):
with open(os.path.join(experiment_path, 'config.json'), 'w') as outfile:
json.dump(vars(args), outfile, indent=2)
def fix_random_seed_as(random_seed):
random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
cudnn.deterministic = True
cudnn.benchmark = False
def set_up_gpu(args):
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_idx
args.num_gpu = len(args.device_idx.split(","))
def load_pretrained_weights(model, path):
chk_dict = torch.load(os.path.abspath(path))
model_state_dict = chk_dict[STATE_DICT_KEY] if STATE_DICT_KEY in chk_dict else chk_dict['state_dict']
model.load_state_dict(model_state_dict)
def setup_to_resume(args, model, optimizer):
chk_dict = torch.load(os.path.join(os.path.abspath(args.resume_training), 'models/checkpoint-recent.pth'))
model.load_state_dict(chk_dict[STATE_DICT_KEY])
optimizer.load_state_dict(chk_dict[OPTIMIZER_STATE_DICT_KEY])
def create_optimizer(model, args):
if args.optimizer == 'Adam':
return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
return optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
class AverageMeterSet(object):
def __init__(self, meters=None):
self.meters = meters if meters else {}
def __getitem__(self, key):
if key not in self.meters:
meter = AverageMeter()
meter.update(0)
return meter
return self.meters[key]
def update(self, name, value, n=1):
if name not in self.meters:
self.meters[name] = AverageMeter()
self.meters[name].update(value, n)
def reset(self):
for meter in self.meters.values():
meter.reset()
def values(self, format_string='{}'):
return {format_string.format(name): meter.val for name, meter in self.meters.items()}
def averages(self, format_string='{}'):
return {format_string.format(name): meter.avg for name, meter in self.meters.items()}
def sums(self, format_string='{}'):
return {format_string.format(name): meter.sum for name, meter in self.meters.items()}
def counts(self, format_string='{}'):
return {format_string.format(name): meter.count for name, meter in self.meters.items()}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val
self.count += n
self.avg = self.sum / self.count
def __format__(self, format):
return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)
def save_state_dict(state_dict, path, filename):
torch.save(state_dict, os.path.join(path, filename))
class LoggerService(object):
def __init__(self, train_loggers=None, val_loggers=None):
self.train_loggers = train_loggers if train_loggers else []
self.val_loggers = val_loggers if val_loggers else []
def complete(self, log_data):
for logger in self.train_loggers:
logger.complete(**log_data)
for logger in self.val_loggers:
logger.complete(**log_data)
def log_train(self, log_data):
for logger in self.train_loggers:
logger.log(**log_data)
def log_val(self, log_data):
for logger in self.val_loggers:
logger.log(**log_data)
class AbstractBaseLogger(metaclass=ABCMeta):
@abstractmethod
def log(self, *args, **kwargs):
raise NotImplementedError
def complete(self, *args, **kwargs):
pass
class RecentModelLogger(AbstractBaseLogger):
def __init__(self, checkpoint_path, filename='checkpoint-recent.pth'):
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
os.mkdir(self.checkpoint_path)
self.recent_epoch = None
self.filename = filename
def log(self, *args, **kwargs):
epoch = kwargs['epoch']
if self.recent_epoch != epoch:
self.recent_epoch = epoch
state_dict = kwargs['state_dict']
state_dict['epoch'] = kwargs['epoch']
save_state_dict(state_dict, self.checkpoint_path, self.filename)
def complete(self, *args, **kwargs):
save_state_dict(kwargs['state_dict'], self.checkpoint_path, self.filename + '.final')
class BestModelLogger(AbstractBaseLogger):
def __init__(self, checkpoint_path, metric_key='mean_iou', filename='best_acc_model.pth'):
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
os.mkdir(self.checkpoint_path)
self.best_metric = 0.
self.metric_key = metric_key
self.filename = filename
def log(self, *args, **kwargs):
current_metric = kwargs[self.metric_key]
if self.best_metric < current_metric:
print("Update Best {} Model at {}".format(self.metric_key, kwargs['epoch']))
self.best_metric = current_metric
save_state_dict(kwargs['state_dict'], self.checkpoint_path, self.filename)
class MetricGraphPrinter(AbstractBaseLogger):
def __init__(self, writer, key='train_loss', graph_name='Train Loss', group_name='metric'):
self.key = key
self.graph_label = graph_name
self.group_name = group_name
self.writer = writer
def log(self, *args, **kwargs):
if self.key in kwargs:
self.writer.add_scalar(self.group_name + '/' + self.graph_label, kwargs[self.key], kwargs['accum_iter'])
else:
self.writer.add_scalar(self.group_name + '/' + self.graph_label, 0, kwargs['accum_iter'])
def complete(self, *args, **kwargs):
self.writer.close()
class AbstractDataset(metaclass=ABCMeta):
def __init__(self, args):
self.args = args
self.min_rating = args.min_rating
self.min_uc = args.min_uc
self.min_sc = args.min_sc
self.split = args.split
assert self.min_uc >= 2, 'Need at least 2 ratings per user for validation and test'
@classmethod
@abstractmethod
def code(cls):
pass
@classmethod
def raw_code(cls):
return cls.code()
@classmethod
@abstractmethod
def url(cls):
pass
@classmethod
def is_zipfile(cls):
return True
@classmethod
def zip_file_content_is_folder(cls):
return True
@classmethod
def all_raw_file_names(cls):
return []
@abstractmethod
def load_ratings_df(self):
pass
def load_dataset(self):
self.preprocess()
dataset_path = self._get_preprocessed_dataset_path()
dataset = pickle.load(dataset_path.open('rb'))
return dataset
def preprocess(self):
dataset_path = self._get_preprocessed_dataset_path()
if dataset_path.is_file():
print('Already preprocessed. Skip preprocessing')
return
if not dataset_path.parent.is_dir():
dataset_path.parent.mkdir(parents=True)
self.maybe_download_raw_dataset()
df = self.load_ratings_df()
df = self.make_implicit(df)
df = self.filter_triplets(df)
df, umap, smap = self.densify_index(df)
train, val, test = self.split_df(df, len(umap))
dataset = {'train': train,
'val': val,
'test': test,
'umap': umap,
'smap': smap}
with dataset_path.open('wb') as f:
pickle.dump(dataset, f)
def maybe_download_raw_dataset(self):
folder_path = self._get_rawdata_folder_path()
if folder_path.is_dir() and\
all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):
print('Raw data already exists. Skip downloading')
return
print("Raw file doesn't exist. Downloading...")
if self.is_zipfile():
tmproot = Path(tempfile.mkdtemp())
tmpzip = tmproot.joinpath('file.zip')
tmpfolder = tmproot.joinpath('folder')
download(self.url(), tmpzip)
unzip(tmpzip, tmpfolder)
if self.zip_file_content_is_folder():
tmpfolder = tmpfolder.joinpath(os.listdir(tmpfolder)[0])
shutil.move(tmpfolder, folder_path)
shutil.rmtree(tmproot)
print()
else:
tmproot = Path(tempfile.mkdtemp())
tmpfile = tmproot.joinpath('file')
download(self.url(), tmpfile)
folder_path.mkdir(parents=True)
shutil.move(tmpfile, folder_path.joinpath('ratings.csv'))
shutil.rmtree(tmproot)
print()
def make_implicit(self, df):
print('Turning into implicit ratings')
df = df[df['rating'] >= self.min_rating]
# return df[['uid', 'sid', 'timestamp']]
return df
def filter_triplets(self, df):
print('Filtering triplets')
if self.min_sc > 0:
item_sizes = df.groupby('sid').size()
good_items = item_sizes.index[item_sizes >= self.min_sc]
df = df[df['sid'].isin(good_items)]
if self.min_uc > 0:
user_sizes = df.groupby('uid').size()
good_users = user_sizes.index[user_sizes >= self.min_uc]
df = df[df['uid'].isin(good_users)]
return df
def densify_index(self, df):
print('Densifying index')
umap = {u: i for i, u in enumerate(set(df['uid']))}
smap = {s: i for i, s in enumerate(set(df['sid']))}
df['uid'] = df['uid'].map(umap)
df['sid'] = df['sid'].map(smap)
return df, umap, smap
def split_df(self, df, user_count):
if self.args.split == 'leave_one_out':
print('Splitting')
user_group = df.groupby('uid')
user2items = user_group.progress_apply(lambda d: list(d.sort_values(by='timestamp')['sid']))
train, val, test = {}, {}, {}
for user in range(user_count):
items = user2items[user]
train[user], val[user], test[user] = items[:-2], items[-2:-1], items[-1:]
return train, val, test
elif self.args.split == 'holdout':
print('Splitting')
np.random.seed(self.args.dataset_split_seed)
eval_set_size = self.args.eval_set_size
# Generate user indices
permuted_index = np.random.permutation(user_count)
train_user_index = permuted_index[ :-2*eval_set_size]
val_user_index = permuted_index[-2*eval_set_size: -eval_set_size]
test_user_index = permuted_index[ -eval_set_size: ]
# Split DataFrames
train_df = df.loc[df['uid'].isin(train_user_index)]
val_df = df.loc[df['uid'].isin(val_user_index)]
test_df = df.loc[df['uid'].isin(test_user_index)]
# DataFrame to dict => {uid : list of sid's}
train = dict(train_df.groupby('uid').progress_apply(lambda d: list(d['sid'])))
val = dict(val_df.groupby('uid').progress_apply(lambda d: list(d['sid'])))
test = dict(test_df.groupby('uid').progress_apply(lambda d: list(d['sid'])))
return train, val, test
else:
raise NotImplementedError
def _get_rawdata_root_path(self):
return Path(RAW_DATASET_ROOT_FOLDER)
def _get_rawdata_folder_path(self):
root = self._get_rawdata_root_path()
return root.joinpath(self.raw_code())
def _get_preprocessed_root_path(self):
root = self._get_rawdata_root_path()
return root.joinpath('preprocessed')
def _get_preprocessed_folder_path(self):
preprocessed_root = self._get_preprocessed_root_path()
folder_name = '{}_min_rating{}-min_uc{}-min_sc{}-split{}' \
.format(self.code(), self.min_rating, self.min_uc, self.min_sc, self.split)
return preprocessed_root.joinpath(folder_name)
def _get_preprocessed_dataset_path(self):
folder = self._get_preprocessed_folder_path()
return folder.joinpath('dataset.pkl')
class ML1MDataset(AbstractDataset):
@classmethod
def code(cls):
return 'ml-1m'
@classmethod
def url(cls):
return 'http://files.grouplens.org/datasets/movielens/ml-1m.zip'
@classmethod
def zip_file_content_is_folder(cls):
return True
@classmethod
def all_raw_file_names(cls):
return ['README',
'movies.dat',
'ratings.dat',
'users.dat']
def load_ratings_df(self):
folder_path = self._get_rawdata_folder_path()
file_path = folder_path.joinpath('ratings.dat')
df = pd.read_csv(file_path, sep='::', header=None)
df.columns = ['uid', 'sid', 'rating', 'timestamp']
return df
DATASETS = {
ML1MDataset.code(): ML1MDataset
}
def dataset_factory(args):
dataset = DATASETS[args.dataset_code]
return dataset(args)
class AbstractNegativeSampler(metaclass=ABCMeta):
def __init__(self, train, val, test, user_count, item_count, sample_size, seed, save_folder):
self.train = train
self.val = val
self.test = test
self.user_count = user_count
self.item_count = item_count
self.sample_size = sample_size
self.seed = seed
self.save_folder = save_folder
@classmethod
@abstractmethod
def code(cls):
pass
@abstractmethod
def generate_negative_samples(self):
pass
def get_negative_samples(self):
savefile_path = self._get_save_path()
if savefile_path.is_file():
print('Negatives samples exist. Loading.')
negative_samples = pickle.load(savefile_path.open('rb'))
return negative_samples
print("Negative samples don't exist. Generating.")
negative_samples = self.generate_negative_samples()
with savefile_path.open('wb') as f:
pickle.dump(negative_samples, f)
return negative_samples
def _get_save_path(self):
folder = Path(self.save_folder)
filename = '{}-sample_size{}-seed{}.pkl'.format(self.code(), self.sample_size, self.seed)
return folder.joinpath(filename)
class RandomNegativeSampler(AbstractNegativeSampler):
@classmethod
def code(cls):
return 'random'
def generate_negative_samples(self):
assert self.seed is not None, 'Specify seed for random sampling'
np.random.seed(self.seed)
negative_samples = {}
print('Sampling negative items')
for user in trange(self.user_count):
if isinstance(self.train[user][1], tuple):
seen = set(x[0] for x in self.train[user])
seen.update(x[0] for x in self.val[user])
seen.update(x[0] for x in self.test[user])
else:
seen = set(self.train[user])
seen.update(self.val[user])
seen.update(self.test[user])
samples = []
for _ in range(self.sample_size):
item = np.random.choice(self.item_count) + 1
while item in seen or item in samples:
item = np.random.choice(self.item_count) + 1
samples.append(item)
negative_samples[user] = samples
return negative_samples
NEGATIVE_SAMPLERS = {
RandomNegativeSampler.code(): RandomNegativeSampler,
}
def negative_sampler_factory(code, train, val, test, user_count, item_count, sample_size, seed, save_folder):
negative_sampler = NEGATIVE_SAMPLERS[code]
return negative_sampler(train, val, test, user_count, item_count, sample_size, seed, save_folder)
class AbstractDataloader(metaclass=ABCMeta):
def __init__(self, args, dataset):
self.args = args
seed = args.dataloader_random_seed
self.rng = random.Random(seed)
self.save_folder = dataset._get_preprocessed_folder_path()
dataset = dataset.load_dataset()
self.train = dataset['train']
self.val = dataset['val']
self.test = dataset['test']
self.umap = dataset['umap']
self.smap = dataset['smap']
self.user_count = len(self.umap)
self.item_count = len(self.smap)
@classmethod
@abstractmethod
def code(cls):
pass
@abstractmethod
def get_pytorch_dataloaders(self):
pass
class BertDataloader(AbstractDataloader):
def __init__(self, args, dataset):
super().__init__(args, dataset)
args.num_items = len(self.smap)
self.max_len = args.bert_max_len
self.mask_prob = args.bert_mask_prob
self.CLOZE_MASK_TOKEN = self.item_count + 1
code = args.train_negative_sampler_code
train_negative_sampler = negative_sampler_factory(code, self.train, self.val, self.test,
self.user_count, self.item_count,
args.train_negative_sample_size,
args.train_negative_sampling_seed,
self.save_folder)
code = args.test_negative_sampler_code
test_negative_sampler = negative_sampler_factory(code, self.train, self.val, self.test,
self.user_count, self.item_count,
args.test_negative_sample_size,
args.test_negative_sampling_seed,
self.save_folder)
self.train_negative_samples = train_negative_sampler.get_negative_samples()
self.test_negative_samples = test_negative_sampler.get_negative_samples()
@classmethod
def code(cls):
return 'bert'
def get_pytorch_dataloaders(self):
train_loader = self._get_train_loader()
val_loader = self._get_val_loader()
test_loader = self._get_test_loader()
return train_loader, val_loader, test_loader
def _get_train_loader(self):
dataset = self._get_train_dataset()
dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,
shuffle=True, pin_memory=True)
return dataloader
def _get_train_dataset(self):
dataset = BertTrainDataset(self.train, self.max_len, self.mask_prob, self.CLOZE_MASK_TOKEN, self.item_count, self.rng)
return dataset
def _get_val_loader(self):
return self._get_eval_loader(mode='val')
def _get_test_loader(self):
return self._get_eval_loader(mode='test')
def _get_eval_loader(self, mode):
batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size
dataset = self._get_eval_dataset(mode)
dataloader = data_utils.DataLoader(dataset, batch_size=batch_size,
shuffle=False, pin_memory=True)
return dataloader
def _get_eval_dataset(self, mode):
answers = self.val if mode == 'val' else self.test
dataset = BertEvalDataset(self.train, answers, self.max_len, self.CLOZE_MASK_TOKEN, self.test_negative_samples)
return dataset
class BertTrainDataset(data_utils.Dataset):
def __init__(self, u2seq, max_len, mask_prob, mask_token, num_items, rng):
self.u2seq = u2seq
self.users = sorted(self.u2seq.keys())
self.max_len = max_len
self.mask_prob = mask_prob
self.mask_token = mask_token
self.num_items = num_items
self.rng = rng
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
seq = self._getseq(user)
tokens = []
labels = []
for s in seq:
prob = self.rng.random()
if prob < self.mask_prob:
prob /= self.mask_prob
if prob < 0.8:
tokens.append(self.mask_token)
elif prob < 0.9:
tokens.append(self.rng.randint(1, self.num_items))
else:
tokens.append(s)
labels.append(s)
else:
tokens.append(s)
labels.append(0)
tokens = tokens[-self.max_len:]
labels = labels[-self.max_len:]
mask_len = self.max_len - len(tokens)
tokens = [0] * mask_len + tokens
labels = [0] * mask_len + labels
return torch.LongTensor(tokens), torch.LongTensor(labels)
def _getseq(self, user):
return self.u2seq[user]
class BertEvalDataset(data_utils.Dataset):
def __init__(self, u2seq, u2answer, max_len, mask_token, negative_samples):
self.u2seq = u2seq
self.users = sorted(self.u2seq.keys())
self.u2answer = u2answer
self.max_len = max_len
self.mask_token = mask_token
self.negative_samples = negative_samples
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
seq = self.u2seq[user]
answer = self.u2answer[user]
negs = self.negative_samples[user]
candidates = answer + negs
labels = [1] * len(answer) + [0] * len(negs)
seq = seq + [self.mask_token]
seq = seq[-self.max_len:]
padding_len = self.max_len - len(seq)
seq = [0] * padding_len + seq
return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)
DATALOADERS = {
BertDataloader.code(): BertDataloader,
}
def dataloader_factory(args):
dataset = dataset_factory(args)
dataloader = DATALOADERS[args.dataloader_code]
dataloader = dataloader(args, dataset)
train, val, test = dataloader.get_pytorch_dataloaders()
return train, val, test
class LayerNorm(nn.Module):
"Construct a layernorm module (See citation for details)."
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class SublayerConnection(nn.Module):
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
"Apply residual connection to any sublayer with the same size."
return x + self.dropout(sublayer(self.norm(x)))
class GELU(nn.Module):
"""
Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
"""
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class PositionwiseFeedForward(nn.Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = GELU()
def forward(self, x):
return self.w_2(self.dropout(self.activation(self.w_1(x))))
class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""
def forward(self, query, key, value, mask=None, dropout=None):
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(query.size(-1))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""
def __init__(self, h, d_model, dropout=0.1):
super().__init__()
assert d_model % h == 0
# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
self.output_linear = nn.Linear(d_model, d_model)
self.attention = Attention()
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linear_layers, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch.
x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
# 3) "Concat" using a view and apply a final linear.
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
return self.output_linear(x)
class PositionalEmbedding(nn.Module):
def __init__(self, max_len, d_model):
super().__init__()
# Compute the positional encodings once in log space.
self.pe = nn.Embedding(max_len, d_model)
def forward(self, x):
batch_size = x.size(0)
return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)
class SegmentEmbedding(nn.Embedding):
def __init__(self, embed_size=512):
super().__init__(3, embed_size, padding_idx=0)
class TokenEmbedding(nn.Embedding):
def __init__(self, vocab_size, embed_size=512):
super().__init__(vocab_size, embed_size, padding_idx=0)
class BERTEmbedding(nn.Module):
"""
BERT Embedding which is consisted with under features
1. TokenEmbedding : normal embedding matrix
2. PositionalEmbedding : adding positional information using sin, cos
2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)
sum of all these features are output of BERTEmbedding
"""
def __init__(self, vocab_size, embed_size, max_len, dropout=0.1):
"""
:param vocab_size: total vocab size
:param embed_size: embedding size of token embedding
:param dropout: dropout rate
"""
super().__init__()
self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
self.position = PositionalEmbedding(max_len=max_len, d_model=embed_size)
# self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
self.dropout = nn.Dropout(p=dropout)
self.embed_size = embed_size
def forward(self, sequence):
x = self.token(sequence) + self.position(sequence) # + self.segment(segment_label)
return self.dropout(x)
class TransformerBlock(nn.Module):
"""
Bidirectional Encoder = Transformer (self-attention)
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
"""
def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
"""
:param hidden: hidden size of transformer
:param attn_heads: head sizes of multi-head attention
:param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
:param dropout: dropout rate
"""
super().__init__()
self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, mask):
x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
x = self.output_sublayer(x, self.feed_forward)
return self.dropout(x)
class BERT(nn.Module):
def __init__(self, args):
super().__init__()
fix_random_seed_as(args.model_init_seed)
# self.init_weights()
max_len = args.bert_max_len
num_items = args.num_items
n_layers = args.bert_num_blocks
heads = args.bert_num_heads
vocab_size = num_items + 2
hidden = args.bert_hidden_units
self.hidden = hidden
dropout = args.bert_dropout
# embedding for BERT, sum of positional, segment, token embeddings
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=self.hidden, max_len=max_len, dropout=dropout)
# multi-layers transformer blocks, deep network
self.transformer_blocks = nn.ModuleList(
[TransformerBlock(hidden, heads, hidden * 4, dropout) for _ in range(n_layers)])
def forward(self, x):
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
# embedding the indexed sequence to sequence of vectors
x = self.embedding(x)
# running over multiple transformer blocks
for transformer in self.transformer_blocks:
x = transformer.forward(x, mask)
return x
def init_weights(self):
pass
class BaseModel(nn.Module, metaclass=ABCMeta):
def __init__(self, args):
super().__init__()
self.args = args
@classmethod
@abstractmethod
def code(cls):
pass
class BERTModel(BaseModel):
def __init__(self, args):
super().__init__(args)
self.bert = BERT(args)
self.out = nn.Linear(self.bert.hidden, args.num_items + 1)
@classmethod
def code(cls):
return 'bert'
def forward(self, x):
x = self.bert(x)
return self.out(x)
MODELS = {
BERTModel.code(): BERTModel,
}
def model_factory(args):
model = MODELS[args.model_code]
return model(args)
class AbstractTrainer(metaclass=ABCMeta):
def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):
self.args = args
self.device = args.device
self.model = model.to(self.device)
self.is_parallel = args.num_gpu > 1
if self.is_parallel:
self.model = nn.DataParallel(self.model)
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.optimizer = self._create_optimizer()
if args.enable_lr_schedule:
self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=args.decay_step, gamma=args.gamma)
self.num_epochs = args.num_epochs
self.metric_ks = args.metric_ks
self.best_metric = args.best_metric
self.export_root = export_root
self.writer, self.train_loggers, self.val_loggers = self._create_loggers()
self.add_extra_loggers()
self.logger_service = LoggerService(self.train_loggers, self.val_loggers)
self.log_period_as_iter = args.log_period_as_iter
@abstractmethod
def add_extra_loggers(self):
pass
@abstractmethod
def log_extra_train_info(self, log_data):
pass
@abstractmethod
def log_extra_val_info(self, log_data):
pass
@classmethod
@abstractmethod
def code(cls):
pass
@abstractmethod
def calculate_loss(self, batch):
pass
@abstractmethod
def calculate_metrics(self, batch):
pass
def train(self):
accum_iter = 0
self.validate(0, accum_iter)
for epoch in range(self.num_epochs):
accum_iter = self.train_one_epoch(epoch, accum_iter)
self.validate(epoch, accum_iter)
self.logger_service.complete({
'state_dict': (self._create_state_dict()),
})
self.writer.close()
def train_one_epoch(self, epoch, accum_iter):
self.model.train()
if self.args.enable_lr_schedule:
self.lr_scheduler.step()
average_meter_set = AverageMeterSet()
tqdm_dataloader = tqdm(self.train_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch_size = batch[0].size(0)
batch = [x.to(self.device) for x in batch]
self.optimizer.zero_grad()
loss = self.calculate_loss(batch)
loss.backward()
self.optimizer.step()
average_meter_set.update('loss', loss.item())
tqdm_dataloader.set_description(
'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))
accum_iter += batch_size
if self._needs_to_log(accum_iter):
tqdm_dataloader.set_description('Logging to Tensorboard')
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch+1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.log_extra_train_info(log_data)
self.logger_service.log_train(log_data)
return accum_iter
def validate(self, epoch, accum_iter):
self.model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.val_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = [x.to(self.device) for x in batch]
metrics = self.calculate_metrics(batch)
for k, v in metrics.items():
average_meter_set.update(k, v)
description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\
['Recall@%d' % k for k in self.metric_ks[:3]]
description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics)
description = description.replace('NDCG', 'N').replace('Recall', 'R')
description = description.format(*(average_meter_set[k].avg for k in description_metrics))
tqdm_dataloader.set_description(description)
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch+1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.log_extra_val_info(log_data)
self.logger_service.log_val(log_data)
def test(self):
print('Test best model with test set!')
best_model = torch.load(os.path.join(self.export_root, 'models', 'best_acc_model.pth')).get('model_state_dict')
self.model.load_state_dict(best_model)
self.model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = [x.to(self.device) for x in batch]
metrics = self.calculate_metrics(batch)
for k, v in metrics.items():
average_meter_set.update(k, v)
description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]] +\
['Recall@%d' % k for k in self.metric_ks[:3]]
description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics)
description = description.replace('NDCG', 'N').replace('Recall', 'R')
description = description.format(*(average_meter_set[k].avg for k in description_metrics))
tqdm_dataloader.set_description(description)
average_metrics = average_meter_set.averages()
with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:
json.dump(average_metrics, f, indent=4)
print(average_metrics)
def _create_optimizer(self):
args = self.args
if args.optimizer.lower() == 'adam':
return optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'sgd':
return optim.SGD(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
else:
raise ValueError
def _create_loggers(self):
root = Path(self.export_root)
writer = SummaryWriter(root.joinpath('logs'))
model_checkpoint = root.joinpath('models')
train_loggers = [
MetricGraphPrinter(writer, key='epoch', graph_name='Epoch', group_name='Train'),
MetricGraphPrinter(writer, key='loss', graph_name='Loss', group_name='Train'),
]
val_loggers = []
for k in self.metric_ks:
val_loggers.append(
MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))
val_loggers.append(
MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))
val_loggers.append(RecentModelLogger(model_checkpoint))
val_loggers.append(BestModelLogger(model_checkpoint, metric_key=self.best_metric))
return writer, train_loggers, val_loggers
def _create_state_dict(self):
return {
STATE_DICT_KEY: self.model.module.state_dict() if self.is_parallel else self.model.state_dict(),
OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),
}
def _needs_to_log(self, accum_iter):
return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0
class BERTTrainer(AbstractTrainer):
def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):
super().__init__(args, model, train_loader, val_loader, test_loader, export_root)
self.ce = nn.CrossEntropyLoss(ignore_index=0)
@classmethod
def code(cls):
return 'bert'
def add_extra_loggers(self):
pass
def log_extra_train_info(self, log_data):
pass
def log_extra_val_info(self, log_data):
pass
def calculate_loss(self, batch):
seqs, labels = batch
logits = self.model(seqs) # B x T x V
logits = logits.view(-1, logits.size(-1)) # (B*T) x V
labels = labels.view(-1) # B*T
loss = self.ce(logits, labels)
return loss
def calculate_metrics(self, batch):
seqs, candidates, labels = batch
scores = self.model(seqs) # B x T x V
scores = scores[:, -1, :] # B x V
scores = scores.gather(1, candidates) # B x C
metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
return metrics
TRAINERS = {
BERTTrainer.code(): BERTTrainer,
}
def trainer_factory(args, model, train_loader, val_loader, test_loader, export_root):
trainer = TRAINERS[args.trainer_code]
return trainer(args, model, train_loader, val_loader, test_loader, export_root)
def train():
export_root = setup_train(args)
train_loader, val_loader, test_loader = dataloader_factory(args)
model = model_factory(args)
trainer = trainer_factory(args, model, train_loader, val_loader, test_loader, export_root)
trainer.train()
test_model = (input('Test model with test dataset? y/[n]: ') == 'y')
if test_model:
trainer.test()
if __name__ == '__main__':
if args.mode == 'train':
train()
else:
raise ValueError('Invalid mode')
Folder created: /content/experiments/test_2021-09-26_1 {'num_gpu': 1} Already preprocessed. Skip preprocessing Negatives samples exist. Loading. Negatives samples exist. Loading.
Val: N@1 0.013, N@5 0.035, N@10 0.051, R@1 0.013, R@5 0.059, R@10 0.111: 100%|██████████| 48/48 [01:30<00:00, 1.88s/it] /usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py:134: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
Update Best NDCG@10 Model at 1
Epoch 1, loss 7.841 : 100%|██████████| 48/48 [05:01<00:00, 6.29s/it] Val: N@1 0.083, N@5 0.185, N@10 0.232, R@1 0.083, R@5 0.285, R@10 0.430: 100%|██████████| 48/48 [01:36<00:00, 2.00s/it]
Update Best NDCG@10 Model at 1
Epoch 2, loss 7.427 : 100%|██████████| 48/48 [04:59<00:00, 6.24s/it] Val: N@1 0.099, N@5 0.208, N@10 0.255, R@1 0.099, R@5 0.315, R@10 0.461: 100%|██████████| 48/48 [01:35<00:00, 2.00s/it]
Update Best NDCG@10 Model at 2
Epoch 3, loss 7.192 : 100%|██████████| 48/48 [04:59<00:00, 6.24s/it] Val: N@1 0.104, N@5 0.223, N@10 0.276, R@1 0.104, R@5 0.336, R@10 0.501: 100%|██████████| 48/48 [01:36<00:00, 2.00s/it]
Update Best NDCG@10 Model at 3
Epoch 4, loss 6.957 : 100%|██████████| 48/48 [05:07<00:00, 6.40s/it] Val: N@1 0.132, N@5 0.263, N@10 0.315, R@1 0.132, R@5 0.390, R@10 0.553: 100%|██████████| 48/48 [01:41<00:00, 2.11s/it]
Update Best NDCG@10 Model at 4
Epoch 5, loss 6.783 : 100%|██████████| 48/48 [05:07<00:00, 6.41s/it] Val: N@1 0.154, N@5 0.302, N@10 0.357, R@1 0.154, R@5 0.444, R@10 0.616: 100%|██████████| 48/48 [01:44<00:00, 2.17s/it]
Update Best NDCG@10 Model at 5
Epoch 6, loss 6.606 : 100%|██████████| 48/48 [05:08<00:00, 6.43s/it] Val: N@1 0.174, N@5 0.335, N@10 0.383, R@1 0.174, R@5 0.484, R@10 0.634: 100%|██████████| 48/48 [01:44<00:00, 2.17s/it]
Update Best NDCG@10 Model at 6
Epoch 7, loss 6.440 : 100%|██████████| 48/48 [05:08<00:00, 6.44s/it] Val: N@1 0.197, N@5 0.367, N@10 0.414, R@1 0.197, R@5 0.521, R@10 0.668: 100%|██████████| 48/48 [01:44<00:00, 2.17s/it]
Update Best NDCG@10 Model at 7
Epoch 8, loss 6.323 : 100%|██████████| 48/48 [05:08<00:00, 6.42s/it] Val: N@1 0.219, N@5 0.388, N@10 0.434, R@1 0.219, R@5 0.545, R@10 0.687: 100%|██████████| 48/48 [01:44<00:00, 2.18s/it]
Update Best NDCG@10 Model at 8
Epoch 9, loss 6.222 : 79%|███████▉ | 38/48 [04:11<01:05, 6.57s/it]