Abstract
Scope
We formalize the problem with the following settings to define the research scope:
Tasks
!apt-get install libarchive-dev
!pip install faiss-cpu --no-cache
!apt-get install libomp-dev
!pip install wget
!pip install libarchive
import argparse
import torch
import pickle
import random
import shutil
import tempfile
import os
from pathlib import Path
import gzip
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
from abc import *
import wget
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
# from torch.autograd.gradcheck import zero_gradients
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import json
import faiss
import numpy as np
from abc import *
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import json
import faiss
import numpy as np
from abc import *
from pathlib import Path
import json
import os
import pprint as pp
import random
from datetime import date
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch import optim as optim
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import json
import faiss
import numpy as np
from abc import *
from pathlib import Path
from pathlib import Path
import zipfile
import libarchive
import sys
from datetime import date
from pathlib import Path
import pickle
import shutil
import tempfile
import os
from tqdm import trange
from collections import Counter
import numpy as np
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn as nn
import math
import torch
import random
import torch.utils.data as data_utils
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
RAW_DATASET_ROOT_FOLDER = 'data'
GEN_DATASET_ROOT_FOLDER = 'gen_data'
STATE_DICT_KEY = 'model_state_dict'
OPTIMIZER_STATE_DICT_KEY = 'optimizer_state_dict'
def fix_random_seed_as(random_seed):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def set_template(args):
args.min_uc = 5
args.min_sc = 5
args.split = 'leave_one_out'
# dataset_code = {'1': 'ml-1m', '20': 'ml-20m', 'b': 'beauty', 'bd': 'beauty_dense' , 'g': 'games', 's': 'steam', 'y': 'yoochoose'}
# args.dataset_code = dataset_code[input('Input 1 / 20 for movielens, b for beauty, bd for dense beauty, g for games, s for steam and y for yoochoose: ')]
args.dataset_code = 'ml-1m'
if args.dataset_code == 'ml-1m':
args.sliding_window_size = 0.5
args.bert_hidden_units = 64
args.bert_dropout = 0.1
args.bert_attn_dropout = 0.1
args.bert_max_len = 200
args.bert_mask_prob = 0.2
args.bert_max_predictions = 40
elif args.dataset_code == 'ml-20m':
args.sliding_window_size = 0.5
args.bert_hidden_units = 64
args.bert_dropout = 0.1
args.bert_attn_dropout = 0.1
args.bert_max_len = 200
args.bert_mask_prob = 0.2
args.bert_max_predictions = 20
elif args.dataset_code in ['beauty', 'beauty_dense']:
args.sliding_window_size = 0.5
args.bert_hidden_units = 64
args.bert_dropout = 0.5
args.bert_attn_dropout = 0.2
args.bert_max_len = 50
args.bert_mask_prob = 0.6
args.bert_max_predictions = 30
elif args.dataset_code == 'games':
args.sliding_window_size = 0.5
args.bert_hidden_units = 64
args.bert_dropout = 0.5
args.bert_attn_dropout = 0.5
args.bert_max_len = 50
args.bert_mask_prob = 0.5
args.bert_max_predictions = 25
elif args.dataset_code == 'steam':
args.sliding_window_size = 0.5
args.bert_hidden_units = 64
args.bert_dropout = 0.2
args.bert_attn_dropout = 0.2
args.bert_max_len = 50
args.bert_mask_prob = 0.4
args.bert_max_predictions = 20
elif args.dataset_code == 'yoochoose':
args.sliding_window_size = 0.5
args.bert_hidden_units = 256
args.bert_dropout = 0.2
args.bert_attn_dropout = 0.2
args.bert_max_len = 50
args.bert_mask_prob = 0.4
args.bert_max_predictions = 20
batch = 128
args.train_batch_size = batch
args.val_batch_size = batch
args.test_batch_size = batch
args.train_negative_sampler_code = 'random'
args.train_negative_sample_size = 0
args.train_negative_sampling_seed = 0
args.test_negative_sampler_code = 'random'
args.test_negative_sample_size = 100
args.test_negative_sampling_seed = 98765
# model_codes = {'b': 'bert', 's':'sas', 'n':'narm'}
# args.model_code = model_codes[input('Input model code, b for BERT, s for SASRec and n for NARM: ')]
args.model_code = 'bert'
if torch.cuda.is_available():
# args.device = 'cuda:' + input('Input GPU ID: ')
args.device = 'cuda:0'
else:
args.device = 'cpu'
args.optimizer = 'AdamW'
args.lr = 0.001
args.weight_decay = 0.01
args.enable_lr_schedule = True
args.decay_step = 10000
args.gamma = 1.
args.enable_lr_warmup = False
args.warmup_steps = 100
args.num_epochs = 1000
args.metric_ks = [1, 5, 10]
args.best_metric = 'NDCG@10'
args.model_init_seed = 98765
args.bert_num_blocks = 2
args.bert_num_heads = 2
args.bert_head_size = None
parser = argparse.ArgumentParser()
################
# Dataset
################
parser.add_argument('--dataset_code', type=str, default='ml-1m', choices=DATASETS.keys())
parser.add_argument('--min_rating', type=int, default=0)
parser.add_argument('--min_uc', type=int, default=5)
parser.add_argument('--min_sc', type=int, default=5)
parser.add_argument('--split', type=str, default='leave_one_out')
parser.add_argument('--dataset_split_seed', type=int, default=0)
################
# Dataloader
################
parser.add_argument('--dataloader_random_seed', type=float, default=0)
parser.add_argument('--train_batch_size', type=int, default=64)
parser.add_argument('--val_batch_size', type=int, default=64)
parser.add_argument('--test_batch_size', type=int, default=64)
parser.add_argument('--sliding_window_size', type=float, default=0.5)
################
# NegativeSampler
################
parser.add_argument('--train_negative_sampler_code', type=str, default='random', choices=['popular', 'random'])
parser.add_argument('--train_negative_sample_size', type=int, default=0)
parser.add_argument('--train_negative_sampling_seed', type=int, default=0)
parser.add_argument('--test_negative_sampler_code', type=str, default='random', choices=['popular', 'random'])
parser.add_argument('--test_negative_sample_size', type=int, default=100)
parser.add_argument('--test_negative_sampling_seed', type=int, default=0)
################
# Trainer
################
# device #
parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda'])
parser.add_argument('--num_gpu', type=int, default=1)
# optimizer & lr#
parser.add_argument('--optimizer', type=str, default='AdamW', choices=['AdamW', 'Adam', 'SGD'])
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--adam_epsilon', type=float, default=1e-9)
parser.add_argument('--momentum', type=float, default=None)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--enable_lr_schedule', type=bool, default=True)
parser.add_argument('--decay_step', type=int, default=100)
parser.add_argument('--gamma', type=float, default=1)
parser.add_argument('--enable_lr_warmup', type=bool, default=True)
parser.add_argument('--warmup_steps', type=int, default=100)
# epochs #
parser.add_argument('--num_epochs', type=int, default=100)
# logger #
parser.add_argument('--log_period_as_iter', type=int, default=12800)
# evaluation #
parser.add_argument('--metric_ks', nargs='+', type=int, default=[1, 5, 10, 20])
parser.add_argument('--best_metric', type=str, default='NDCG@10')
################
# Model
################
parser.add_argument('--model_code', type=str, default='bert', choices=['bert', 'sas', 'narm'])
# BERT specs, used for SASRec and NARM as well #
parser.add_argument('--bert_max_len', type=int, default=None)
parser.add_argument('--bert_hidden_units', type=int, default=64)
parser.add_argument('--bert_num_blocks', type=int, default=2)
parser.add_argument('--bert_num_heads', type=int, default=2)
parser.add_argument('--bert_head_size', type=int, default=32)
parser.add_argument('--bert_dropout', type=float, default=0.1)
parser.add_argument('--bert_attn_dropout', type=float, default=0.1)
parser.add_argument('--bert_mask_prob', type=float, default=0.2)
################
# Distillation & Retraining
################
parser.add_argument('--num_generated_seqs', type=int, default=3000)
parser.add_argument('--num_original_seqs', type=int, default=0)
parser.add_argument('--num_poisoned_seqs', type=int, default=100)
parser.add_argument('--num_alter_items', type=int, default=10)
################
args = parser.parse_args(args={})
def download(url, savepath):
wget.download(url, str(savepath))
print()
def unzip(zippath, savepath):
print("Extracting data...")
zip = zipfile.ZipFile(zippath)
zip.extractall(savepath)
zip.close()
def unzip7z(filename):
print("Extracting data...")
libarchive.extract_file(filename)
def ndcg(scores, labels, k):
scores = scores.cpu()
labels = labels.cpu()
rank = (-scores).argsort(dim=1)
cut = rank[:, :k]
hits = labels.gather(1, cut)
position = torch.arange(2, 2+k)
weights = 1 / torch.log2(position.float())
dcg = (hits.float() * weights).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum()
for n in labels.sum(1)])
ndcg = dcg / idcg
return ndcg.mean()
def recalls_and_ndcgs_for_ks(scores, labels, ks):
metrics = {}
scores = scores
labels = labels
answer_count = labels.sum(1)
labels_float = labels.float()
rank = (-scores).argsort(dim=1)
cut = rank
for k in sorted(ks, reverse=True):
cut = cut[:, :k]
hits = labels_float.gather(1, cut)
metrics['Recall@%d' % k] = \
(hits.sum(1) / torch.min(torch.Tensor([k]).to(
labels.device), labels.sum(1).float())).mean().cpu().item()
position = torch.arange(2, 2+k)
weights = 1 / torch.log2(position.float())
dcg = (hits * weights.to(hits.device)).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum()
for n in answer_count]).to(dcg.device)
ndcg = (dcg / idcg).mean()
metrics['NDCG@%d' % k] = ndcg.cpu().item()
return metrics
def em_and_agreement(scores_rank, labels_rank):
em = (scores_rank == labels_rank).float().mean()
temp = np.hstack((scores_rank.numpy(), labels_rank.numpy()))
temp = np.sort(temp, axis=1)
agreement = np.mean(np.sum(temp[:, 1:] == temp[:, :-1], axis=1))
return em, agreement
def kl_agreements_and_intersctions_for_ks(scores, soft_labels, ks, k_kl=100):
metrics = {}
scores = scores.cpu()
soft_labels = soft_labels.cpu()
scores_rank = (-scores).argsort(dim=1)
labels_rank = (-soft_labels).argsort(dim=1)
top_kl_scores = F.log_softmax(scores.gather(1, labels_rank[:, :k_kl]), dim=-1)
top_kl_labels = F.softmax(soft_labels.gather(1, labels_rank[:, :k_kl]), dim=-1)
kl = F.kl_div(top_kl_scores, top_kl_labels, reduction='batchmean')
metrics['KL-Div'] = kl.item()
for k in sorted(ks, reverse=True):
em, agreement = em_and_agreement(scores_rank[:, :k], labels_rank[:, :k])
metrics['EM@%d' % k] = em.item()
metrics['Agr@%d' % k] = (agreement / k).item()
return metrics
class AverageMeterSet(object):
def __init__(self, meters=None):
self.meters = meters if meters else {}
def __getitem__(self, key):
if key not in self.meters:
meter = AverageMeter()
meter.update(0)
return meter
return self.meters[key]
def update(self, name, value, n=1):
if name not in self.meters:
self.meters[name] = AverageMeter()
self.meters[name].update(value, n)
def reset(self):
for meter in self.meters.values():
meter.reset()
def values(self, format_string='{}'):
return {format_string.format(name): meter.val for name, meter in self.meters.items()}
def averages(self, format_string='{}'):
return {format_string.format(name): meter.avg for name, meter in self.meters.items()}
def sums(self, format_string='{}'):
return {format_string.format(name): meter.sum for name, meter in self.meters.items()}
def counts(self, format_string='{}'):
return {format_string.format(name): meter.count for name, meter in self.meters.items()}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val
self.count += n
self.avg = self.sum / self.count
def __format__(self, format):
return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)
def save_state_dict(state_dict, path, filename):
torch.save(state_dict, os.path.join(path, filename))
class LoggerService(object):
def __init__(self, train_loggers=None, val_loggers=None):
self.train_loggers = train_loggers if train_loggers else []
self.val_loggers = val_loggers if val_loggers else []
def complete(self, log_data):
for logger in self.train_loggers:
logger.complete(**log_data)
for logger in self.val_loggers:
logger.complete(**log_data)
def log_train(self, log_data):
for logger in self.train_loggers:
logger.log(**log_data)
def log_val(self, log_data):
for logger in self.val_loggers:
logger.log(**log_data)
class AbstractBaseLogger(metaclass=ABCMeta):
@abstractmethod
def log(self, *args, **kwargs):
raise NotImplementedError
def complete(self, *args, **kwargs):
pass
class RecentModelLogger(AbstractBaseLogger):
def __init__(self, checkpoint_path, filename='checkpoint-recent.pth'):
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
os.mkdir(self.checkpoint_path)
self.recent_epoch = None
self.filename = filename
def log(self, *args, **kwargs):
epoch = kwargs['epoch']
if self.recent_epoch != epoch:
self.recent_epoch = epoch
state_dict = kwargs['state_dict']
state_dict['epoch'] = kwargs['epoch']
save_state_dict(state_dict, self.checkpoint_path, self.filename)
def complete(self, *args, **kwargs):
save_state_dict(kwargs['state_dict'],
self.checkpoint_path, self.filename + '.final')
class BestModelLogger(AbstractBaseLogger):
def __init__(self, checkpoint_path, metric_key='mean_iou', filename='best_acc_model.pth'):
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
os.mkdir(self.checkpoint_path)
self.best_metric = 0.
self.metric_key = metric_key
self.filename = filename
def log(self, *args, **kwargs):
current_metric = kwargs[self.metric_key]
if self.best_metric < current_metric:
print("Update Best {} Model at {}".format(
self.metric_key, kwargs['epoch']))
self.best_metric = current_metric
save_state_dict(kwargs['state_dict'],
self.checkpoint_path, self.filename)
class MetricGraphPrinter(AbstractBaseLogger):
def __init__(self, writer, key='train_loss', graph_name='Train Loss', group_name='metric'):
self.key = key
self.graph_label = graph_name
self.group_name = group_name
self.writer = writer
def log(self, *args, **kwargs):
if self.key in kwargs:
self.writer.add_scalar(
self.group_name + '/' + self.graph_label, kwargs[self.key], kwargs['accum_iter'])
else:
self.writer.add_scalar(
self.group_name + '/' + self.graph_label, 0, kwargs['accum_iter'])
def complete(self, *args, **kwargs):
self.writer.close()
class AbstractDataset(metaclass=ABCMeta):
def __init__(self, args):
self.args = args
self.min_rating = args.min_rating
self.min_uc = args.min_uc
self.min_sc = args.min_sc
self.split = args.split
assert self.min_uc >= 2, 'Need at least 2 ratings per user for validation and test'
@classmethod
@abstractmethod
def code(cls):
pass
@classmethod
def raw_code(cls):
return cls.code()
@classmethod
def zip_file_content_is_folder(cls):
return True
@classmethod
def all_raw_file_names(cls):
return []
@classmethod
@abstractmethod
def url(cls):
pass
@classmethod
@abstractmethod
def is_zipfile(cls):
pass
@classmethod
@abstractmethod
def is_7zfile(cls):
pass
@abstractmethod
def preprocess(self):
pass
@abstractmethod
def load_ratings_df(self):
pass
@abstractmethod
def maybe_download_raw_dataset(self):
pass
def load_dataset(self):
self.preprocess()
dataset_path = self._get_preprocessed_dataset_path()
dataset = pickle.load(dataset_path.open('rb'))
return dataset
def filter_triplets(self, df):
print('Filtering triplets')
if self.min_sc > 0:
item_sizes = df.groupby('sid').size()
good_items = item_sizes.index[item_sizes >= self.min_sc]
df = df[df['sid'].isin(good_items)]
if self.min_uc > 0:
user_sizes = df.groupby('uid').size()
good_users = user_sizes.index[user_sizes >= self.min_uc]
df = df[df['uid'].isin(good_users)]
return df
def densify_index(self, df):
print('Densifying index')
umap = {u: i for i, u in enumerate(set(df['uid']), start=1)}
smap = {s: i for i, s in enumerate(set(df['sid']), start=1)}
df['uid'] = df['uid'].map(umap)
df['sid'] = df['sid'].map(smap)
return df, umap, smap
def split_df(self, df, user_count):
if self.args.split == 'leave_one_out':
print('Splitting')
user_group = df.groupby('uid')
user2items = user_group.progress_apply(
lambda d: list(d.sort_values(by=['timestamp', 'sid'])['sid']))
train, val, test = {}, {}, {}
for i in range(user_count):
user = i + 1
items = user2items[user]
train[user], val[user], test[user] = items[:-2], items[-2:-1], items[-1:]
return train, val, test
else:
raise NotImplementedError
def _get_rawdata_root_path(self):
return Path(RAW_DATASET_ROOT_FOLDER)
def _get_rawdata_folder_path(self):
root = self._get_rawdata_root_path()
return root.joinpath(self.raw_code())
def _get_preprocessed_root_path(self):
root = self._get_rawdata_root_path()
return root.joinpath('preprocessed')
def _get_preprocessed_folder_path(self):
preprocessed_root = self._get_preprocessed_root_path()
folder_name = '{}_min_rating{}-min_uc{}-min_sc{}-split{}' \
.format(self.code(), self.min_rating, self.min_uc, self.min_sc, self.split)
return preprocessed_root.joinpath(folder_name)
def _get_preprocessed_dataset_path(self):
folder = self._get_preprocessed_folder_path()
return folder.joinpath('dataset.pkl')
class ML1MDataset(AbstractDataset):
@classmethod
def code(cls):
return 'ml-1m'
@classmethod
def url(cls):
return 'http://files.grouplens.org/datasets/movielens/ml-1m.zip'
@classmethod
def zip_file_content_is_folder(cls):
return True
@classmethod
def all_raw_file_names(cls):
return ['README',
'movies.dat',
'ratings.dat',
'users.dat']
@classmethod
def is_zipfile(cls):
return True
@classmethod
def is_7zfile(cls):
return False
def maybe_download_raw_dataset(self):
folder_path = self._get_rawdata_folder_path()
if folder_path.is_dir() and\
all(folder_path.joinpath(filename).is_file() for filename in self.all_raw_file_names()):
print('Raw data already exists. Skip downloading')
return
print("Raw file doesn't exist. Downloading...")
tmproot = Path(tempfile.mkdtemp())
tmpzip = tmproot.joinpath('file.zip')
tmpfolder = tmproot.joinpath('folder')
download(self.url(), tmpzip)
unzip(tmpzip, tmpfolder)
if self.zip_file_content_is_folder():
tmpfolder = tmpfolder.joinpath(os.listdir(tmpfolder)[0])
shutil.move(tmpfolder, folder_path)
shutil.rmtree(tmproot)
print()
def preprocess(self):
dataset_path = self._get_preprocessed_dataset_path()
if dataset_path.is_file():
print('Already preprocessed. Skip preprocessing')
return
if not dataset_path.parent.is_dir():
dataset_path.parent.mkdir(parents=True)
self.maybe_download_raw_dataset()
df = self.load_ratings_df()
df = self.filter_triplets(df)
df, umap, smap = self.densify_index(df)
train, val, test = self.split_df(df, len(umap))
dataset = {'train': train,
'val': val,
'test': test,
'umap': umap,
'smap': smap}
with dataset_path.open('wb') as f:
pickle.dump(dataset, f)
def load_ratings_df(self):
folder_path = self._get_rawdata_folder_path()
file_path = folder_path.joinpath('ratings.dat')
df = pd.read_csv(file_path, sep='::', header=None)
df.columns = ['uid', 'sid', 'rating', 'timestamp']
return df
DATASETS = {
ML1MDataset.code(): ML1MDataset
}
def dataset_factory(args):
dataset = DATASETS[args.dataset_code]
return dataset(args)
class AbstractDistillationDataset(metaclass=ABCMeta):
def __init__(self, args, bb_model_code, mode='random'):
self.args = args
self.bb_model_code = bb_model_code
self.mode = mode
assert self.mode in ['random', 'autoregressive', 'adversarial']
@classmethod
@abstractmethod
def code(cls):
pass
@classmethod
def raw_code(cls):
return cls.code()
def check_data_present(self):
dataset_path = self._get_distillation_dataset_path()
return dataset_path.is_file()
def load_dataset(self):
dataset_path = self._get_distillation_dataset_path()
if not dataset_path.is_file():
print('Dataset not found, please generate distillation dataset first')
return
dataset = pickle.load(dataset_path.open('rb'))
return dataset
def save_dataset(self, tokens, logits, candidates):
dataset_path = self._get_distillation_dataset_path()
if not dataset_path.parent.is_dir():
dataset_path.parent.mkdir(parents=True)
dataset = {'seqs': tokens,
'logits': logits,
'candidates': candidates}
with dataset_path.open('wb') as f:
pickle.dump(dataset, f)
def _get_rawdata_root_path(self):
return Path(GEN_DATASET_ROOT_FOLDER)
def _get_folder_path(self):
root = self._get_rawdata_root_path()
return root.joinpath(self.raw_code())
def _get_subfolder_path(self):
root = self._get_folder_path()
return root.joinpath(self.bb_model_code + '_' + str(self.args.num_generated_seqs))
def _get_distillation_dataset_path(self):
folder = self._get_subfolder_path()
return folder.joinpath(self.mode + '_dataset.pkl')
class ML1MDistillationDataset(AbstractDistillationDataset):
@classmethod
def code(cls):
return 'ml-1m'
class ML20MDistillationDataset(AbstractDistillationDataset):
@classmethod
def code(cls):
return 'ml-20m'
class BeautyDistillationDataset(AbstractDistillationDataset):
@classmethod
def code(cls):
return 'beauty'
class BeautyDenseDistillationDataset(AbstractDistillationDataset):
@classmethod
def code(cls):
return 'beauty_dense'
class GamesDistillationDataset(AbstractDistillationDataset):
@classmethod
def code(cls):
return 'games'
class SteamDistillationDataset(AbstractDistillationDataset):
@classmethod
def code(cls):
return 'steam'
class YooChooseDistillationDataset(AbstractDistillationDataset):
@classmethod
def code(cls):
return 'yoochoose'
class AbstractNegativeSampler(metaclass=ABCMeta):
def __init__(self, train, val, test, user_count, item_count, sample_size, seed, flag, save_folder):
self.train = train
self.val = val
self.test = test
self.user_count = user_count
self.item_count = item_count
self.sample_size = sample_size
self.seed = seed
self.flag = flag
self.save_folder = save_folder
@classmethod
@abstractmethod
def code(cls):
pass
@abstractmethod
def generate_negative_samples(self):
pass
def get_negative_samples(self):
savefile_path = self._get_save_path()
if savefile_path.is_file():
print('Negatives samples exist. Loading.')
seen_samples, negative_samples = pickle.load(savefile_path.open('rb'))
return seen_samples, negative_samples
print("Negative samples don't exist. Generating.")
seen_samples, negative_samples = self.generate_negative_samples()
with savefile_path.open('wb') as f:
pickle.dump([seen_samples, negative_samples], f)
return seen_samples, negative_samples
def _get_save_path(self):
folder = Path(self.save_folder)
filename = '{}-sample_size{}-seed{}-{}.pkl'.format(
self.code(), self.sample_size, self.seed, self.flag)
return folder.joinpath(filename)
class RandomNegativeSampler(AbstractNegativeSampler):
@classmethod
def code(cls):
return 'random'
def generate_negative_samples(self):
assert self.seed is not None, 'Specify seed for random sampling'
np.random.seed(self.seed)
num_samples = 2 * self.user_count * self.sample_size
all_samples = np.random.choice(self.item_count, num_samples) + 1
seen_samples = {}
negative_samples = {}
print('Sampling negative items randomly...')
j = 0
for i in trange(self.user_count):
user = i + 1
seen = set(self.train[user])
seen.update(self.val[user])
seen.update(self.test[user])
seen_samples[user] = seen
samples = []
while len(samples) < self.sample_size:
item = all_samples[j % num_samples]
j += 1
if item in seen or item in samples:
continue
samples.append(item)
negative_samples[user] = samples
return seen_samples, negative_samples
class PopularNegativeSampler(AbstractNegativeSampler):
@classmethod
def code(cls):
return 'popular'
def generate_negative_samples(self):
assert self.seed is not None, 'Specify seed for random sampling'
np.random.seed(self.seed)
popularity = self.items_by_popularity()
items = list(popularity.keys())
total = 0
for i in range(len(items)):
total += popularity[items[i]]
for i in range(len(items)):
popularity[items[i]] /= total
probs = list(popularity.values())
num_samples = 2 * self.user_count * self.sample_size
all_samples = np.random.choice(items, num_samples, p=probs)
seen_samples = {}
negative_samples = {}
print('Sampling negative items by popularity...')
j = 0
for i in trange(self.user_count):
user = i + 1
seen = set(self.train[user])
seen.update(self.val[user])
seen.update(self.test[user])
seen_samples[user] = seen
samples = []
while len(samples) < self.sample_size:
item = all_samples[j % num_samples]
j += 1
if item in seen or item in samples:
continue
samples.append(item)
negative_samples[user] = samples
return seen_samples, negative_samples
def items_by_popularity(self):
popularity = Counter()
self.users = sorted(self.train.keys())
for user in self.users:
popularity.update(self.train[user])
popularity.update(self.val[user])
popularity.update(self.test[user])
popularity = dict(popularity)
popularity = {k: v for k, v in sorted(popularity.items(), key=lambda item: item[1], reverse=True)}
return popularity
NEGATIVE_SAMPLERS = {
PopularNegativeSampler.code(): PopularNegativeSampler,
RandomNegativeSampler.code(): RandomNegativeSampler,
}
def negative_sampler_factory(code, train, val, test, user_count, item_count, sample_size, seed, flag, save_folder):
negative_sampler = NEGATIVE_SAMPLERS[code]
return negative_sampler(train, val, test, user_count, item_count, sample_size, seed, flag, save_folder)
class AbstractDataloader(metaclass=ABCMeta):
def __init__(self, args, dataset):
self.args = args
self.rng = random.Random()
self.save_folder = dataset._get_preprocessed_folder_path()
dataset = dataset.load_dataset()
self.train = dataset['train']
self.val = dataset['val']
self.test = dataset['test']
self.umap = dataset['umap']
self.smap = dataset['smap']
self.user_count = len(self.umap)
self.item_count = len(self.smap)
@classmethod
@abstractmethod
def code(cls):
pass
@abstractmethod
def get_pytorch_dataloaders(self):
pass
class RNNDataloader():
def __init__(self, args, dataset):
self.args = args
self.rng = random.Random()
self.save_folder = dataset._get_preprocessed_folder_path()
dataset = dataset.load_dataset()
self.train = dataset['train']
self.val = dataset['val']
self.test = dataset['test']
self.umap = dataset['umap']
self.smap = dataset['smap']
self.user_count = len(self.umap)
self.item_count = len(self.smap)
args.num_items = len(self.smap)
self.max_len = args.bert_max_len
val_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,
self.train, self.val, self.test,
self.user_count, self.item_count,
args.test_negative_sample_size,
args.test_negative_sampling_seed,
'val', self.save_folder)
test_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,
self.train, self.val, self.test,
self.user_count, self.item_count,
args.test_negative_sample_size,
args.test_negative_sampling_seed,
'test', self.save_folder)
self.seen_samples, self.val_negative_samples = val_negative_sampler.get_negative_samples()
self.seen_samples, self.test_negative_samples = test_negative_sampler.get_negative_samples()
@classmethod
def code(cls):
return 'rnn'
def get_pytorch_dataloaders(self):
train_loader = self._get_train_loader()
val_loader = self._get_val_loader()
test_loader = self._get_test_loader()
return train_loader, val_loader, test_loader
def _get_train_loader(self):
dataset = self._get_train_dataset()
dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,
shuffle=True, pin_memory=True)
return dataloader
def _get_train_dataset(self):
dataset = RNNTrainDataset(
self.train, self.max_len)
return dataset
def _get_val_loader(self):
return self._get_eval_loader(mode='val')
def _get_test_loader(self):
return self._get_eval_loader(mode='test')
def _get_eval_loader(self, mode):
batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size
dataset = self._get_eval_dataset(mode)
dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
return dataloader
def _get_eval_dataset(self, mode):
if mode == 'val':
dataset = RNNValidDataset(self.train, self.val, self.max_len, self.val_negative_samples)
elif mode == 'test':
dataset = RNNTestDataset(self.train, self.val, self.test, self.max_len, self.test_negative_samples)
return dataset
class RNNTrainDataset(data_utils.Dataset):
def __init__(self, u2seq, max_len):
# self.u2seq = u2seq
# self.users = sorted(self.u2seq.keys())
self.max_len = max_len
self.all_seqs = []
self.all_labels = []
for u in sorted(u2seq.keys()):
seq = u2seq[u]
for i in range(1, len(seq)):
self.all_seqs += [seq[:-i]]
self.all_labels += [seq[-i]]
assert len(self.all_seqs) == len(self.all_labels)
def __len__(self):
return len(self.all_seqs)
def __getitem__(self, index):
tokens = self.all_seqs[index][-self.max_len:]
length = len(tokens)
tokens = tokens + [0] * (self.max_len - length)
return torch.LongTensor(tokens), torch.LongTensor([length]), torch.LongTensor([self.all_labels[index]])
class RNNValidDataset(data_utils.Dataset):
def __init__(self, u2seq, u2answer, max_len, negative_samples, valid_users=None):
self.u2seq = u2seq # train
if not valid_users:
self.users = sorted(self.u2seq.keys())
else:
self.users = valid_users
self.users = sorted(self.u2seq.keys())
self.u2answer = u2answer
self.max_len = max_len
self.negative_samples = negative_samples
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
tokens = self.u2seq[user][-self.max_len:]
length = len(tokens)
tokens = tokens + [0] * (self.max_len - length)
answer = self.u2answer[user]
negs = self.negative_samples[user]
candidates = answer + negs
labels = [1] * len(answer) + [0] * len(negs)
return torch.LongTensor(tokens), torch.LongTensor([length]), torch.LongTensor(candidates), torch.LongTensor(labels)
class RNNTestDataset(data_utils.Dataset):
def __init__(self, u2seq, u2val, u2answer, max_len, negative_samples, test_users=None):
self.u2seq = u2seq # train
self.u2val = u2val # val
if not test_users:
self.users = sorted(self.u2seq.keys())
else:
self.users = test_users
self.users = sorted(self.u2seq.keys())
self.u2answer = u2answer # test
self.max_len = max_len
self.negative_samples = negative_samples
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
tokens = (self.u2seq[user] + self.u2val[user])[-self.max_len:] # append validation item after train seq
length = len(tokens)
tokens = tokens + [0] * (self.max_len - length)
answer = self.u2answer[user]
negs = self.negative_samples[user]
candidates = answer + negs
labels = [1] * len(answer) + [0] * len(negs)
return torch.LongTensor(tokens), torch.LongTensor([length]), torch.LongTensor(candidates), torch.LongTensor(labels)
class SASDataloader():
def __init__(self, args, dataset):
self.args = args
self.rng = random.Random()
self.save_folder = dataset._get_preprocessed_folder_path()
dataset = dataset.load_dataset()
self.train = dataset['train']
self.val = dataset['val']
self.test = dataset['test']
self.umap = dataset['umap']
self.smap = dataset['smap']
self.user_count = len(self.umap)
self.item_count = len(self.smap)
args.num_items = self.item_count
self.max_len = args.bert_max_len
self.mask_prob = args.bert_mask_prob
self.max_predictions = args.bert_max_predictions
self.sliding_size = args.sliding_window_size
self.CLOZE_MASK_TOKEN = self.item_count + 1
val_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,
self.train, self.val, self.test,
self.user_count, self.item_count,
args.test_negative_sample_size,
args.test_negative_sampling_seed,
'val', self.save_folder)
test_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,
self.train, self.val, self.test,
self.user_count, self.item_count,
args.test_negative_sample_size,
args.test_negative_sampling_seed,
'test', self.save_folder)
self.seen_samples, self.val_negative_samples = val_negative_sampler.get_negative_samples()
self.seen_samples, self.test_negative_samples = test_negative_sampler.get_negative_samples()
@classmethod
def code(cls):
return 'sas'
def get_pytorch_dataloaders(self):
train_loader = self._get_train_loader()
val_loader = self._get_val_loader()
test_loader = self._get_test_loader()
return train_loader, val_loader, test_loader
def _get_train_loader(self):
dataset = self._get_train_dataset()
dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,
shuffle=True, pin_memory=True)
return dataloader
def _get_train_dataset(self):
dataset = SASTrainDataset(
self.train, self.max_len, self.sliding_size, self.seen_samples, self.item_count, self.rng)
return dataset
def _get_val_loader(self):
return self._get_eval_loader(mode='val')
def _get_test_loader(self):
return self._get_eval_loader(mode='test')
def _get_eval_loader(self, mode):
batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size
dataset = self._get_eval_dataset(mode)
dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
return dataloader
def _get_eval_dataset(self, mode):
if mode == 'val':
dataset = SASValidDataset(self.train, self.val, self.max_len, self.val_negative_samples)
elif mode == 'test':
dataset = SASTestDataset(self.train, self.val, self.test, self.max_len, self.test_negative_samples)
return dataset
class SASTrainDataset(data_utils.Dataset):
def __init__(self, u2seq, max_len, sliding_size, seen_samples, num_items, rng):
# self.u2seq = u2seq
# self.users = sorted(self.u2seq.keys())
self.max_len = max_len
self.sliding_step = int(sliding_size * max_len)
self.num_items = num_items
self.rng = rng
assert self.sliding_step > 0
self.all_seqs = []
self.seen_samples = []
for u in sorted(u2seq.keys()):
seq = u2seq[u]
neg = seen_samples[u]
if len(seq) < self.max_len + self.sliding_step:
self.all_seqs.append(seq)
self.seen_samples.append(neg)
else:
start_idx = range(len(seq) - max_len, -1, -self.sliding_step)
self.all_seqs = self.all_seqs + [seq[i:i + max_len] for i in start_idx]
self.seen_samples = self.seen_samples + [neg for i in start_idx]
def __len__(self):
return len(self.all_seqs)
def __getitem__(self, index):
seq = self.all_seqs[index]
labels = seq[-self.max_len:]
tokens = seq[:-1][-self.max_len:]
neg = []
mask_len = self.max_len - len(tokens)
tokens = [0] * mask_len + tokens
mask_len = self.max_len - len(labels)
while len(neg) < len(labels):
item = self.rng.randint(1, self.num_items)
if item in self.seen_samples[index] or item in neg:
continue
neg.append(item)
labels = [0] * mask_len + labels
neg = [0] * mask_len + neg
return torch.LongTensor(tokens), torch.LongTensor(labels), torch.LongTensor(neg)
class SASValidDataset(data_utils.Dataset):
def __init__(self, u2seq, u2answer, max_len, negative_samples, valid_users=None):
self.u2seq = u2seq # train
if not valid_users:
self.users = sorted(self.u2seq.keys())
else:
self.users = valid_users
self.users = sorted(self.u2seq.keys())
self.u2answer = u2answer
self.max_len = max_len
self.negative_samples = negative_samples
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
seq = self.u2seq[user]
answer = self.u2answer[user]
negs = self.negative_samples[user]
candidates = answer + negs
labels = [1] * len(answer) + [0] * len(negs)
# no mask token here
seq = seq[-self.max_len:]
padding_len = self.max_len - len(seq)
seq = [0] * padding_len + seq
return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)
class SASTestDataset(data_utils.Dataset):
def __init__(self, u2seq, u2val, u2answer, max_len, negative_samples, test_users=None):
self.u2seq = u2seq # train
self.u2val = u2val # val
if not test_users:
self.users = sorted(self.u2seq.keys())
else:
self.users = test_users
self.users = sorted(self.u2seq.keys())
self.u2answer = u2answer # test
self.max_len = max_len
self.negative_samples = negative_samples
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
seq = self.u2seq[user] + self.u2val[user] # append validation item after train seq
answer = self.u2answer[user]
negs = self.negative_samples[user]
candidates = answer + negs
labels = [1] * len(answer) + [0] * len(negs)
# no mask token here
seq = seq[-self.max_len:]
padding_len = self.max_len - len(seq)
seq = [0] * padding_len + seq
return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)
class BERTDataloader():
def __init__(self, args, dataset):
self.args = args
self.rng = random.Random()
self.save_folder = dataset._get_preprocessed_folder_path()
dataset = dataset.load_dataset()
self.train = dataset['train']
self.val = dataset['val']
self.test = dataset['test']
self.umap = dataset['umap']
self.smap = dataset['smap']
self.user_count = len(self.umap)
self.item_count = len(self.smap)
args.num_items = self.item_count
self.max_len = args.bert_max_len
self.mask_prob = args.bert_mask_prob
self.max_predictions = args.bert_max_predictions
self.sliding_size = args.sliding_window_size
self.CLOZE_MASK_TOKEN = self.item_count + 1
val_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,
self.train, self.val, self.test,
self.user_count, self.item_count,
args.test_negative_sample_size,
args.test_negative_sampling_seed,
'val', self.save_folder)
test_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,
self.train, self.val, self.test,
self.user_count, self.item_count,
args.test_negative_sample_size,
args.test_negative_sampling_seed,
'test', self.save_folder)
self.seen_samples, self.val_negative_samples = val_negative_sampler.get_negative_samples()
self.seen_samples, self.test_negative_samples = test_negative_sampler.get_negative_samples()
@classmethod
def code(cls):
return 'bert'
def get_pytorch_dataloaders(self):
train_loader = self._get_train_loader()
val_loader = self._get_val_loader()
test_loader = self._get_test_loader()
return train_loader, val_loader, test_loader
def _get_train_loader(self):
dataset = self._get_train_dataset()
dataloader = data_utils.DataLoader(dataset, batch_size=self.args.train_batch_size,
shuffle=True, pin_memory=True)
return dataloader
def _get_train_dataset(self):
dataset = BERTTrainDataset(
self.train, self.max_len, self.mask_prob, self.max_predictions, self.sliding_size, self.CLOZE_MASK_TOKEN, self.item_count, self.rng)
return dataset
def _get_val_loader(self):
return self._get_eval_loader(mode='val')
def _get_test_loader(self):
return self._get_eval_loader(mode='test')
def _get_eval_loader(self, mode):
batch_size = self.args.val_batch_size if mode == 'val' else self.args.test_batch_size
dataset = self._get_eval_dataset(mode)
dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
return dataloader
def _get_eval_dataset(self, mode):
if mode == 'val':
dataset = BERTValidDataset(self.train, self.val, self.max_len, self.CLOZE_MASK_TOKEN, self.val_negative_samples)
elif mode == 'test':
dataset = BERTTestDataset(self.train, self.val, self.test, self.max_len, self.CLOZE_MASK_TOKEN, self.test_negative_samples)
return dataset
class BERTTrainDataset(data_utils.Dataset):
def __init__(self, u2seq, max_len, mask_prob, max_predictions, sliding_size, mask_token, num_items, rng):
# self.u2seq = u2seq
# self.users = sorted(self.u2seq.keys())
self.max_len = max_len
self.mask_prob = mask_prob
self.max_predictions = max_predictions
self.sliding_step = int(sliding_size * max_len)
self.mask_token = mask_token
self.num_items = num_items
self.rng = rng
assert self.sliding_step > 0
self.all_seqs = []
for u in sorted(u2seq.keys()):
seq = u2seq[u]
if len(seq) < self.max_len + self.sliding_step:
self.all_seqs.append(seq)
else:
start_idx = range(len(seq) - max_len, -1, -self.sliding_step)
self.all_seqs = self.all_seqs + [seq[i:i + max_len] for i in start_idx]
def __len__(self):
return len(self.all_seqs)
# return len(self.users)
def __getitem__(self, index):
# user = self.users[index]
# seq = self._getseq(user)
seq = self.all_seqs[index]
tokens = []
labels = []
covered_items = set()
for i in range(len(seq)):
s = seq[i]
if (len(covered_items) >= self.max_predictions) or (s in covered_items):
tokens.append(s)
labels.append(0)
continue
temp_mask_prob = self.mask_prob
if i == (len(seq) - 1):
temp_mask_prob += 0.1 * (1 - self.mask_prob)
prob = self.rng.random()
if prob < temp_mask_prob:
covered_items.add(s)
prob /= temp_mask_prob
if prob < 0.8:
tokens.append(self.mask_token)
elif prob < 0.9:
tokens.append(self.rng.randint(1, self.num_items))
else:
tokens.append(s)
labels.append(s)
else:
tokens.append(s)
labels.append(0)
tokens = tokens[-self.max_len:]
labels = labels[-self.max_len:]
mask_len = self.max_len - len(tokens)
tokens = [0] * mask_len + tokens
labels = [0] * mask_len + labels
return torch.LongTensor(tokens), torch.LongTensor(labels)
def _getseq(self, user):
return self.u2seq[user]
class BERTValidDataset(data_utils.Dataset):
def __init__(self, u2seq, u2answer, max_len, mask_token, negative_samples, valid_users=None):
self.u2seq = u2seq # train
if not valid_users:
self.users = sorted(self.u2seq.keys())
else:
self.users = valid_users
self.u2answer = u2answer
self.max_len = max_len
self.mask_token = mask_token
self.negative_samples = negative_samples
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
seq = self.u2seq[user]
answer = self.u2answer[user]
negs = self.negative_samples[user]
candidates = answer + negs
labels = [1] * len(answer) + [0] * len(negs)
seq = seq + [self.mask_token]
seq = seq[-self.max_len:]
padding_len = self.max_len - len(seq)
seq = [0] * padding_len + seq
return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)
class BERTTestDataset(data_utils.Dataset):
def __init__(self, u2seq, u2val, u2answer, max_len, mask_token, negative_samples, test_users=None):
self.u2seq = u2seq # train
self.u2val = u2val # val
if not test_users:
self.users = sorted(self.u2seq.keys())
else:
self.users = test_users
self.users = sorted(self.u2seq.keys())
self.u2answer = u2answer # test
self.max_len = max_len
self.mask_token = mask_token
self.negative_samples = negative_samples
def __len__(self):
return len(self.users)
def __getitem__(self, index):
user = self.users[index]
seq = self.u2seq[user] + self.u2val[user] # append validation item after train seq
answer = self.u2answer[user]
negs = self.negative_samples[user]
candidates = answer + negs
labels = [1] * len(answer) + [0] * len(negs)
seq = seq + [self.mask_token]
seq = seq[-self.max_len:]
padding_len = self.max_len - len(seq)
seq = [0] * padding_len + seq
return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)
def dataloader_factory(args):
dataset = dataset_factory(args)
if args.model_code == 'bert':
dataloader = BERTDataloader(args, dataset)
elif args.model_code == 'sas':
dataloader = SASDataloader(args, dataset)
else:
dataloader = RNNDataloader(args, dataset)
train, val, test = dataloader.get_pytorch_dataloaders()
return train, val, test
DIS_DATASETS = {
ML1MDistillationDataset.code(): ML1MDistillationDataset
}
def dis_dataset_factory(args, bb_model_code, mode='random'):
dataset = DIS_DATASETS[args.dataset_code]
return dataset(args, bb_model_code, mode)
def dis_train_loader_factory(args, bb_model_code, mode='random'):
dataset = dis_dataset_factory(args, bb_model_code, mode)
if dataset.check_data_present():
dataloader = DistillationLoader(args, dataset)
train, val = dataloader.get_loaders()
return train, val
else:
return None
class DistillationLoader():
def __init__(self, args, dataset):
self.args = args
dataset = dataset.load_dataset()
self.tokens = dataset['seqs']
self.logits = dataset['logits']
self.candidates = dataset['candidates']
@classmethod
def code(cls):
return 'distillation_loader'
def get_loaders(self):
train, val = self._get_datasets()
train_loader = data_utils.DataLoader(train, batch_size=self.args.train_batch_size,
shuffle=True, pin_memory=True)
val_loader = data_utils.DataLoader(val, batch_size=self.args.train_batch_size,
shuffle=True, pin_memory=True)
return train_loader, val_loader
def _get_datasets(self):
if self.args.model_code == 'bert':
train_dataset = BERTDistillationTrainingDataset(self.args, self.tokens, self.logits, self.candidates)
valid_dataset = BERTDistillationValidationDataset(self.args, self.tokens, self.logits, self.candidates)
elif self.args.model_code == 'sas':
train_dataset = SASDistillationTrainingDataset(self.args, self.tokens, self.logits, self.candidates)
valid_dataset = SASDistillationValidationDataset(self.args, self.tokens, self.logits, self.candidates)
elif self.args.model_code == 'narm':
train_dataset = NARMDistillationTrainingDataset(self.args, self.tokens, self.logits, self.candidates)
valid_dataset = NARMDistillationValidationDataset(self.args, self.tokens, self.logits, self.candidates)
return train_dataset, valid_dataset
class BERTDistillationTrainingDataset(data_utils.Dataset):
def __init__(self, args, tokens, labels, candidates):
self.max_len = args.bert_max_len
self.mask_prob = args.bert_mask_prob
self.max_predictions = args.bert_max_predictions
self.num_items = args.num_items
self.mask_token = args.num_items + 1
self.all_seqs = []
self.all_labels = []
self.all_candidates = []
for i in range(len(tokens)):
seq = tokens[i]
label = labels[i]
candidate = candidates[i]
for j in range(0, len(seq)-1):
masked_seq = seq[:j+1] + [self.mask_token]
self.all_seqs += [masked_seq]
self.all_labels += [label[j]]
self.all_candidates += [candidate[j]]
assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)
def __len__(self):
return len(self.all_seqs)
def __getitem__(self, index):
masked_seq = self.all_seqs[index]
masked_seq = masked_seq[-self.max_len:]
mask_len = self.max_len - len(masked_seq)
masked_seq = [0] * mask_len + masked_seq
return torch.LongTensor(masked_seq), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])
class BERTDistillationValidationDataset(data_utils.Dataset):
def __init__(self, args, tokens, labels, candidates):
self.max_len = args.bert_max_len
self.mask_prob = args.bert_mask_prob
self.max_predictions = args.bert_max_predictions
self.num_items = args.num_items
self.mask_token = args.num_items + 1
self.all_seqs = []
self.all_labels = []
self.all_candidates = []
for i in range(len(tokens)):
seq = tokens[i]
label = labels[i]
candidate = candidates[i]
self.all_seqs += [seq + [self.mask_token]]
self.all_labels += [[1] + [0] * (len(label[-1]) - 1)]
self.all_candidates += [candidate[-1]]
assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)
def __len__(self):
return len(self.all_seqs)
def __getitem__(self, index):
masked_seq = self.all_seqs[index]
masked_seq = masked_seq[-self.max_len:]
mask_len = self.max_len - len(masked_seq)
masked_seq = [0] * mask_len + masked_seq
return torch.LongTensor(masked_seq), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])
class SASDistillationTrainingDataset(data_utils.Dataset):
def __init__(self, args, tokens, labels, candidates):
self.max_len = args.bert_max_len
self.all_seqs = []
self.all_labels = []
self.all_candidates = []
for i in range(len(tokens)):
seq = tokens[i]
label = labels[i]
candidate = candidates[i]
for j in range(1, len(seq)):
self.all_seqs += [seq[:-j]]
self.all_labels += [label[-j-1]]
self.all_candidates += [candidate[-j-1]]
assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)
def __len__(self):
return len(self.all_seqs)
def __getitem__(self, index):
tokens = self.all_seqs[index][-self.max_len:]
mask_len = self.max_len - len(tokens)
tokens = [0] * mask_len + tokens
return torch.LongTensor(tokens), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])
class SASDistillationValidationDataset(data_utils.Dataset):
def __init__(self, args, tokens, labels, candidates):
self.max_len = args.bert_max_len
self.all_seqs = []
self.all_labels = []
self.all_candidates = []
for i in range(len(tokens)):
seq = tokens[i]
label = labels[i]
candidate = candidates[i]
self.all_seqs += [seq]
self.all_labels += [[1] + [0] * (len(label[-1]) - 1)]
self.all_candidates += [candidate[-1]]
assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)
def __len__(self):
return len(self.all_seqs)
def __getitem__(self, index):
tokens = self.all_seqs[index][-self.max_len:]
mask_len = self.max_len - len(tokens)
tokens = [0] * mask_len + tokens
return torch.LongTensor(tokens), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])
class NARMDistillationTrainingDataset(data_utils.Dataset):
def __init__(self, args, tokens, labels, candidates):
self.max_len = args.bert_max_len
self.all_seqs = []
self.all_labels = []
self.all_candidates = []
for i in range(len(tokens)):
seq = tokens[i]
label = labels[i]
candidate = candidates[i]
for j in range(1, len(seq)):
self.all_seqs += [seq[:-j]]
self.all_labels += [label[-j-1]]
self.all_candidates += [candidate[-j-1]]
assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)
def __len__(self):
return len(self.all_seqs)
def __getitem__(self, index):
tokens = self.all_seqs[index][-self.max_len:]
length = len(tokens)
tokens = tokens + [0] * (self.max_len - length)
return torch.LongTensor(tokens), torch.LongTensor([length]), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])
class NARMDistillationValidationDataset(data_utils.Dataset):
def __init__(self, args, tokens, labels, candidates):
self.max_len = args.bert_max_len
self.all_seqs = []
self.all_labels = []
self.all_candidates = []
for i in range(len(tokens)):
seq = tokens[i]
label = labels[i]
candidate = candidates[i]
self.all_seqs += [seq]
self.all_labels += [[1] + [0] * (len(label[-1]) - 1)]
self.all_candidates += [candidate[-1]]
assert len(self.all_seqs) == len(self.all_labels) == len(self.all_candidates)
def __len__(self):
return len(self.all_seqs)
def __getitem__(self, index):
tokens = self.all_seqs[index][-self.max_len:]
length = len(tokens)
tokens = tokens + [0] * (self.max_len - length)
return torch.LongTensor(tokens), torch.LongTensor([length]), torch.LongTensor(self.all_candidates[index]), torch.tensor(self.all_labels[index])
class TokenEmbedding(nn.Embedding):
def __init__(self, vocab_size, embed_size=512):
super().__init__(vocab_size, embed_size, padding_idx=0)
class PositionalEmbedding(nn.Module):
def __init__(self, max_len, d_model):
super().__init__()
self.d_model = d_model
self.pe = nn.Embedding(max_len+1, d_model)
def forward(self, x):
pose = (x > 0) * (x > 0).sum(dim=-1).unsqueeze(1).repeat(1, x.size(-1))
pose += torch.arange(start=-(x.size(1)-1), end=1, step=1, device=x.device)
pose = pose * (x > 0)
return self.pe(pose)
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.activation = GELU()
def forward(self, x):
return self.w_2(self.activation(self.w_1(x)))
# layer norm
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(features))
self.bias = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.weight * (x - mean) / (std + self.eps) + self.bias
# layer norm and dropout (dropout and then layer norm)
class SublayerConnection(nn.Module):
def __init__(self, size, dropout):
super().__init__()
self.layer_norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
# return x + self.dropout(sublayer(self.norm(x))) # original implementation
return self.layer_norm(x + self.dropout(sublayer(x))) # BERT4Rec implementation
class Attention(nn.Module):
def forward(self, query, key, value, mask=None, dropout=None, sas=False):
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(query.size(-1))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
if sas:
direction_mask = torch.ones_like(scores)
direction_mask = torch.tril(direction_mask)
scores = scores.masked_fill(direction_mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, head_size=None, dropout=0.1):
super().__init__()
assert d_model % h == 0
self.h = h
self.d_k = d_model // h
if head_size is not None:
self.head_size = head_size
else:
self.head_size = d_model // h
self.linear_layers = nn.ModuleList(
[nn.Linear(d_model, self.h * self.head_size) for _ in range(3)])
self.attention = Attention()
self.dropout = nn.Dropout(p=dropout)
self.output_linear = nn.Linear(self.h * self.head_size, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1) do all the linear projections in batch from d_model => h x d_k
query, key, value = [l(x).view(batch_size, -1, self.h, self.head_size).transpose(1, 2)
for l, x in zip(self.linear_layers, (query, key, value))]
# 2) apply attention on all the projected vectors in batch.
x, attn = self.attention(
query, key, value, mask=mask, dropout=self.dropout)
# 3) "concat" using a view and apply a final linear.
x = x.transpose(1, 2).contiguous().view(
batch_size, -1, self.h * self.head_size)
return self.output_linear(x)
class TransformerBlock(nn.Module):
def __init__(self, hidden, attn_heads, head_size, feed_forward_hidden, dropout, attn_dropout=0.1):
super().__init__()
self.attention = MultiHeadedAttention(
h=attn_heads, d_model=hidden, head_size=head_size, dropout=attn_dropout)
self.feed_forward = PositionwiseFeedForward(
d_model=hidden, d_ff=feed_forward_hidden)
self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
def forward(self, x, mask):
x = self.input_sublayer(
x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
x = self.output_sublayer(x, self.feed_forward)
return x
class SASMultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, head_size=None, dropout=0.1):
super().__init__()
assert d_model % h == 0
self.h = h
self.d_k = d_model // h
if head_size is not None:
self.head_size = head_size
else:
self.head_size = d_model // h
self.linear_layers = nn.ModuleList(
[nn.Linear(d_model, self.h * self.head_size) for _ in range(3)])
self.attention = Attention()
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = LayerNorm(d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1) do all the linear projections in batch from d_model => h x d_k
query_, key_, value_ = [l(x).view(batch_size, -1, self.h, self.head_size).transpose(1, 2)
for l, x in zip(self.linear_layers, (query, key, value))]
# 2) apply attention on all the projected vectors in batch.
x, attn = self.attention(
query_, key_, value_, mask=mask, dropout=self.dropout, sas=True)
# 3) "concat" using a view and apply a final linear.
x = x.transpose(1, 2).contiguous().view(
batch_size, -1, self.h * self.head_size)
return self.layer_norm(x + query)
class SASPositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.activation = nn.ReLU()
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.layer_norm = LayerNorm(d_model)
def forward(self, x):
x_ = self.dropout(self.activation(self.conv1(x.permute(0, 2, 1))))
return self.layer_norm(self.dropout(self.conv2(x_)).permute(0, 2, 1) + x)
class SASTransformerBlock(nn.Module):
def __init__(self, hidden, attn_heads, head_size, feed_forward_hidden, dropout, attn_dropout=0.1):
super().__init__()
self.layer_norm = LayerNorm(hidden)
self.attention = SASMultiHeadedAttention(
h=attn_heads, d_model=hidden, head_size=head_size, dropout=attn_dropout)
self.feed_forward = SASPositionwiseFeedForward(
d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
def forward(self, x, mask):
x = self.attention(self.layer_norm(x), x, x, mask)
x = self.feed_forward(x)
return x
class SASRec(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.embedding = SASEmbedding(self.args)
self.model = SASModel(self.args)
self.truncated_normal_init()
def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):
with torch.no_grad():
l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.
u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.
for n, p in self.model.named_parameters():
if not 'layer_norm' in n:
p.uniform_(2 * l - 1, 2 * u - 1)
p.erfinv_()
p.mul_(std * math.sqrt(2.))
p.add_(mean)
def forward(self, x):
x, mask = self.embedding(x)
scores = self.model(x, self.embedding.token.weight, mask)
return scores
class SASEmbedding(nn.Module):
def __init__(self, args):
super().__init__()
vocab_size = args.num_items + 1
hidden = args.bert_hidden_units
max_len = args.bert_max_len
dropout = args.bert_dropout
self.token = TokenEmbedding(
vocab_size=vocab_size, embed_size=hidden)
self.position = PositionalEmbedding(
max_len=max_len, d_model=hidden)
self.dropout = nn.Dropout(p=dropout)
def get_mask(self, x):
if len(x.shape) > 2:
x = torch.ones(x.shape[:2]).to(x.device)
return (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
def forward(self, x):
mask = self.get_mask(x)
if len(x.shape) > 2:
pos = self.position(torch.ones(x.shape[:2]).to(x.device))
x = torch.matmul(x, self.token.weight) + pos
else:
x = self.token(x) + self.position(x)
return self.dropout(x), mask
class SASModel(nn.Module):
def __init__(self, args):
super().__init__()
hidden = args.bert_hidden_units
heads = args.bert_num_heads
head_size = args.bert_head_size
dropout = args.bert_dropout
attn_dropout = args.bert_attn_dropout
layers = args.bert_num_blocks
self.transformer_blocks = nn.ModuleList([SASTransformerBlock(
hidden, heads, head_size, hidden * 4, dropout, attn_dropout) for _ in range(layers)])
def forward(self, x, embedding_weight, mask):
for transformer in self.transformer_blocks:
x = transformer.forward(x, mask)
scores = torch.matmul(x, embedding_weight.permute(1, 0))
return scores
class BERT(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.embedding = BERTEmbedding(self.args)
self.model = BERTModel(self.args)
self.truncated_normal_init()
def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):
with torch.no_grad():
l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.
u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.
for n, p in self.model.named_parameters():
if not 'layer_norm' in n:
p.uniform_(2 * l - 1, 2 * u - 1)
p.erfinv_()
p.mul_(std * math.sqrt(2.))
p.add_(mean)
def forward(self, x):
x, mask = self.embedding(x)
scores = self.model(x, self.embedding.token.weight, mask)
return scores
class BERTEmbedding(nn.Module):
def __init__(self, args):
super().__init__()
vocab_size = args.num_items + 2
hidden = args.bert_hidden_units
max_len = args.bert_max_len
dropout = args.bert_dropout
self.token = TokenEmbedding(
vocab_size=vocab_size, embed_size=hidden)
self.position = PositionalEmbedding(
max_len=max_len, d_model=hidden)
self.layer_norm = LayerNorm(features=hidden)
self.dropout = nn.Dropout(p=dropout)
def get_mask(self, x):
if len(x.shape) > 2:
x = torch.ones(x.shape[:2]).to(x.device)
return (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
def forward(self, x):
mask = self.get_mask(x)
if len(x.shape) > 2:
pos = self.position(torch.ones(x.shape[:2]).to(x.device))
x = torch.matmul(x, self.token.weight) + pos
else:
x = self.token(x) + self.position(x)
return self.dropout(self.layer_norm(x)), mask
class BERTModel(nn.Module):
def __init__(self, args):
super().__init__()
hidden = args.bert_hidden_units
heads = args.bert_num_heads
head_size = args.bert_head_size
dropout = args.bert_dropout
attn_dropout = args.bert_attn_dropout
layers = args.bert_num_blocks
self.transformer_blocks = nn.ModuleList([TransformerBlock(
hidden, heads, head_size, hidden * 4, dropout, attn_dropout) for _ in range(layers)])
self.linear = nn.Linear(hidden, hidden)
self.bias = torch.nn.Parameter(torch.zeros(args.num_items + 2))
self.bias.requires_grad = True
self.activation = GELU()
def forward(self, x, embedding_weight, mask):
for transformer in self.transformer_blocks:
x = transformer.forward(x, mask)
x = self.activation(self.linear(x))
scores = torch.matmul(x, embedding_weight.permute(1, 0)) + self.bias
return scores
class NARM(nn.Module):
def __init__(self, args):
super(NARM, self).__init__()
self.args = args
self.embedding = NARMEmbedding(self.args)
self.model = NARMModel(self.args)
self.truncated_normal_init()
def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):
with torch.no_grad():
l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.
u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.
for p in self.parameters():
p.uniform_(2 * l - 1, 2 * u - 1)
p.erfinv_()
p.mul_(std * math.sqrt(2.))
p.add_(mean)
def forward(self, x, lengths):
x, mask = self.embedding(x, lengths)
scores = self.model(x, self.embedding.token.weight, lengths, mask)
return scores
class NARMEmbedding(nn.Module):
def __init__(self, args):
super().__init__()
vocab_size = args.num_items + 1
embed_size = args.bert_hidden_units
self.token = nn.Embedding(vocab_size, embed_size)
self.embed_dropout = nn.Dropout(args.bert_dropout)
def get_mask(self, x, lengths):
if len(x.shape) > 2:
return torch.ones(x.shape[:2])[:, :max(lengths)].to(x.device)
else:
return ((x > 0) * 1)[:, :max(lengths)]
def forward(self, x, lengths):
mask = self.get_mask(x, lengths)
if len(x.shape) > 2:
x = torch.matmul(x, self.token.weight)
else:
x = self.token(x)
return self.embed_dropout(x), mask
class NARMModel(nn.Module):
def __init__(self, args):
super().__init__()
embed_size = args.bert_hidden_units
hidden_size = 2 * args.bert_hidden_units
self.gru = nn.GRU(embed_size, hidden_size, num_layers=1, batch_first=True)
self.a_global = nn.Linear(hidden_size, hidden_size, bias=False)
self.a_local = nn.Linear(hidden_size, hidden_size, bias=False)
self.act = HardSigmoid()
self.v_vector = nn.Linear(hidden_size, 1, bias=False)
self.proj_dropout = nn.Dropout(args.bert_attn_dropout)
self.b_vetor = nn.Linear(embed_size, 2 * hidden_size, bias=False)
def forward(self, x, embedding_weight, lengths, mask):
x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
gru_out, hidden = self.gru(x)
gru_out, _ = pad_packed_sequence(gru_out, batch_first=True)
c_global = hidden[-1]
state2 = self.a_local(gru_out)
state1 = self.a_global(c_global).unsqueeze(1).expand_as(state2)
state1 = mask.unsqueeze(2).expand_as(state2) * state1
alpha = self.act(state1 + state2).view(-1, state1.size(-1))
attn = self.v_vector(alpha).view(mask.size())
attn = F.softmax(attn.masked_fill(mask == 0, -1e9), dim=-1)
c_local = torch.sum(attn.unsqueeze(2).expand_as(gru_out) * gru_out, 1)
proj = self.proj_dropout(torch.cat([c_global, c_local], 1))
scores = torch.matmul(proj, self.b_vetor(embedding_weight).permute(1, 0))
return scores
class HardSigmoid(nn.Module):
def forward(self, x):
return torch.clamp((x / 6 + 0.5), min=0., max=1.)
class RNNTrainer(metaclass=ABCMeta):
def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):
self.args = args
self.device = args.device
self.model = model.to(self.device)
self.is_parallel = args.num_gpu > 1
if self.is_parallel:
self.model = nn.DataParallel(self.model)
self.num_epochs = args.num_epochs
self.metric_ks = args.metric_ks
self.best_metric = args.best_metric
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.optimizer = self._create_optimizer()
if args.enable_lr_schedule:
if args.enable_lr_warmup:
self.lr_scheduler = self.get_linear_schedule_with_warmup(
self.optimizer, args.warmup_steps, len(train_loader) * self.num_epochs)
else:
self.lr_scheduler = optim.lr_scheduler.StepLR(
self.optimizer, step_size=args.decay_step, gamma=args.gamma)
self.export_root = export_root
self.writer, self.train_loggers, self.val_loggers = self._create_loggers()
self.logger_service = LoggerService(
self.train_loggers, self.val_loggers)
self.log_period_as_iter = args.log_period_as_iter
self.ce = nn.CrossEntropyLoss(ignore_index=0)
def train(self):
accum_iter = 0
self.validate(0, accum_iter)
for epoch in range(self.num_epochs):
accum_iter = self.train_one_epoch(epoch, accum_iter)
self.validate(epoch, accum_iter)
self.logger_service.complete({
'state_dict': (self._create_state_dict()),
})
self.writer.close()
def train_one_epoch(self, epoch, accum_iter):
self.model.train()
average_meter_set = AverageMeterSet()
tqdm_dataloader = tqdm(self.train_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch_size = batch[0].size(0)
seqs, lengths, labels = batch
lengths = lengths.flatten()
seqs, labels = seqs.to(self.device), labels.to(self.device)
self.optimizer.zero_grad()
logits = self.model(seqs, lengths)
loss = self.ce(logits, labels.squeeze())
loss.backward()
self.clip_gradients(5)
self.optimizer.step()
if self.args.enable_lr_schedule:
self.lr_scheduler.step()
average_meter_set.update('loss', loss.item())
tqdm_dataloader.set_description(
'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))
accum_iter += batch_size
if self._needs_to_log(accum_iter):
tqdm_dataloader.set_description('Logging to Tensorboard')
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch + 1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.logger_service.log_train(log_data)
return accum_iter
def validate(self, epoch, accum_iter):
self.model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.val_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
metrics = self.calculate_metrics(batch)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch+1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.logger_service.log_val(log_data)
def test(self):
best_model_dict = torch.load(os.path.join(
self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.model.load_state_dict(best_model_dict)
self.model.eval()
average_meter_set = AverageMeterSet()
all_scores = []
average_scores = []
with torch.no_grad():
tqdm_dataloader = tqdm(self.test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = [x.to(self.device) for x in batch]
metrics = self.calculate_metrics(batch)
# seqs, lengths, candidates, labels = batch
# lengths = lengths.flatten()
# seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
# scores = self.model(seqs, lengths)
# scores_sorted, indices = torch.sort(scores, dim=-1, descending=True)
# all_scores += scores_sorted[:, :100].cpu().numpy().tolist()
# average_scores += scores_sorted.cpu().numpy().tolist()
# scores = scores.gather(1, candidates)
# metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
average_metrics = average_meter_set.averages()
with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:
json.dump(average_metrics, f, indent=4)
return average_metrics
def calculate_metrics(self, batch):
seqs, lengths, candidates, labels = batch
lengths = lengths.flatten()
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
scores = self.model(seqs, lengths) # B x V
scores = scores.gather(1, candidates) # B x C
metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
return metrics
def clip_gradients(self, limit=5):
for p in self.model.parameters():
nn.utils.clip_grad_norm_(p, 5)
def _update_meter_set(self, meter_set, metrics):
for k, v in metrics.items():
meter_set.update(k, v)
def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):
description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]
] + ['Recall@%d' % k for k in self.metric_ks[:3]]
description = 'Eval: ' + \
', '.join(s + ' {:.3f}' for s in description_metrics)
description = description.replace('NDCG', 'N').replace('Recall', 'R')
description = description.format(
*(meter_set[k].avg for k in description_metrics))
tqdm_dataloader.set_description(description)
def _create_optimizer(self):
args = self.args
param_optimizer = list(self.model.named_parameters())
no_decay = ['bias', 'layer_norm']
optimizer_grouped_parameters = [
{
'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay,
},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]
if args.optimizer.lower() == 'adamw':
return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
elif args.optimizer.lower() == 'adam':
return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'sgd':
return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
else:
raise ValueError
def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
# based on hugging face get_linear_schedule_with_warmup
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _create_loggers(self):
root = Path(self.export_root)
writer = SummaryWriter(root.joinpath('logs'))
model_checkpoint = root.joinpath('models')
train_loggers = [
MetricGraphPrinter(writer, key='epoch',
graph_name='Epoch', group_name='Train'),
MetricGraphPrinter(writer, key='loss',
graph_name='Loss', group_name='Train'),
]
val_loggers = []
for k in self.metric_ks:
val_loggers.append(
MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))
val_loggers.append(
MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))
val_loggers.append(RecentModelLogger(model_checkpoint))
val_loggers.append(BestModelLogger(
model_checkpoint, metric_key=self.best_metric))
return writer, train_loggers, val_loggers
def _create_state_dict(self):
return {
STATE_DICT_KEY: self.model.module.state_dict() if self.is_parallel else self.model.state_dict(),
OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),
}
def _needs_to_log(self, accum_iter):
return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0
class SASTrainer(metaclass=ABCMeta):
def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):
self.args = args
self.device = args.device
self.model = model.to(self.device)
self.is_parallel = args.num_gpu > 1
if self.is_parallel:
self.model = nn.DataParallel(self.model)
self.num_epochs = args.num_epochs
self.metric_ks = args.metric_ks
self.best_metric = args.best_metric
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.optimizer = self._create_optimizer()
if args.enable_lr_schedule:
if args.enable_lr_warmup:
self.lr_scheduler = self.get_linear_schedule_with_warmup(
self.optimizer, args.warmup_steps, len(train_loader) * self.num_epochs)
else:
self.lr_scheduler = optim.lr_scheduler.StepLR(
self.optimizer, step_size=args.decay_step, gamma=args.gamma)
self.export_root = export_root
self.writer, self.train_loggers, self.val_loggers = self._create_loggers()
self.logger_service = LoggerService(
self.train_loggers, self.val_loggers)
self.log_period_as_iter = args.log_period_as_iter
self.bce = nn.BCEWithLogitsLoss()
def train(self):
accum_iter = 0
self.validate(0, accum_iter)
for epoch in range(self.num_epochs):
accum_iter = self.train_one_epoch(epoch, accum_iter)
self.validate(epoch, accum_iter)
self.logger_service.complete({
'state_dict': (self._create_state_dict()),
})
self.writer.close()
def train_one_epoch(self, epoch, accum_iter):
self.model.train()
average_meter_set = AverageMeterSet()
tqdm_dataloader = tqdm(self.train_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch_size = batch[0].size(0)
batch = [x.to(self.device) for x in batch]
self.optimizer.zero_grad()
loss = self.calculate_loss(batch)
loss.backward()
self.clip_gradients(5)
self.optimizer.step()
if self.args.enable_lr_schedule:
self.lr_scheduler.step()
average_meter_set.update('loss', loss.item())
tqdm_dataloader.set_description(
'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))
accum_iter += batch_size
if self._needs_to_log(accum_iter):
tqdm_dataloader.set_description('Logging to Tensorboard')
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch + 1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.logger_service.log_train(log_data)
return accum_iter
def validate(self, epoch, accum_iter):
self.model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.val_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = [x.to(self.device) for x in batch]
metrics = self.calculate_metrics(batch)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch+1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.logger_service.log_val(log_data)
def test(self):
best_model_dict = torch.load(os.path.join(
self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.model.load_state_dict(best_model_dict)
self.model.eval()
average_meter_set = AverageMeterSet()
all_scores = []
average_scores = []
with torch.no_grad():
tqdm_dataloader = tqdm(self.test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = [x.to(self.device) for x in batch]
metrics = self.calculate_metrics(batch)
# seqs, candidates, labels = batch
# scores = self.model(seqs)
# scores = scores[:, -1, :]
# scores_sorted, indices = torch.sort(scores, dim=-1, descending=True)
# all_scores += scores_sorted[:, :100].cpu().numpy().tolist()
# average_scores += scores_sorted.cpu().numpy().tolist()
# scores = scores.gather(1, candidates)
# metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
average_metrics = average_meter_set.averages()
with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:
json.dump(average_metrics, f, indent=4)
return average_metrics
def calculate_loss(self, batch):
seqs, labels, negs = batch
logits = self.model(seqs) # F.softmax(self.model(seqs), dim=-1)
pos_logits = logits.gather(-1, labels.unsqueeze(-1))[seqs > 0].squeeze()
pos_targets = torch.ones_like(pos_logits)
neg_logits = logits.gather(-1, negs.unsqueeze(-1))[seqs > 0].squeeze()
neg_targets = torch.zeros_like(neg_logits)
loss = self.bce(torch.cat((pos_logits, neg_logits), 0), torch.cat((pos_targets, neg_targets), 0))
return loss
def calculate_metrics(self, batch):
seqs, candidates, labels = batch
scores = self.model(seqs)
scores = scores[:, -1, :]
scores = scores.gather(1, candidates)
metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
return metrics
def clip_gradients(self, limit=5):
for p in self.model.parameters():
nn.utils.clip_grad_norm_(p, 5)
def _update_meter_set(self, meter_set, metrics):
for k, v in metrics.items():
meter_set.update(k, v)
def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):
description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]
] + ['Recall@%d' % k for k in self.metric_ks[:3]]
description = 'Eval: ' + \
', '.join(s + ' {:.3f}' for s in description_metrics)
description = description.replace('NDCG', 'N').replace('Recall', 'R')
description = description.format(
*(meter_set[k].avg for k in description_metrics))
tqdm_dataloader.set_description(description)
def _create_optimizer(self):
args = self.args
param_optimizer = list(self.model.named_parameters())
no_decay = ['bias', 'layer_norm']
optimizer_grouped_parameters = [
{
'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay,
},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]
if args.optimizer.lower() == 'adamw':
return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
elif args.optimizer.lower() == 'adam':
return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'sgd':
return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
else:
raise ValueError
def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
# based on hugging face get_linear_schedule_with_warmup
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _create_loggers(self):
root = Path(self.export_root)
writer = SummaryWriter(root.joinpath('logs'))
model_checkpoint = root.joinpath('models')
train_loggers = [
MetricGraphPrinter(writer, key='epoch',
graph_name='Epoch', group_name='Train'),
MetricGraphPrinter(writer, key='loss',
graph_name='Loss', group_name='Train'),
]
val_loggers = []
for k in self.metric_ks:
val_loggers.append(
MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))
val_loggers.append(
MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))
val_loggers.append(RecentModelLogger(model_checkpoint))
val_loggers.append(BestModelLogger(
model_checkpoint, metric_key=self.best_metric))
return writer, train_loggers, val_loggers
def _create_state_dict(self):
return {
STATE_DICT_KEY: self.model.module.state_dict() if self.is_parallel else self.model.state_dict(),
OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),
}
def _needs_to_log(self, accum_iter):
return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0
class BERTTrainer(metaclass=ABCMeta):
def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):
self.args = args
self.device = args.device
self.model = model.to(self.device)
self.is_parallel = args.num_gpu > 1
if self.is_parallel:
self.model = nn.DataParallel(self.model)
self.num_epochs = args.num_epochs
self.metric_ks = args.metric_ks
self.best_metric = args.best_metric
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.optimizer = self._create_optimizer()
if args.enable_lr_schedule:
if args.enable_lr_warmup:
self.lr_scheduler = self.get_linear_schedule_with_warmup(
self.optimizer, args.warmup_steps, len(train_loader) * self.num_epochs)
else:
self.lr_scheduler = optim.lr_scheduler.StepLR(
self.optimizer, step_size=args.decay_step, gamma=args.gamma)
self.export_root = export_root
self.writer, self.train_loggers, self.val_loggers = self._create_loggers()
self.logger_service = LoggerService(
self.train_loggers, self.val_loggers)
self.log_period_as_iter = args.log_period_as_iter
self.ce = nn.CrossEntropyLoss(ignore_index=0)
def train(self):
accum_iter = 0
self.validate(0, accum_iter)
for epoch in range(self.num_epochs):
accum_iter = self.train_one_epoch(epoch, accum_iter)
self.validate(epoch, accum_iter)
self.logger_service.complete({
'state_dict': (self._create_state_dict()),
})
self.writer.close()
def train_one_epoch(self, epoch, accum_iter):
self.model.train()
average_meter_set = AverageMeterSet()
tqdm_dataloader = tqdm(self.train_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch_size = batch[0].size(0)
batch = [x.to(self.device) for x in batch]
self.optimizer.zero_grad()
loss = self.calculate_loss(batch)
loss.backward()
self.clip_gradients(5)
self.optimizer.step()
if self.args.enable_lr_schedule:
self.lr_scheduler.step()
average_meter_set.update('loss', loss.item())
tqdm_dataloader.set_description(
'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))
accum_iter += batch_size
if self._needs_to_log(accum_iter):
tqdm_dataloader.set_description('Logging to Tensorboard')
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch + 1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.logger_service.log_train(log_data)
return accum_iter
def validate(self, epoch, accum_iter):
self.model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.val_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = [x.to(self.device) for x in batch]
metrics = self.calculate_metrics(batch)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch+1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.logger_service.log_val(log_data)
def test(self):
best_model_dict = torch.load(os.path.join(
self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.model.load_state_dict(best_model_dict)
self.model.eval()
average_meter_set = AverageMeterSet()
all_scores = []
average_scores = []
with torch.no_grad():
tqdm_dataloader = tqdm(self.test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
batch = [x.to(self.device) for x in batch]
metrics = self.calculate_metrics(batch)
# seqs, candidates, labels = batch
# scores = self.model(seqs)
# scores = scores[:, -1, :]
# scores_sorted, indices = torch.sort(scores, dim=-1, descending=True)
# all_scores += scores_sorted[:, :100].cpu().numpy().tolist()
# average_scores += scores_sorted.cpu().numpy().tolist()
# scores = scores.gather(1, candidates)
# metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
average_metrics = average_meter_set.averages()
with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:
json.dump(average_metrics, f, indent=4)
return average_metrics
def calculate_loss(self, batch):
seqs, labels = batch
logits = self.model(seqs)
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
loss = self.ce(logits, labels)
return loss
def calculate_metrics(self, batch):
seqs, candidates, labels = batch
scores = self.model(seqs)
scores = scores[:, -1, :]
scores = scores.gather(1, candidates)
metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
return metrics
def clip_gradients(self, limit=5):
for p in self.model.parameters():
nn.utils.clip_grad_norm_(p, 5)
def _update_meter_set(self, meter_set, metrics):
for k, v in metrics.items():
meter_set.update(k, v)
def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):
description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]
] + ['Recall@%d' % k for k in self.metric_ks[:3]]
description = 'Eval: ' + \
', '.join(s + ' {:.3f}' for s in description_metrics)
description = description.replace('NDCG', 'N').replace('Recall', 'R')
description = description.format(
*(meter_set[k].avg for k in description_metrics))
tqdm_dataloader.set_description(description)
def _create_optimizer(self):
args = self.args
param_optimizer = list(self.model.named_parameters())
no_decay = ['bias', 'layer_norm']
optimizer_grouped_parameters = [
{
'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay,
},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]
if args.optimizer.lower() == 'adamw':
return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
elif args.optimizer.lower() == 'adam':
return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'sgd':
return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
else:
raise ValueError
def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
# based on hugging face get_linear_schedule_with_warmup
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _create_loggers(self):
root = Path(self.export_root)
writer = SummaryWriter(root.joinpath('logs'))
model_checkpoint = root.joinpath('models')
train_loggers = [
MetricGraphPrinter(writer, key='epoch',
graph_name='Epoch', group_name='Train'),
MetricGraphPrinter(writer, key='loss',
graph_name='Loss', group_name='Train'),
]
val_loggers = []
for k in self.metric_ks:
val_loggers.append(
MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))
val_loggers.append(
MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))
val_loggers.append(RecentModelLogger(model_checkpoint))
val_loggers.append(BestModelLogger(
model_checkpoint, metric_key=self.best_metric))
return writer, train_loggers, val_loggers
def _create_state_dict(self):
return {
STATE_DICT_KEY: self.model.module.state_dict() if self.is_parallel else self.model.state_dict(),
OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),
}
def _needs_to_log(self, accum_iter):
return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0
class NoDataRankDistillationTrainer(metaclass=ABCMeta):
def __init__(self, args, model_code, model, bb_model, test_loader, export_root, loss='ranking', tau=1., margin_topk=0.5, margin_neg=0.5):
self.args = args
self.device = args.device
self.num_items = args.num_items
self.max_len = args.bert_max_len
self.batch_size = args.train_batch_size
self.mask_prob = args.bert_mask_prob
self.max_predictions = args.bert_max_predictions
self.CLOZE_MASK_TOKEN = self.num_items + 1
self.model = model.to(self.device)
self.model_code = model_code
self.bb_model = bb_model.to(self.device)
self.num_epochs = args.num_epochs
self.metric_ks = args.metric_ks
self.best_metric = args.best_metric
self.export_root = export_root
self.log_period_as_iter = args.log_period_as_iter
self.is_parallel = args.num_gpu > 1
if self.is_parallel:
self.model = nn.DataParallel(self.model)
self.test_loader = test_loader
self.optimizer = self._create_optimizer()
if args.enable_lr_schedule:
if args.enable_lr_warmup:
self.lr_scheduler = self.get_linear_schedule_with_warmup(
self.optimizer, args.warmup_steps, (args.num_generated_seqs // self.batch_size + 1) * self.num_epochs * 2)
else:
self.lr_scheduler = optim.lr_scheduler.StepLR(
self.optimizer, step_size=args.decay_step, gamma=args.gamma)
self.loss = loss
self.tau = tau
self.margin_topk = margin_topk
self.margin_neg = margin_neg
if self.loss == 'kl':
self.loss_func = nn.KLDivLoss(reduction='batchmean')
elif self.loss == 'ranking':
self.loss_func_1 = nn.MarginRankingLoss(margin=self.margin_topk)
self.loss_func_2 = nn.MarginRankingLoss(margin=self.margin_neg)
elif self.loss == 'kl+ct':
self.loss_func_1 = nn.KLDivLoss(reduction='batchmean')
self.loss_func_2 = nn.CrossEntropyLoss(ignore_index=0)
def calculate_loss(self, seqs, labels, candidates, lengths=None):
if isinstance(self.model, BERT) or isinstance(self.model, SASRec):
logits = self.model(seqs)[:, -1, :]
elif isinstance(self.model, NARM):
logits = self.model(seqs, lengths)
if self.loss == 'kl':
logits = torch.gather(logits, -1, candidates)
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1, labels.size(-1))
loss = self.loss_func(F.log_softmax(logits/self.tau, dim=-1), F.softmax(labels/self.tau, dim=-1))
elif self.loss == 'ranking':
# logits = F.softmax(logits/self.tau, dim=-1)
weight = torch.ones_like(logits).to(self.device)
weight[torch.arange(weight.size(0)).unsqueeze(1), candidates] = 0
neg_samples = torch.distributions.Categorical(F.softmax(weight, -1)).sample_n(candidates.size(-1)).permute(1, 0)
# assume candidates are in descending order w.r.t. true label
neg_logits = torch.gather(logits, -1, neg_samples)
logits = torch.gather(logits, -1, candidates)
logits_1 = logits[:, :-1].reshape(-1)
logits_2 = logits[:, 1:].reshape(-1)
loss = self.loss_func_1(logits_1, logits_2, torch.ones(logits_1.shape).to(self.device))
loss += self.loss_func_2(logits, neg_logits, torch.ones(logits.shape).to(self.device))
elif self.loss == 'kl+ct':
logits = torch.gather(logits, -1, candidates)
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1, labels.size(-1))
loss = self.loss_func_1(F.log_softmax(logits/self.tau, dim=-1), F.softmax(labels/self.tau, dim=-1))
loss += self.loss_func_2(F.softmax(logits), torch.argmax(labels, dim=-1))
return loss
def calculate_metrics(self, batch, similarity=False):
self.model.eval()
self.bb_model.eval()
if isinstance(self.model, BERT) or isinstance(self.model, SASRec):
seqs, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
scores = self.model(seqs)[:, -1, :]
metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)
elif isinstance(self.model, NARM):
seqs, lengths, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
lengths = lengths.flatten()
scores = self.model(seqs, lengths)
metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)
if similarity:
if isinstance(self.model, BERT) and isinstance(self.bb_model, BERT):
soft_labels = self.bb_model(seqs)[:, -1, :]
elif isinstance(self.model, BERT) and isinstance(self.bb_model, SASRec):
temp_seqs = torch.cat((torch.zeros(seqs.size(0)).long().unsqueeze(1).to(self.device), seqs[:, :-1]), dim=1)
soft_labels = self.bb_model(temp_seqs)[:, -1, :]
elif isinstance(self.model, BERT) and isinstance(self.bb_model, NARM):
temp_seqs = torch.cat((torch.zeros(seqs.size(0)).long().unsqueeze(1).to(self.device), seqs[:, :-1]), dim=1)
temp_seqs = self.pre2post_padding(temp_seqs)
temp_lengths = (temp_seqs > 0).sum(-1).cpu().flatten()
soft_labels = self.bb_model(temp_seqs, temp_lengths)
elif isinstance(self.model, SASRec) and isinstance(self.bb_model, SASRec):
soft_labels = self.bb_model(seqs)[:, -1, :]
elif isinstance(self.model, SASRec) and isinstance(self.bb_model, BERT):
temp_seqs = torch.cat((seqs[:, 1:], torch.tensor([self.CLOZE_MASK_TOKEN] * seqs.size(0)).unsqueeze(1).to(self.device)), dim=1)
soft_labels = self.bb_model(temp_seqs)[:, -1, :]
elif isinstance(self.model, SASRec) and isinstance(self.bb_model, NARM):
temp_seqs = self.pre2post_padding(seqs)
temp_lengths = (temp_seqs > 0).sum(-1).cpu().flatten()
soft_labels = self.bb_model(temp_seqs, temp_lengths)
elif isinstance(self.model, NARM) and isinstance(self.bb_model, NARM):
soft_labels = self.bb_model(seqs, lengths)
elif isinstance(self.model, NARM) and isinstance(self.bb_model, BERT):
temp_seqs = self.post2pre_padding(seqs)
temp_seqs = torch.cat((temp_seqs[:, 1:], torch.tensor([self.CLOZE_MASK_TOKEN] * seqs.size(0)).unsqueeze(1).to(self.device)), dim=1)
soft_labels = self.bb_model(temp_seqs)[:, -1, :]
elif isinstance(self.model, NARM) and isinstance(self.bb_model, SASRec):
temp_seqs = self.post2pre_padding(seqs)
soft_labels = self.bb_model(temp_seqs)[:, -1, :]
similarity = kl_agreements_and_intersctions_for_ks(scores, soft_labels, self.metric_ks)
metrics = {**metrics, **similarity}
return metrics
def generate_autoregressive_data(self, k=100, batch_size=50):
dataset = dis_dataset_factory(self.args, self.model_code, 'autoregressive')
# if dataset.check_data_present():
# print('Dataset already exists. Skip generation')
# return
batch_num = batch_size // self.args.num_generated_seqs
print('Generating dataset...')
for i in tqdm(range(batch_num)):
seqs = torch.randint(1, self.num_items + 1, (batch_size, 1)).to(self.device)
logits = None
candidates = None
self.bb_model.eval()
with torch.no_grad():
if isinstance(self.bb_model, BERT):
mask_items = torch.tensor([self.CLOZE_MASK_TOKEN] * seqs.size(0)).to(self.device)
for j in range(self.max_len - 1):
input_seqs = torch.zeros((seqs.size(0), self.max_len)).to(self.device)
input_seqs[:, (self.max_len-2-j):-1] = seqs
input_seqs[:, -1] = mask_items
labels = self.bb_model(input_seqs.long())[:, -1, :]
_, sorted_items = torch.sort(labels[:, 1:-1], dim=-1, descending=True)
sorted_items = sorted_items[:, :k] + 1
randomized_label = torch.rand(sorted_items.shape).to(self.device)
randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)
randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)
selected_indices = torch.distributions.Categorical(F.softmax(torch.ones_like(randomized_label), -1).to(randomized_label.device)).sample()
row_indices = torch.arange(sorted_items.size(0))
seqs = torch.cat((seqs, sorted_items[row_indices, selected_indices].unsqueeze(1)), 1)
try:
logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)
candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)
except:
logits = randomized_label.unsqueeze(1)
candidates = sorted_items.unsqueeze(1)
input_seqs = torch.zeros((seqs.size(0), self.max_len)).to(self.device)
input_seqs[:, :-1] = seqs[:, 1:]
input_seqs[:, -1] = mask_items
labels = self.bb_model(input_seqs.long())[:, -1, :]
_, sorted_items = torch.sort(labels[:, 1:-1], dim=-1, descending=True)
sorted_items = sorted_items[:, :k] + 1
randomized_label = torch.rand(sorted_items.shape).to(self.device)
randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)
randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)
logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)
candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)
elif isinstance(self.bb_model, SASRec):
for j in range(self.max_len - 1):
input_seqs = torch.zeros((seqs.size(0), self.max_len)).to(self.device)
input_seqs[:, (self.max_len-1-j):] = seqs
labels = self.bb_model(input_seqs.long())[:, -1, :]
_, sorted_items = torch.sort(labels[:, 1:], dim=-1, descending=True)
sorted_items = sorted_items[:, :k] + 1
randomized_label = torch.rand(sorted_items.shape).to(self.device)
randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)
randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)
selected_indices = torch.distributions.Categorical(F.softmax(torch.ones_like(randomized_label), -1).to(randomized_label.device)).sample()
row_indices = torch.arange(sorted_items.size(0))
seqs = torch.cat((seqs, sorted_items[row_indices, selected_indices].unsqueeze(1)), 1)
try:
logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)
candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)
except:
logits = randomized_label.unsqueeze(1)
candidates = sorted_items.unsqueeze(1)
labels = self.bb_model(seqs.long())[:, -1, :]
_, sorted_items = torch.sort(labels[:, 1:], dim=-1, descending=True)
sorted_items = sorted_items[:, :k] + 1
randomized_label = torch.rand(sorted_items.shape).to(self.device)
randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)
randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)
logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)
candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)
elif isinstance(self.bb_model, NARM):
for j in range(self.max_len - 1):
lengths = torch.tensor([j + 1] * seqs.size(0))
labels = self.bb_model(seqs.long(), lengths)
_, sorted_items = torch.sort(labels[:, 1:], dim=-1, descending=True)
sorted_items = sorted_items[:, :k] + 1
randomized_label = torch.rand(sorted_items.shape).to(self.device)
randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)
randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)
selected_indices = torch.distributions.Categorical(F.softmax(torch.ones_like(randomized_label), -1).to(randomized_label.device)).sample()
row_indices = torch.arange(sorted_items.size(0))
seqs = torch.cat((seqs, sorted_items[row_indices, selected_indices].unsqueeze(1)), 1)
try:
logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)
candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)
except:
logits = randomized_label.unsqueeze(1)
candidates = sorted_items.unsqueeze(1)
lengths = torch.tensor([self.max_len] * seqs.size(0))
labels = self.bb_model(seqs.long(), lengths)
_, sorted_items = torch.sort(labels[:, 1:], dim=-1, descending=True)
sorted_items = sorted_items[:, :k] + 1
randomized_label = torch.rand(sorted_items.shape).to(self.device)
randomized_label = randomized_label / randomized_label.sum(dim=-1).unsqueeze(-1)
randomized_label, _ = torch.sort(randomized_label, dim=-1, descending=True)
logits = torch.cat((logits, randomized_label.unsqueeze(1)), 1)
candidates = torch.cat((candidates, sorted_items.unsqueeze(1)), 1)
if i == 0:
batch_tokens = seqs.cpu().numpy()
batch_logits = logits.cpu().numpy()
batch_candidates = candidates.cpu().numpy()
else:
batch_tokens = np.concatenate((batch_tokens, seqs.cpu().numpy()))
batch_logits = np.concatenate((batch_logits, logits.cpu().numpy()))
batch_candidates = np.concatenate((batch_candidates, candidates.cpu().numpy()))
dataset.save_dataset(batch_tokens.tolist(), batch_logits.tolist(), batch_candidates.tolist())
def train_autoregressive(self):
accum_iter = 0
self.writer, self.train_loggers, self.val_loggers = self._create_loggers()
self.logger_service = LoggerService(
self.train_loggers, self.val_loggers)
self.generate_autoregressive_data()
dis_train_loader, dis_val_loader = dis_train_loader_factory(self.args, self.model_code, 'autoregressive')
print('## Distilling model via autoregressive data... ##')
self.validate(dis_val_loader, 0, accum_iter)
for epoch in range(self.num_epochs):
accum_iter = self.train_one_epoch(epoch, accum_iter, dis_train_loader, dis_val_loader, stage=1)
metrics = self.test()
self.logger_service.complete({
'state_dict': (self._create_state_dict()),
})
self.writer.close()
return metrics
def train_one_epoch(self, epoch, accum_iter, train_loader, val_loader, stage=0):
self.model.train()
self.bb_model.train()
average_meter_set = AverageMeterSet()
tqdm_dataloader = tqdm(train_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
self.optimizer.zero_grad()
if isinstance(self.model, BERT) or isinstance(self.model, SASRec):
seqs, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
loss = self.calculate_loss(seqs, labels, candidates)
elif isinstance(self.model, NARM):
seqs, lengths, candidates, labels = batch
lengths = lengths.flatten()
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
loss = self.calculate_loss(seqs, labels, candidates, lengths=lengths)
loss.backward()
self.clip_gradients(5)
self.optimizer.step()
accum_iter += int(seqs.size(0))
average_meter_set.update('loss', loss.item())
tqdm_dataloader.set_description(
'Epoch {} Stage {}, loss {:.3f} '.format(epoch+1, stage, average_meter_set['loss'].avg))
if self._needs_to_log(accum_iter):
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch+1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.logger_service.log_train(log_data)
if self.args.enable_lr_schedule:
self.lr_scheduler.step()
self.validate(val_loader, epoch, accum_iter)
return accum_iter
def validate(self, val_loader, epoch, accum_iter):
self.model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(val_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
metrics = self.calculate_metrics(batch)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch+1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.logger_service.log_val(log_data)
def test(self):
wb_model = torch.load(os.path.join(
self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.model.load_state_dict(wb_model)
self.model.eval()
self.bb_model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
metrics = self.calculate_metrics(batch, similarity=True)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
average_metrics = average_meter_set.averages()
with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:
json.dump(average_metrics, f, indent=4)
return average_metrics
def bb_model_test(self):
self.bb_model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
if isinstance(self.model, BERT) or isinstance(self.model, SASRec):
seqs, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
scores = self.bb_model(seqs)[:, -1, :]
metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)
elif isinstance(self.model, NARM):
seqs, lengths, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
lengths = lengths.flatten()
scores = self.bb_model(seqs, lengths)
metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
average_metrics = average_meter_set.averages()
with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:
json.dump(average_metrics, f, indent=4)
return average_metrics
def pre2post_padding(self, seqs):
processed = torch.zeros_like(seqs)
lengths = (seqs > 0).sum(-1).squeeze()
for i in range(seqs.size(0)):
processed[i, :lengths[i]] = seqs[i, seqs.size(1)-lengths[i]:]
return processed
def post2pre_padding(self, seqs):
processed = torch.zeros_like(seqs)
lengths = (seqs > 0).sum(-1).squeeze()
for i in range(seqs.size(0)):
processed[i, seqs.size(1)-lengths[i]:] = seqs[i, :lengths[i]]
return processed
def clip_gradients(self, limit=5):
for p in self.model.parameters():
nn.utils.clip_grad_norm_(p, 5)
def _update_meter_set(self, meter_set, metrics):
for k, v in metrics.items():
meter_set.update(k, v)
def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):
description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]
] + ['Recall@%d' % k for k in self.metric_ks[:3]]
description = 'Eval: ' + \
', '.join(s + ' {:.3f}' for s in description_metrics)
description = description.replace('NDCG', 'N').replace('Recall', 'R')
description = description.format(
*(meter_set[k].avg for k in description_metrics))
tqdm_dataloader.set_description(description)
def _create_optimizer(self):
args = self.args
param_optimizer = list(self.model.named_parameters())
no_decay = ['bias', 'layer_norm']
optimizer_grouped_parameters = [
{
'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay,
},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]
if args.optimizer.lower() == 'adamw':
return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
elif args.optimizer.lower() == 'adam':
return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'sgd':
return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
else:
raise ValueError
def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
# based on hugging face get_linear_schedule_with_warmup
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _create_loggers(self):
root = Path(self.export_root)
writer = SummaryWriter(root.joinpath('logs'))
model_checkpoint = root.joinpath('models')
train_loggers = [
MetricGraphPrinter(writer, key='epoch',
graph_name='Epoch', group_name='Train'),
MetricGraphPrinter(writer, key='loss',
graph_name='Loss', group_name='Train'),
]
val_loggers = []
for k in self.metric_ks:
val_loggers.append(
MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))
val_loggers.append(
MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))
val_loggers.append(RecentModelLogger(model_checkpoint))
val_loggers.append(BestModelLogger(
model_checkpoint, metric_key=self.best_metric))
return writer, train_loggers, val_loggers
def _create_state_dict(self):
return {
STATE_DICT_KEY: self.model.module.state_dict() if self.is_parallel else self.model.state_dict(),
OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),
}
def _needs_to_log(self, accum_iter):
return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0
def train(args, export_root=None, resume=False):
args.lr = 0.001
fix_random_seed_as(args.model_init_seed)
train_loader, val_loader, test_loader = dataloader_factory(args)
if args.model_code == 'bert':
model = BERT(args)
elif args.model_code == 'sas':
model = SASRec(args)
elif args.model_code == 'narm':
model = NARM(args)
if export_root == None:
export_root = 'experiments/' + args.model_code + '/' + args.dataset_code
if resume:
try:
model.load_state_dict(torch.load(os.path.join(export_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))
except FileNotFoundError:
print('Failed to load old model, continue training new model...')
if args.model_code == 'bert':
args.num_epochs = 10
trainer = BERTTrainer(args, model, train_loader, val_loader, test_loader, export_root)
if args.model_code == 'sas':
trainer = SASTrainer(args, model, train_loader, val_loader, test_loader, export_root)
elif args.model_code == 'narm':
args.num_epochs = 100
trainer = RNNTrainer(args, model, train_loader, val_loader, test_loader, export_root)
trainer.train()
trainer.test()
if __name__ == "__main__":
set_template(args)
# when use k-core beauty and k is not 5 (beauty-dense)
# args.min_uc = k
# args.min_sc = k
train(args, resume=True)
Raw file doesn't exist. Downloading... Extracting data...
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:70: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
Filtering triplets Densifying index Splitting
100%|██████████| 6040/6040 [00:10<00:00, 595.20it/s]
Negative samples don't exist. Generating. Sampling negative items randomly...
100%|██████████| 6040/6040 [00:01<00:00, 5155.44it/s]
Negative samples don't exist. Generating. Sampling negative items randomly...
100%|██████████| 6040/6040 [00:01<00:00, 5303.09it/s]
Failed to load old model, continue training new model...
Eval: N@1 0.009, N@5 0.025, N@10 0.040, R@1 0.009, R@5 0.043, R@10 0.091: 100%|██████████| 48/48 [00:03<00:00, 13.01it/s]
Update Best NDCG@10 Model at 1
Epoch 1, loss 8.047 : 100%|██████████| 68/68 [00:12<00:00, 5.43it/s] Eval: N@1 0.026, N@5 0.066, N@10 0.095, R@1 0.026, R@5 0.108, R@10 0.196: 100%|██████████| 48/48 [00:03<00:00, 14.29it/s]
Update Best NDCG@10 Model at 1
Epoch 2, loss 7.721 : 100%|██████████| 68/68 [00:12<00:00, 5.44it/s] Eval: N@1 0.041, N@5 0.103, N@10 0.142, R@1 0.041, R@5 0.166, R@10 0.287: 100%|██████████| 48/48 [00:03<00:00, 14.39it/s]
Update Best NDCG@10 Model at 2
Epoch 3, loss 7.419 : 100%|██████████| 68/68 [00:12<00:00, 5.42it/s] Eval: N@1 0.055, N@5 0.130, N@10 0.172, R@1 0.055, R@5 0.205, R@10 0.337: 100%|██████████| 48/48 [00:03<00:00, 14.23it/s]
Update Best NDCG@10 Model at 3
Epoch 4, loss 7.230 : 100%|██████████| 68/68 [00:12<00:00, 5.46it/s] Eval: N@1 0.070, N@5 0.153, N@10 0.200, R@1 0.070, R@5 0.235, R@10 0.380: 100%|██████████| 48/48 [00:03<00:00, 14.27it/s]
Update Best NDCG@10 Model at 4
Epoch 5, loss 7.126 : 100%|██████████| 68/68 [00:12<00:00, 5.46it/s] Eval: N@1 0.076, N@5 0.167, N@10 0.215, R@1 0.076, R@5 0.256, R@10 0.405: 100%|██████████| 48/48 [00:03<00:00, 14.48it/s]
Update Best NDCG@10 Model at 5
Epoch 6, loss 7.071 : 100%|██████████| 68/68 [00:12<00:00, 5.43it/s] Eval: N@1 0.076, N@5 0.169, N@10 0.218, R@1 0.076, R@5 0.260, R@10 0.412: 100%|██████████| 48/48 [00:03<00:00, 14.16it/s]
Update Best NDCG@10 Model at 6
Epoch 7, loss 7.043 : 100%|██████████| 68/68 [00:12<00:00, 5.48it/s] Eval: N@1 0.077, N@5 0.174, N@10 0.223, R@1 0.077, R@5 0.269, R@10 0.424: 100%|██████████| 48/48 [00:03<00:00, 14.32it/s]
Update Best NDCG@10 Model at 7
Epoch 8, loss 7.026 : 100%|██████████| 68/68 [00:12<00:00, 5.46it/s] Eval: N@1 0.080, N@5 0.177, N@10 0.226, R@1 0.080, R@5 0.272, R@10 0.425: 100%|██████████| 48/48 [00:03<00:00, 14.37it/s]
Update Best NDCG@10 Model at 8
Epoch 9, loss 7.007 : 100%|██████████| 68/68 [00:12<00:00, 5.49it/s] Eval: N@1 0.080, N@5 0.180, N@10 0.227, R@1 0.080, R@5 0.277, R@10 0.423: 100%|██████████| 48/48 [00:03<00:00, 14.37it/s]
Update Best NDCG@10 Model at 9
Epoch 10, loss 7.004 : 100%|██████████| 68/68 [00:12<00:00, 5.48it/s] Eval: N@1 0.080, N@5 0.178, N@10 0.226, R@1 0.080, R@5 0.274, R@10 0.423: 100%|██████████| 48/48 [00:03<00:00, 14.33it/s] Eval: N@1 0.080, N@5 0.174, N@10 0.218, R@1 0.080, R@5 0.266, R@10 0.405: 100%|██████████| 48/48 [00:03<00:00, 14.35it/s]
def distill(args, bb_model_root=None, export_root=None, resume=False):
args.lr = 0.001
args.enable_lr_warmup = False
fix_random_seed_as(args.model_init_seed)
_, _, test_loader = dataloader_factory(args)
if args.model_code == 'bert':
model = BERT(args)
elif args.model_code == 'sas':
model = SASRec(args)
elif args.model_code == 'narm':
model = NARM(args)
# model_codes = {'b': 'bert', 's':'sas', 'n':'narm'}
# bb_model_code = model_codes[input('Input black box model code, b for BERT, s for SASRec and n for NARM: ')]
# args.num_generated_seqs = int(input('Input integer number of seqs budget: '))
args.num_generated_seqs = 5
bb_model_code = 'bert'
if bb_model_code == 'bert':
bb_model = BERT(args)
elif bb_model_code == 'sas':
bb_model = SASRec(args)
elif bb_model_code == 'narm':
bb_model = NARM(args)
if bb_model_root == None:
bb_model_root = 'experiments/' + bb_model_code + '/' + args.dataset_code
if export_root == None:
folder_name = bb_model_code + '2' + args.model_code + '_autoregressive' + str(args.num_generated_seqs)
export_root = 'experiments/distillation_rank/' + folder_name + '/' + args.dataset_code
bb_model.load_state_dict(torch.load(os.path.join(bb_model_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))
if resume:
try:
model.load_state_dict(torch.load(os.path.join(export_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))
except FileNotFoundError:
print('Failed to load old model, continue training new model...')
trainer = NoDataRankDistillationTrainer(args, args.model_code, model, bb_model, test_loader, export_root)
trainer.train_autoregressive()
if __name__ == "__main__":
set_template(args)
# when use k-core beauty and k is not 5 (beauty-dense)
# args.min_uc = k
# args.min_sc = k
args.num_epochs = 5
distill(args=args, resume=False)
Already preprocessed. Skip preprocessing Negatives samples exist. Loading. Negatives samples exist. Loading. Generating dataset...
100%|██████████| 10/10 [00:43<00:00, 4.36s/it]
## Distilling model via autoregressive data... ##
Eval: N@1 0.000, N@5 0.000, N@10 0.000, R@1 0.000, R@5 0.000, R@10 0.000: 100%|██████████| 4/4 [00:00<00:00, 11.28it/s] 0%| | 0/778 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/distributions/distribution.py:151: UserWarning: sample_n will be deprecated. Use .sample((n,)) instead warnings.warn('sample_n will be deprecated. Use .sample((n,)) instead', UserWarning) Epoch 1 Stage 1, loss 0.511 : 100%|██████████| 778/778 [02:03<00:00, 6.28it/s] Eval: N@1 1.000, N@5 1.000, N@10 1.000, R@1 1.000, R@5 1.000, R@10 1.000: 100%|██████████| 4/4 [00:00<00:00, 14.12it/s]
Update Best NDCG@10 Model at 1
Epoch 2 Stage 1, loss 0.443 : 100%|██████████| 778/778 [02:03<00:00, 6.29it/s] Eval: N@1 1.000, N@5 1.000, N@10 1.000, R@1 1.000, R@5 1.000, R@10 1.000: 100%|██████████| 4/4 [00:00<00:00, 14.09it/s] Epoch 3 Stage 1, loss 0.431 : 100%|██████████| 778/778 [02:03<00:00, 6.30it/s] Eval: N@1 1.000, N@5 1.000, N@10 1.000, R@1 1.000, R@5 1.000, R@10 1.000: 100%|██████████| 4/4 [00:00<00:00, 14.09it/s] Epoch 4 Stage 1, loss 0.422 : 100%|██████████| 778/778 [02:03<00:00, 6.30it/s] Eval: N@1 1.000, N@5 1.000, N@10 1.000, R@1 1.000, R@5 1.000, R@10 1.000: 100%|██████████| 4/4 [00:00<00:00, 13.72it/s] Epoch 5 Stage 1, loss 0.416 : 100%|██████████| 778/778 [02:03<00:00, 6.30it/s] Eval: N@1 1.000, N@5 1.000, N@10 1.000, R@1 1.000, R@5 1.000, R@10 1.000: 100%|██████████| 4/4 [00:00<00:00, 14.16it/s] Eval: N@1 0.073, N@5 0.140, N@10 0.160, R@1 0.073, R@5 0.202, R@10 0.264: 100%|██████████| 48/48 [00:09<00:00, 5.28it/s]
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.zero_()
elif isinstance(x, collections.abc.Iterable):
for elem in x:
zero_gradients(elem)
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
# from torch.autograd.gradcheck import zero_gradients
from tqdm import tqdm
import json
import faiss
import numpy as np
from abc import *
from pathlib import Path
class AdversarialRankAttacker(metaclass=ABCMeta):
def __init__(self, args, wb_model, bb_model, test_loader):
self.args = args
self.device = args.device
self.num_items = args.num_items
self.max_len = args.bert_max_len
self.wb_model = wb_model.to(self.device)
self.bb_model = bb_model.to(self.device)
self.metric_ks = args.metric_ks
self.best_metric = args.best_metric
self.test_loader = test_loader
self.CLOZE_MASK_TOKEN = args.num_items + 1
self.adv_ce = nn.CrossEntropyLoss(ignore_index=0)
if isinstance(self.wb_model, BERT):
self.item_embeddings = self.wb_model.embedding.token.weight.detach().cpu().numpy()[1:-1]
else:
self.item_embeddings = self.wb_model.embedding.token.weight.detach().cpu().numpy()[1:]
self.faiss_index = faiss.IndexFlatL2(self.item_embeddings.shape[-1])
self.faiss_index.add(self.item_embeddings)
self.item_embeddings = torch.tensor(self.item_embeddings).to(self.device)
if isinstance(self.bb_model, BERT):
self.bb_item_embeddings = self.bb_model.embedding.token.weight.detach().cpu().numpy()[1:-1]
else:
self.bb_item_embeddings = self.bb_model.embedding.token.weight.detach().cpu().numpy()[1:]
self.bb_item_embeddings = torch.tensor(self.bb_item_embeddings).to(self.device)
def attack(self, target, num_attack=10, repeated_search=10):
print('## Targeted Attack on Item {} ##'.format(str(target)))
average_meter_set = AverageMeterSet()
tqdm_dataloader = tqdm(self.test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
self.wb_model.eval()
with torch.no_grad():
if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):
seqs, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
if isinstance(self.bb_model, BERT):
seqs[:, 1:] = seqs[:, :-1]
seqs[:, 0] = 0
elif isinstance(self.bb_model, NARM):
seqs, lengths, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
seqs = self.post2pre_padding(seqs)
perturbed_seqs = seqs.clone()
append_items = torch.tensor([target]*(perturbed_seqs.size(0)*num_attack)).reshape(-1, num_attack)
perturbed_seqs = torch.cat((perturbed_seqs, torch.tensor(append_items).to(self.device)), 1)
perturbed_seqs = perturbed_seqs[:, -self.max_len:]
if isinstance(self.wb_model, BERT):
mask_items = torch.tensor([self.CLOZE_MASK_TOKEN] * perturbed_seqs.size(0)).to(self.device)
perturbed_seqs[:, :-1] = perturbed_seqs[:, 1:]
perturbed_seqs[:, -1] = mask_items
wb_embedding, mask = self.wb_model.embedding(perturbed_seqs.long())
elif isinstance(self.wb_model, SASRec):
wb_embedding, mask = self.wb_model.embedding(perturbed_seqs.long())
elif isinstance(self.wb_model, NARM):
perturbed_seqs = self.pre2post_padding(perturbed_seqs)
lengths = (perturbed_seqs > 0).sum(-1).cpu().flatten()
wb_embedding, mask = self.wb_model.embedding(perturbed_seqs.long(), lengths)
self.wb_model.train()
wb_embedding = wb_embedding.detach().clone()
wb_embedding.requires_grad = True
zero_gradients(wb_embedding)
if isinstance(self.wb_model, BERT) or isinstance(self.wb_model, SASRec):
wb_scores = self.wb_model.model(wb_embedding, self.wb_model.embedding.token.weight, mask)[:, -1, :]
elif isinstance(self.wb_model, NARM):
wb_scores = self.wb_model.model(wb_embedding, self.wb_model.embedding.token.weight, lengths, mask)
loss = self.adv_ce(wb_scores, torch.tensor([target] * perturbed_seqs.size(0)).to(self.device))
self.wb_model.zero_grad()
loss.backward()
wb_embedding_grad = wb_embedding.grad.data
self.wb_model.eval()
with torch.no_grad():
appended_indicies = (perturbed_seqs != self.CLOZE_MASK_TOKEN)
appended_indicies = (perturbed_seqs != 0) * appended_indicies
appended_indicies = torch.arange(perturbed_seqs.shape[1]).to(self.device) * appended_indicies
_, appended_indicies = torch.sort(appended_indicies, -1, descending=True)
appended_indicies = appended_indicies[:, :num_attack]
best_seqs = perturbed_seqs.clone().detach()
for num in range(num_attack):
row_indices = torch.arange(seqs.size(0))
col_indices = appended_indicies[:, num]
current_embedding = wb_embedding[row_indices, col_indices]
current_embedding_grad = wb_embedding_grad[row_indices, col_indices]
all_embeddings = self.item_embeddings.unsqueeze(1).repeat_interleave(current_embedding.size(0), 1)
cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
multipication_results = torch.t(cos(current_embedding-current_embedding_grad.sign(), all_embeddings))
_, candidate_indicies = torch.sort(multipication_results, dim=1, descending=True)
if num == 0:
multipication_results[:, target-1] = multipication_results[:, target-1] - 100000000
_, candidate_indicies = torch.sort(multipication_results, dim=1, descending=True)
best_seqs[row_indices, col_indices] = candidate_indicies[:, 0] + 1
if isinstance(self.wb_model, BERT) or isinstance(self.wb_model, SASRec):
logits = F.softmax(self.wb_model(best_seqs)[:, -1, :], dim=-1)
elif isinstance(self.wb_model, NARM):
logits = F.softmax(self.wb_model(best_seqs, lengths), dim=-1)
best_scores = torch.gather(logits, -1, torch.tensor([target] * best_seqs.size(0)).unsqueeze(1).to(self.device)).squeeze()
elif num > 0:
prev_col_indices = appended_indicies[:, num-1]
if_prev_target = (best_seqs[row_indices, prev_col_indices] == target)
multipication_results[:, target-1] = multipication_results[:, target-1] + (if_prev_target * -100000000)
_, candidate_indicies = torch.sort(multipication_results, dim=1, descending=True)
best_seqs[row_indices, col_indices] = best_seqs[row_indices, col_indices] * ~if_prev_target + \
(candidate_indicies[:, 0] + 1) * if_prev_target
if isinstance(self.wb_model, BERT) or isinstance(self.wb_model, SASRec):
logits = F.softmax(self.wb_model(best_seqs)[:, -1, :], dim=-1)
elif isinstance(self.wb_model, NARM):
logits = F.softmax(self.wb_model(best_seqs, lengths), dim=-1)
best_scores = torch.gather(logits, -1, torch.tensor([target] * best_seqs.size(0)).unsqueeze(1).to(self.device)).squeeze()
for time in range(repeated_search):
temp_seqs = best_seqs.clone().detach()
temp_seqs[row_indices, col_indices] = candidate_indicies[:, time] + 1
if isinstance(self.wb_model, BERT) or isinstance(self.wb_model, SASRec):
logits = F.softmax(self.wb_model(temp_seqs)[:, -1, :], dim=-1)
elif isinstance(self.wb_model, NARM):
logits = F.softmax(self.wb_model(temp_seqs, lengths), dim=-1)
temp_scores = torch.gather(logits, -1, torch.tensor([target] * temp_seqs.size(0)).unsqueeze(1).to(self.device)).squeeze()
best_seqs[row_indices, col_indices] = temp_seqs[row_indices, col_indices] * (temp_scores >= best_scores) + best_seqs[row_indices, col_indices] * (temp_scores < best_scores)
best_scores = temp_scores * (temp_scores >= best_scores) + best_scores * (temp_scores < best_scores)
best_seqs = best_seqs.detach()
best_scores = best_scores.detach()
del temp_scores
perturbed_seqs = best_seqs.detach()
if isinstance(self.wb_model, BERT) and isinstance(self.bb_model, BERT):
perturbed_scores = self.bb_model(perturbed_seqs)[:, -1, :]
elif isinstance(self.wb_model, BERT) and isinstance(self.bb_model, SASRec):
temp_seqs = torch.cat((torch.zeros(perturbed_seqs.size(0)).long().unsqueeze(1).to(self.device), perturbed_seqs[:, :-1]), dim=1)
perturbed_scores = self.bb_model(temp_seqs)[:, -1, :]
elif isinstance(self.wb_model, BERT) and isinstance(self.bb_model, NARM):
temp_seqs = torch.cat((torch.zeros(perturbed_seqs.size(0)).long().unsqueeze(1).to(self.device), perturbed_seqs[:, :-1]), dim=1)
temp_seqs = self.pre2post_padding(temp_seqs)
temp_lengths = (temp_seqs > 0).sum(-1).cpu().flatten()
perturbed_scores = self.bb_model(temp_seqs, temp_lengths)
elif isinstance(self.wb_model, SASRec) and isinstance(self.bb_model, SASRec):
perturbed_scores = self.bb_model(perturbed_seqs)[:, -1, :]
elif isinstance(self.wb_model, SASRec) and isinstance(self.bb_model, BERT):
temp_seqs = torch.cat((perturbed_seqs[:, 1:], torch.tensor([self.CLOZE_MASK_TOKEN] * perturbed_seqs.size(0)).unsqueeze(1).to(self.device)), dim=1)
perturbed_scores = self.bb_model(temp_seqs)[:, -1, :]
elif isinstance(self.wb_model, SASRec) and isinstance(self.bb_model, NARM):
temp_seqs = self.pre2post_padding(perturbed_seqs)
temp_lengths = (temp_seqs > 0).sum(-1).cpu().flatten()
perturbed_scores = self.bb_model(temp_seqs, temp_lengths)
elif isinstance(self.wb_model, NARM) and isinstance(self.bb_model, NARM):
perturbed_scores = self.bb_model(perturbed_seqs, lengths)
elif isinstance(self.wb_model, NARM) and isinstance(self.bb_model, BERT):
temp_seqs = self.post2pre_padding(perturbed_seqs)
temp_seqs = torch.cat((temp_seqs[:, 1:], torch.tensor([self.CLOZE_MASK_TOKEN] * perturbed_seqs.size(0)).unsqueeze(1).to(self.device)), dim=1)
perturbed_scores = self.bb_model(temp_seqs)[:, -1, :]
elif isinstance(self.wb_model, NARM) and isinstance(self.bb_model, SASRec):
temp_seqs = self.post2pre_padding(perturbed_seqs)
perturbed_scores = self.bb_model(temp_seqs)[:, -1, :]
candidates[:, 0] = torch.tensor([target] * candidates.size(0)).to(self.device)
perturbed_scores = perturbed_scores.gather(1, candidates)
metrics = recalls_and_ndcgs_for_ks(perturbed_scores, labels, self.metric_ks)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(tqdm_dataloader, average_meter_set)
average_metrics = average_meter_set.averages()
return average_metrics
def test(self, target=None):
if target is not None:
print('## Black-Box Targeted Test on Item {} ##'.format(str(target)))
else:
print('## Black-Box Untargeted Test on Item Level ##')
self.bb_model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):
seqs, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
scores = self.bb_model(seqs)[:, -1, :]
elif isinstance(self.bb_model, NARM):
seqs, lengths, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
lengths = lengths.flatten()
scores = self.bb_model(seqs, lengths)
if target is not None:
candidates[:, 0] = torch.tensor([target] * seqs.size(0)).to(self.device)
scores = scores.gather(1, candidates)
metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
average_metrics = average_meter_set.averages()
return average_metrics
def calculate_metrics(self, batch):
self.bb_model.eval()
if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):
seqs, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
scores = self.bb_model(seqs)[:, -1, :]
elif isinstance(self.bb_model, NARM):
seqs, lengths, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
lengths = lengths.flatten()
scores = self.bb_model(seqs, lengths)
scores = scores.gather(1, candidates) # B x C
metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
return metrics
def pre2post_padding(self, seqs):
processed = torch.zeros_like(seqs)
lengths = (seqs > 0).sum(-1).squeeze()
for i in range(seqs.size(0)):
processed[i, :lengths[i]] = seqs[i, seqs.size(1)-lengths[i]:]
return processed
def post2pre_padding(self, seqs):
processed = torch.zeros_like(seqs)
lengths = (seqs > 0).sum(-1).squeeze()
for i in range(seqs.size(0)):
processed[i, seqs.size(1)-lengths[i]:] = seqs[i, :lengths[i]]
return processed
def _update_meter_set(self, meter_set, metrics):
for k, v in metrics.items():
meter_set.update(k, v)
def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):
description_metrics = ['Recall@%d' % k for k in self.metric_ks[:3]] + ['NDCG@%d' % k for k in self.metric_ks[1:3]]
description = 'Val: ' + ', '.join(s + ' {:.3f}' for s in description_metrics)
description = description.replace('NDCG', 'N').replace('Recall', 'R')
description = description.format(*(meter_set[k].avg for k in description_metrics))
tqdm_dataloader.set_description(description)
import pickle
import shutil
import tempfile
import os
from pathlib import Path
import numpy as np
from abc import *
class AbstractPoisonedDataset(metaclass=ABCMeta):
def __init__(self, args, target, method_code, num_poisoned_seqs=0, num_original_seqs=0):
self.args = args
if isinstance(target, list):
self.target = target_spec = '_'.join([str(t) for t in target])
else:
self.target = target
self.method_code = method_code
self.num_poisoned_seqs = num_poisoned_seqs
self.num_original_seqs = num_original_seqs
@classmethod
@abstractmethod
def code(cls):
pass
@classmethod
def raw_code(cls):
return cls.code()
def check_data_present(self):
dataset_path = self._get_poisoned_dataset_path()
return dataset_path.is_file()
def load_dataset(self):
dataset_path = self._get_poisoned_dataset_path()
if not dataset_path.is_file():
print('Dataset not found, please generate distillation dataset first')
return
dataset = pickle.load(dataset_path.open('rb'))
return dataset
def save_dataset(self, tokens, original_dataset_size=0, valid_all=False):
original_dataset = dataset_factory(self.args)
original_dataset = original_dataset.load_dataset()
train = original_dataset['train']
val = original_dataset['val']
test = original_dataset['test']
self.num_poisoned_seqs = len(tokens)
self.num_original_seqs = len(train)
start_index = len(train) + 1
if original_dataset_size > 0:
sampled_users = np.random.choice(list(train.keys()), original_dataset_size)
train_ = {idx + 1: train[user] for idx, user in enumerate(sampled_users)}
val_ = {idx + 1: val[user] for idx, user in enumerate(sampled_users)}
test_ = {idx + 1: test[user] for idx, user in enumerate(sampled_users)}
train, val, test = train_, val_, test_
self.num_original_seqs = original_dataset_size
start_index = original_dataset_size + 1
self.poisoning_users = []
for i in range(len(tokens)):
items = tokens[i]
user = start_index + i
self.poisoning_users.append(user)
train[user], val[user], test[user] = items[:-2], items[-2:-1], items[-1:]
dataset_path = self._get_poisoned_dataset_path()
if not dataset_path.parent.is_dir():
dataset_path.parent.mkdir(parents=True)
dataset = {'train': train,
'val': val,
'test': test}
with dataset_path.open('wb') as f:
pickle.dump(dataset, f)
return self.num_poisoned_seqs, self.num_original_seqs, self.poisoning_users
def _get_rawdata_root_path(self):
return Path(GEN_DATASET_ROOT_FOLDER)
def _get_folder_path(self):
root = self._get_rawdata_root_path()
return root.joinpath(self.raw_code())
def _get_subfolder_path(self):
root = self._get_folder_path()
folder = 'poisoned' + str(self.num_poisoned_seqs) + '_' + 'original' + str(self.num_original_seqs)
return root.joinpath(self.method_code + '_target_' + str(self.target) + '_' + folder)
def _get_poisoned_dataset_path(self):
folder = self._get_subfolder_path()
return folder.joinpath('poisoned_dataset.pkl')
class ML1MPoisonedDataset(AbstractPoisonedDataset):
@classmethod
def code(cls):
return 'ml-1m'
class ML20MPoisonedDataset(AbstractPoisonedDataset):
@classmethod
def code(cls):
return 'ml-20m'
class BeautyPoisonedDataset(AbstractPoisonedDataset):
@classmethod
def code(cls):
return 'beauty'
class SteamPoisonedDataset(AbstractPoisonedDataset):
@classmethod
def code(cls):
return 'steam'
class YooChoosePoisonedDataset(AbstractPoisonedDataset):
@classmethod
def code(cls):
return 'yoochoose'
import torch
import torch.utils.data as data_utils
import random
POI_DATASETS = {
ML1MPoisonedDataset.code(): ML1MPoisonedDataset,
ML20MPoisonedDataset.code(): ML20MPoisonedDataset,
BeautyPoisonedDataset.code(): BeautyPoisonedDataset,
SteamPoisonedDataset.code(): SteamPoisonedDataset,
YooChoosePoisonedDataset.code(): YooChoosePoisonedDataset,
}
def poi_dataset_factory(args, target, method_code, num_poisoned_seqs=0, num_original_seqs=0):
dataset = POI_DATASETS[args.dataset_code]
return dataset(args, target, method_code, num_poisoned_seqs, num_original_seqs)
def poi_train_loader_factory(args, target, method_code, num_poisoned_seqs, num_original_seqs, poisoning_users=None):
dataset = poi_dataset_factory(args, target, method_code, num_poisoned_seqs, num_original_seqs)
if dataset.check_data_present():
dataloader = PoisonedDataLoader(args, dataset)
train, val, test = dataloader.get_loaders(poisoning_users)
return train, val, test
else:
return None
class PoisonedDataLoader():
def __init__(self, args, dataset):
self.args = args
self.rng = random.Random()
self.save_folder = dataset._get_subfolder_path()
dataset = dataset.load_dataset()
self.train = dataset['train']
self.val = dataset['val']
self.test = dataset['test']
self.user_count = len(self.train)
self.item_count = self.args.num_items
self.max_len = args.bert_max_len
self.mask_prob = args.bert_mask_prob
self.max_predictions = args.bert_max_predictions
self.sliding_size = args.sliding_window_size
self.CLOZE_MASK_TOKEN = self.args.num_items + 1
val_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,
self.train, self.val, self.test,
self.user_count, self.item_count,
args.test_negative_sample_size,
args.test_negative_sampling_seed,
'poisoned_val', self.save_folder)
test_negative_sampler = negative_sampler_factory(args.test_negative_sampler_code,
self.train, self.val, self.test,
self.user_count, self.item_count,
args.test_negative_sample_size,
args.test_negative_sampling_seed,
'poisoned_test', self.save_folder)
self.seen_samples, self.val_negative_samples = val_negative_sampler.get_negative_samples()
self.seen_samples, self.test_negative_samples = test_negative_sampler.get_negative_samples()
@classmethod
def code(cls):
return 'distillation_loader'
def get_loaders(self, poisoning_users=None):
train, val, test = self._get_datasets(poisoning_users)
train_loader = data_utils.DataLoader(train, batch_size=self.args.train_batch_size,
shuffle=True, pin_memory=True)
val_loader = data_utils.DataLoader(val, batch_size=self.args.train_batch_size,
shuffle=True, pin_memory=True)
test_loader = data_utils.DataLoader(test, batch_size=self.args.train_batch_size,
shuffle=True, pin_memory=True)
return train_loader, val_loader, test_loader
def _get_datasets(self, poisoning_users=None):
if self.args.model_code == 'bert':
train = BERTTrainDataset(self.train, self.max_len, self.mask_prob, self.max_predictions, self.sliding_size, self.CLOZE_MASK_TOKEN, self.item_count, self.rng)
val = BERTValidDataset(self.train, self.val, self.max_len, self.CLOZE_MASK_TOKEN, self.val_negative_samples, poisoning_users)
test = BERTTestDataset(self.train, self.val, self.test, self.max_len, self.CLOZE_MASK_TOKEN, self.test_negative_samples, poisoning_users)
elif self.args.model_code == 'sas':
train = SASTrainDataset(self.train, self.max_len, self.sliding_size, self.seen_samples, self.item_count, self.rng)
val = SASValidDataset(self.train, self.val, self.max_len, self.val_negative_samples, poisoning_users)
test = SASTestDataset(self.train, self.val, self.test, self.max_len, self.test_negative_samples, poisoning_users)
elif self.args.model_code == 'narm':
train = RNNTrainDataset(self.train, self.max_len)
val = RNNValidDataset(self.train, self.val, self.max_len, self.val_negative_samples, poisoning_users)
test = RNNTestDataset(self.train, self.val, self.test, self.max_len, self.test_negative_samples, poisoning_users)
return train, val, test
import os
import torch
from abc import ABCMeta, abstractmethod
def save_state_dict(state_dict, path, filename):
torch.save(state_dict, os.path.join(path, filename))
class LoggerService(object):
def __init__(self, train_loggers=None, val_loggers=None):
self.train_loggers = train_loggers if train_loggers else []
self.val_loggers = val_loggers if val_loggers else []
def complete(self, log_data):
for logger in self.train_loggers:
logger.complete(**log_data)
for logger in self.val_loggers:
logger.complete(**log_data)
def log_train(self, log_data):
for logger in self.train_loggers:
logger.log(**log_data)
def log_val(self, log_data):
for logger in self.val_loggers:
logger.log(**log_data)
class AbstractBaseLogger(metaclass=ABCMeta):
@abstractmethod
def log(self, *args, **kwargs):
raise NotImplementedError
def complete(self, *args, **kwargs):
pass
class RecentModelLogger(AbstractBaseLogger):
def __init__(self, checkpoint_path, filename='checkpoint-recent.pth'):
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
os.mkdir(self.checkpoint_path)
self.recent_epoch = None
self.filename = filename
def log(self, *args, **kwargs):
epoch = kwargs['epoch']
if self.recent_epoch != epoch:
self.recent_epoch = epoch
state_dict = kwargs['state_dict']
state_dict['epoch'] = kwargs['epoch']
save_state_dict(state_dict, self.checkpoint_path, self.filename)
def complete(self, *args, **kwargs):
save_state_dict(kwargs['state_dict'],
self.checkpoint_path, self.filename + '.final')
class BestModelLogger(AbstractBaseLogger):
def __init__(self, checkpoint_path, metric_key='mean_iou', filename='best_acc_model.pth'):
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
os.mkdir(self.checkpoint_path)
self.best_metric = 0.
self.metric_key = metric_key
self.filename = filename
def log(self, *args, **kwargs):
current_metric = kwargs[self.metric_key]
if self.best_metric < current_metric:
print("Update Best {} Model at {}".format(
self.metric_key, kwargs['epoch']))
self.best_metric = current_metric
save_state_dict(kwargs['state_dict'],
self.checkpoint_path, self.filename)
class MetricGraphPrinter(AbstractBaseLogger):
def __init__(self, writer, key='train_loss', graph_name='Train Loss', group_name='metric'):
self.key = key
self.graph_label = graph_name
self.group_name = group_name
self.writer = writer
def log(self, *args, **kwargs):
if self.key in kwargs:
self.writer.add_scalar(
self.group_name + '/' + self.graph_label, kwargs[self.key], kwargs['accum_iter'])
else:
self.writer.add_scalar(
self.group_name + '/' + self.graph_label, 0, kwargs['accum_iter'])
def complete(self, *args, **kwargs):
self.writer.close()
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
# from torch.autograd.gradcheck import zero_gradients
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import json
import math
import faiss
import numpy as np
from abc import *
from pathlib import Path
class PoisonedGroupRetrainer(metaclass=ABCMeta):
def __init__(self, args, wb_model_spec, wb_model, bb_model, original_test_loader, bb_model_root=None):
self.args = args
self.device = args.device
self.num_items = args.num_items
self.max_len = args.bert_max_len
self.wb_model_spec = wb_model_spec
self.wb_model = wb_model.to(self.device)
self.bb_model = bb_model.to(self.device)
self.is_parallel = args.num_gpu > 1
if self.is_parallel:
self.bb_model = nn.DataParallel(self.bb_model)
self.num_epochs = args.num_epochs
self.metric_ks = args.metric_ks
self.best_metric = args.best_metric
self.original_test_loader = original_test_loader
if bb_model_root == None:
self.bb_model_root = 'experiments/' + args.model_code + '/' + args.dataset_code
else:
self.bb_model_root = bb_model_root
if isinstance(self.wb_model, BERT):
self.item_embeddings = self.wb_model.embedding.token.weight.detach().cpu().numpy()[1:-1]
else:
self.item_embeddings = self.wb_model.embedding.token.weight.detach().cpu().numpy()[1:]
self.faiss_index = faiss.IndexFlatL2(self.item_embeddings.shape[-1])
self.faiss_index.add(self.item_embeddings)
self.item_embeddings = torch.tensor(self.item_embeddings).to(self.device)
self.CLOZE_MASK_TOKEN = args.num_items + 1
self.adv_ce = nn.CrossEntropyLoss(ignore_index=0)
if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, NARM):
self.ce = nn.CrossEntropyLoss(ignore_index=0)
elif isinstance(self.bb_model, SASRec):
self.ce = nn.BCEWithLogitsLoss()
def train_ours(self, targets, ratio, popular_items, num_items):
num_poisoned, num_original, poisoning_users = self.generate_poisoned_data(targets, popular_items, num_items)
target_spec = '_'.join([str(target) for target in targets])
self.train_loader, self.val_loader, self.test_loader = poi_train_loader_factory(self.args, target_spec, self.wb_model_spec, num_poisoned, num_original)
self.bb_model.load_state_dict(torch.load(os.path.join(self.bb_model_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))
self.export_root = 'experiments/retrained/' + self.wb_model_spec + '/' + self.args.dataset_code + '/ratio_' + str(ratio) + '_target_' + target_spec
self.writer, self.train_loggers, self.val_loggers = self._create_loggers()
self.logger_service = LoggerService(
self.train_loggers, self.val_loggers)
self.log_period_as_iter = self.args.log_period_as_iter
metrics_before, metrics_after = self.train(targets)
return metrics_before, metrics_after
def generate_poisoned_data(self, targets, popular_items, num_items, batch_size=50, sample_prob=0.0):
print('## Generate Biased Data with Target {} ##'.format(targets))
target_spec = '_'.join([str(target) for target in targets])
dataset = poi_dataset_factory(self.args, target_spec, self.wb_model_spec)
# if dataset.check_data_present():
# print('Dataset already exists. Skip generation')
# return
if isinstance(self.wb_model, BERT):
self.item_embeddings = self.wb_model.embedding.token.weight.detach().cpu().numpy()[1:-1]
else:
self.item_embeddings = self.wb_model.embedding.token.weight.detach().cpu().numpy()[1:]
self.item_embeddings = torch.tensor(self.item_embeddings).to(self.device)
batch_num = math.ceil(self.args.num_poisoned_seqs / batch_size)
print('Generating poisoned dataset...')
for i in tqdm(range(batch_num)):
if i == batch_num - 1 and self.args.num_poisoned_seqs % batch_size != 0:
batch_size = self.args.num_poisoned_seqs % batch_size
seqs = torch.tensor(np.random.choice(targets, size=batch_size)).reshape(batch_size, 1).to(self.device)
for j in range(self.max_len - 1):
self.wb_model.eval()
if j % 2 == 0:
selected_targets = torch.tensor(np.random.choice(targets, size=batch_size)).to(self.device)
rand_items = torch.tensor(np.random.choice(self.num_items, size=seqs.size(0))+1).to(self.device)
seqs = torch.cat((seqs, rand_items.unsqueeze(1)), 1)
if isinstance(self.wb_model, BERT):
mask_items = torch.tensor([self.CLOZE_MASK_TOKEN] * seqs.size(0)).to(self.device)
input_seqs = torch.zeros((seqs.size(0), self.max_len)).to(self.device)
if j < self.max_len - 2:
input_seqs[:, (self.max_len-3-j):-1] = seqs
elif j == self.max_len - 2:
input_seqs[:, :-1] = seqs[:, 1:]
input_seqs[:, -1] = mask_items
wb_embedding, mask = self.wb_model.embedding(input_seqs.long())
elif isinstance(self.wb_model, SASRec):
input_seqs = torch.zeros((seqs.size(0), self.max_len)).to(self.device)
input_seqs[:, (self.max_len-2-j):] = seqs
wb_embedding, mask = self.wb_model.embedding(input_seqs.long())
elif isinstance(self.wb_model, NARM):
input_seqs = seqs
lengths = torch.tensor([j + 2] * seqs.size(0))
wb_embedding, mask = self.wb_model.embedding(input_seqs, lengths)
self.wb_model.train()
wb_embedding = wb_embedding.detach().clone()
wb_embedding.requires_grad = True
zero_gradients(wb_embedding)
if isinstance(self.wb_model, BERT) or isinstance(self.wb_model, SASRec):
wb_scores = self.wb_model.model(wb_embedding, self.wb_model.embedding.token.weight, mask)[:, -1, :]
elif isinstance(self.wb_model, NARM):
wb_scores = self.wb_model.model(wb_embedding, self.wb_model.embedding.token.weight, lengths, mask)
loss = self.adv_ce(wb_scores, selected_targets)
self.wb_model.zero_grad()
loss.backward()
wb_embedding_grad = wb_embedding.grad.data
self.wb_model.eval()
with torch.no_grad():
if isinstance(self.wb_model, BERT):
current_embedding = wb_embedding[:, -2]
current_embedding_grad = wb_embedding_grad[:, -2]
else:
current_embedding = wb_embedding[:, -1]
current_embedding_grad = wb_embedding_grad[:, -1]
all_embeddings = self.item_embeddings.unsqueeze(1).repeat_interleave(current_embedding.size(0), 1)
cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
multipication_results = torch.t(cos(current_embedding-current_embedding_grad.sign(), all_embeddings))
multipication_results[torch.arange(seqs.size(0)), selected_targets-1] = multipication_results[torch.arange(seqs.size(0)), selected_targets-1] + 2
_, candidate_indicies = torch.sort(multipication_results, dim=1, descending=False)
sample_indices = torch.randint(0, 10, [seqs.size(0)])
seqs[:, -1] = candidate_indicies[torch.arange(seqs.size(0)), sample_indices] + 1
seqs = torch.cat((seqs, selected_targets.unsqueeze(1)), 1)
seqs = seqs[:, :self.max_len]
try:
batch_tokens = np.concatenate((batch_tokens, seqs.cpu().numpy()))
except:
batch_tokens = seqs.cpu().numpy()
num_poisoned, num_original, poisoning_users = dataset.save_dataset(batch_tokens.tolist(), original_dataset_size=self.args.num_original_seqs)
return num_poisoned, num_original, poisoning_users
def train(self, targets):
self.optimizer = self._create_optimizer()
if self.args.enable_lr_schedule:
if self.args.enable_lr_warmup:
self.lr_scheduler = self.get_linear_schedule_with_warmup(
self.optimizer, self.args.warmup_steps, len(train_loader) * self.num_epochs)
else:
self.lr_scheduler = optim.lr_scheduler.StepLR(
self.optimizer, step_size=self.args.decay_step, gamma=self.args.gamma)
print('## Biased Retrain on Item {} ##'.format(targets))
accum_iter = 0
for epoch in range(self.num_epochs):
accum_iter = self.train_one_epoch(epoch, accum_iter)
print('## Clean Black-Box Model Targeted Test on Item {} ##'.format(targets))
metrics_before = self.targeted_test(targets, load_retrained=False)
print('## Retrained Black-Box Model Targeted Test on Item {} ##'.format(targets))
metrics_after = self.targeted_test(targets, load_retrained=True)
self.logger_service.complete({
'state_dict': (self._create_state_dict()),
})
self.writer.close()
return metrics_before, metrics_after
def train_one_epoch(self, epoch, accum_iter):
self.bb_model.train()
average_meter_set = AverageMeterSet()
tqdm_dataloader = tqdm(self.train_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
self.optimizer.zero_grad()
if isinstance(self.bb_model, BERT):
seqs, labels = batch
seqs, labels = seqs.to(self.device), labels.to(self.device)
logits = self.bb_model(seqs)
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1)
loss = self.ce(logits, labels)
elif isinstance(self.bb_model, SASRec):
seqs, labels, negs = batch
seqs, labels, negs = seqs.to(self.device), labels.to(self.device), negs.to(self.device)
logits = self.bb_model(seqs) # F.softmax(self.bb_model(seqs), dim=-1)
pos_logits = logits.gather(-1, labels.unsqueeze(-1))[seqs > 0].squeeze()
pos_targets = torch.ones_like(pos_logits)
neg_logits = logits.gather(-1, negs.unsqueeze(-1))[seqs > 0].squeeze()
neg_targets = torch.zeros_like(neg_logits)
loss = self.ce(torch.cat((pos_logits, neg_logits), 0), torch.cat((pos_targets, neg_targets), 0))
elif isinstance(self.bb_model, NARM):
seqs, lengths, labels = batch
lengths = lengths.flatten()
seqs, labels = seqs.to(self.device), labels.to(self.device)
logits = self.bb_model(seqs, lengths)
loss = self.ce(logits, labels.squeeze())
loss.backward()
self.clip_gradients(5)
self.optimizer.step()
if self.args.enable_lr_schedule:
self.lr_scheduler.step()
average_meter_set.update('loss', loss.item())
tqdm_dataloader.set_description(
'Epoch {}, loss {:.3f} '.format(epoch+1, average_meter_set['loss'].avg))
accum_iter += seqs.size(0)
if self._needs_to_log(accum_iter):
tqdm_dataloader.set_description('Logging to Tensorboard')
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch + 1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
self.logger_service.log_train(log_data)
self.validate(epoch, accum_iter)
return accum_iter
def validate(self, epoch, accum_iter):
self.bb_model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.val_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
metrics = self.calculate_metrics(batch)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
log_data = {
'state_dict': (self._create_state_dict()),
'epoch': epoch+1,
'accum_iter': accum_iter,
}
log_data.update(average_meter_set.averages())
# self.log_extra_val_info(log_data)
self.logger_service.log_val(log_data)
def test(self, load_retrained=False):
if load_retrained:
best_model_dict = torch.load(os.path.join(
self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.bb_model.load_state_dict(best_model_dict)
else:
bb_model_dict = torch.load(os.path.join(
self.bb_model_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.bb_model.load_state_dict(bb_model_dict)
self.bb_model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.original_test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
metrics = self.calculate_metrics(batch)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
average_metrics = average_meter_set.averages()
with open(os.path.join(self.export_root, 'logs', 'test_metrics.json'), 'w') as f:
json.dump(average_metrics, f, indent=4)
return average_metrics
def targeted_test(self, targets, load_retrained=False):
if load_retrained:
best_model_dict = torch.load(os.path.join(
self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.bb_model.load_state_dict(best_model_dict)
else:
bb_model_dict = torch.load(os.path.join(
self.bb_model_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.bb_model.load_state_dict(bb_model_dict)
self.bb_model.eval()
average_meter_set = AverageMeterSet()
with torch.no_grad():
tqdm_dataloader = tqdm(self.original_test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):
seqs, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
scores = self.bb_model(seqs)[:, -1, :]
elif isinstance(self.bb_model, NARM):
seqs, lengths, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
lengths = lengths.flatten()
scores = self.bb_model(seqs, lengths)
for target in targets:
candidates[:, 0] = torch.tensor([target] * seqs.size(0)).to(self.device)
metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)
self._update_meter_set(average_meter_set, metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
average_metrics = average_meter_set.averages()
return average_metrics
def targeted_test_item(self, targets, load_retrained=False):
if load_retrained:
best_model_dict = torch.load(os.path.join(
self.export_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.bb_model.load_state_dict(best_model_dict)
else:
bb_model_dict = torch.load(os.path.join(
self.bb_model_root, 'models', 'best_acc_model.pth')).get(STATE_DICT_KEY)
self.bb_model.load_state_dict(bb_model_dict)
self.bb_model.eval()
average_meter_set = AverageMeterSet()
item_average_meter_set = {target: AverageMeterSet() for target in targets}
with torch.no_grad():
tqdm_dataloader = tqdm(self.original_test_loader)
for batch_idx, batch in enumerate(tqdm_dataloader):
if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):
seqs, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
scores = self.bb_model(seqs)[:, -1, :]
elif isinstance(self.bb_model, NARM):
seqs, lengths, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
lengths = lengths.flatten()
scores = self.bb_model(seqs, lengths)
for target in targets:
candidates[:, 0] = torch.tensor([target] * seqs.size(0)).to(self.device)
metrics = recalls_and_ndcgs_for_ks(scores.gather(1, candidates), labels, self.metric_ks)
self._update_meter_set(average_meter_set, metrics)
self._update_meter_set(item_average_meter_set[target], metrics)
self._update_dataloader_metrics(
tqdm_dataloader, average_meter_set)
average_metrics = average_meter_set.averages()
for target in targets:
item_average_meter_set[target] = item_average_meter_set[target].averages()
return average_metrics, item_average_meter_set
def calculate_metrics(self, batch):
self.bb_model.eval()
if isinstance(self.bb_model, BERT) or isinstance(self.bb_model, SASRec):
seqs, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
scores = self.bb_model(seqs)[:, -1, :]
elif isinstance(self.bb_model, NARM):
seqs, lengths, candidates, labels = batch
seqs, candidates, labels = seqs.to(self.device), candidates.to(self.device), labels.to(self.device)
lengths = lengths.flatten()
scores = self.bb_model(seqs, lengths)
scores = scores.gather(1, candidates) # B x C
metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
return metrics
def clip_gradients(self, limit=5):
for p in self.bb_model.parameters():
nn.utils.clip_grad_norm_(p, 5)
def _update_meter_set(self, meter_set, metrics):
for k, v in metrics.items():
meter_set.update(k, v)
def _update_dataloader_metrics(self, tqdm_dataloader, meter_set):
description_metrics = ['NDCG@%d' % k for k in self.metric_ks[:3]
] + ['Recall@%d' % k for k in self.metric_ks[:3]]
description = 'Eval: ' + \
', '.join(s + ' {:.3f}' for s in description_metrics)
description = description.replace('NDCG', 'N').replace('Recall', 'R')
description = description.format(
*(meter_set[k].avg for k in description_metrics))
tqdm_dataloader.set_description(description)
def _create_optimizer(self):
args = self.args
param_optimizer = list(self.bb_model.named_parameters())
no_decay = ['bias', 'layer_norm']
optimizer_grouped_parameters = [
{
'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay,
},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]
if args.optimizer.lower() == 'adamw':
return optim.AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
elif args.optimizer.lower() == 'adam':
return optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'sgd':
return optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
else:
raise ValueError
def get_linear_schedule_with_warmup(self, optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
# based on hugging face get_linear_schedule_with_warmup
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def _create_loggers(self):
root = Path(self.export_root)
writer = SummaryWriter(root.joinpath('logs'))
model_checkpoint = root.joinpath('models')
train_loggers = [
MetricGraphPrinter(writer, key='epoch',
graph_name='Epoch', group_name='Train'),
MetricGraphPrinter(writer, key='loss',
graph_name='Loss', group_name='Train'),
]
val_loggers = []
for k in self.metric_ks:
val_loggers.append(
MetricGraphPrinter(writer, key='NDCG@%d' % k, graph_name='NDCG@%d' % k, group_name='Validation'))
val_loggers.append(
MetricGraphPrinter(writer, key='Recall@%d' % k, graph_name='Recall@%d' % k, group_name='Validation'))
val_loggers.append(RecentModelLogger(model_checkpoint))
val_loggers.append(BestModelLogger(
model_checkpoint, metric_key=self.best_metric))
return writer, train_loggers, val_loggers
def _create_state_dict(self):
return {
STATE_DICT_KEY: self.bb_model.module.state_dict() if self.is_parallel else self.bb_model.state_dict(),
OPTIMIZER_STATE_DICT_KEY: self.optimizer.state_dict(),
}
def _needs_to_log(self, accum_iter):
return accum_iter % self.log_period_as_iter < self.args.train_batch_size and accum_iter != 0
import json
import os
import pprint as pp
import random
from datetime import date
from pathlib import Path
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch import optim as optim
def recall(scores, labels, k):
scores = scores
labels = labels
rank = (-scores).argsort(dim=1)
cut = rank[:, :k]
hit = labels.gather(1, cut)
return (hit.sum(1).float() / torch.min(torch.Tensor([k]).to(hit.device), labels.sum(1).float())).mean().cpu().item()
def ndcg(scores, labels, k):
scores = scores.cpu()
labels = labels.cpu()
rank = (-scores).argsort(dim=1)
cut = rank[:, :k]
hits = labels.gather(1, cut)
position = torch.arange(2, 2+k)
weights = 1 / torch.log2(position.float())
dcg = (hits.float() * weights).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum()
for n in labels.sum(1)])
ndcg = dcg / idcg
return ndcg.mean()
def recalls_and_ndcgs_for_ks(scores, labels, ks):
metrics = {}
scores = scores
labels = labels
answer_count = labels.sum(1)
labels_float = labels.float()
rank = (-scores).argsort(dim=1)
cut = rank
for k in sorted(ks, reverse=True):
cut = cut[:, :k]
hits = labels_float.gather(1, cut)
metrics['Recall@%d' % k] = \
(hits.sum(1) / torch.min(torch.Tensor([k]).to(
labels.device), labels.sum(1).float())).mean().cpu().item()
position = torch.arange(2, 2+k)
weights = 1 / torch.log2(position.float())
dcg = (hits * weights.to(hits.device)).sum(1)
idcg = torch.Tensor([weights[:min(int(n), k)].sum()
for n in answer_count]).to(dcg.device)
ndcg = (dcg / idcg).mean()
metrics['NDCG@%d' % k] = ndcg.cpu().item()
return metrics
def setup_train(args):
set_up_gpu(args)
export_root = create_experiment_export_folder(args)
export_experiments_config_as_json(args, export_root)
pp.pprint({k: v for k, v in vars(args).items() if v is not None}, width=1)
return export_root
def create_experiment_export_folder(args):
experiment_dir, experiment_description = args.experiment_dir, args.experiment_description
if not os.path.exists(experiment_dir):
os.mkdir(experiment_dir)
experiment_path = get_name_of_experiment_path(
experiment_dir, experiment_description)
os.mkdir(experiment_path)
print('Folder created: ' + os.path.abspath(experiment_path))
return experiment_path
def get_name_of_experiment_path(experiment_dir, experiment_description):
experiment_path = os.path.join(
experiment_dir, (experiment_description + "_" + str(date.today())))
idx = _get_experiment_index(experiment_path)
experiment_path = experiment_path + "_" + str(idx)
return experiment_path
def _get_experiment_index(experiment_path):
idx = 0
while os.path.exists(experiment_path + "_" + str(idx)):
idx += 1
return idx
def load_weights(model, path):
pass
def save_test_result(export_root, result):
filepath = Path(export_root).joinpath('test_result.txt')
with filepath.open('w') as f:
json.dump(result, f, indent=2)
def export_experiments_config_as_json(args, experiment_path):
with open(os.path.join(experiment_path, 'config.json'), 'w') as outfile:
json.dump(vars(args), outfile, indent=2)
def fix_random_seed_as(random_seed):
random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
cudnn.deterministic = True
cudnn.benchmark = False
def set_up_gpu(args):
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_idx
args.num_gpu = len(args.device_idx.split(","))
def load_pretrained_weights(model, path):
chk_dict = torch.load(os.path.abspath(path))
model_state_dict = chk_dict[STATE_DICT_KEY] if STATE_DICT_KEY in chk_dict else chk_dict['state_dict']
model.load_state_dict(model_state_dict)
def setup_to_resume(args, model, optimizer):
chk_dict = torch.load(os.path.join(os.path.abspath(
args.resume_training), 'models/checkpoint-recent.pth'))
model.load_state_dict(chk_dict[STATE_DICT_KEY])
optimizer.load_state_dict(chk_dict[OPTIMIZER_STATE_DICT_KEY])
def create_optimizer(model, args):
if args.optimizer == 'Adam':
return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
return optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
class AverageMeterSet(object):
def __init__(self, meters=None):
self.meters = meters if meters else {}
def __getitem__(self, key):
if key not in self.meters:
meter = AverageMeter()
meter.update(0)
return meter
return self.meters[key]
def update(self, name, value, n=1):
if name not in self.meters:
self.meters[name] = AverageMeter()
self.meters[name].update(value, n)
def reset(self):
for meter in self.meters.values():
meter.reset()
def values(self, format_string='{}'):
return {format_string.format(name): meter.val for name, meter in self.meters.items()}
def averages(self, format_string='{}'):
return {format_string.format(name): meter.avg for name, meter in self.meters.items()}
def sums(self, format_string='{}'):
return {format_string.format(name): meter.sum for name, meter in self.meters.items()}
def counts(self, format_string='{}'):
return {format_string.format(name): meter.count for name, meter in self.meters.items()}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val
self.count += n
self.avg = self.sum / self.count
def __format__(self, format):
return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)
import argparse
import torch
from pathlib import Path
from collections import defaultdict
def attack(args, attack_item_num=2, bb_model_root=None):
fix_random_seed_as(args.model_init_seed)
_, _, test_loader = dataloader_factory(args)
model_codes = {'b': 'bert', 's':'sas', 'n':'narm'}
wb_model_code = model_codes[input('Input white box model code, b for BERT, s for SASRec and n for NARM: ')]
wb_model_folder = {}
folder_list = [item for item in os.listdir('experiments/distillation_rank/') if (args.model_code + '2' + wb_model_code in item)]
for idx, folder_name in enumerate(folder_list):
wb_model_folder[idx + 1] = folder_name
wb_model_folder[idx + 2] = args.model_code + '_black_box'
print(wb_model_folder)
wb_model_spec = wb_model_folder[int(input('Input index of desired white box model: '))]
wb_model_root = 'experiments/distillation_rank/' + wb_model_spec + '/' + args.dataset_code
if wb_model_spec == args.model_code + '_black_box':
wb_model_root = 'experiments/' + args.model_code + '/' + args.dataset_code
if bb_model_root == None:
bb_model_root = 'experiments/' + args.model_code + '/' + args.dataset_code
if wb_model_code == 'bert':
wb_model = BERT(args)
elif wb_model_code == 'sas':
wb_model = SASRec(args)
elif wb_model_code == 'narm':
wb_model = NARM(args)
if args.model_code == 'bert':
bb_model = BERT(args)
elif args.model_code == 'sas':
bb_model = SASRec(args)
elif args.model_code == 'narm':
bb_model = NARM(args)
bb_model.load_state_dict(torch.load(os.path.join(bb_model_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))
wb_model.load_state_dict(torch.load(os.path.join(wb_model_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))
item_counter = defaultdict(int)
dataset = dataset_factory(args)
dataset = dataset.load_dataset()
train = dataset['train']
val = dataset['val']
test = dataset['test']
for user in train.keys():
seqs = train[user] + val[user] + test[user]
for i in seqs:
item_counter[i] += 1
item_popularity = []
for i in item_counter.keys():
item_popularity.append((item_counter[i], i))
item_popularity.sort(reverse=True)
attacker = AdversarialRankAttacker(args, wb_model, bb_model, test_loader)
item_id = []
item_rank = []
item_R1_before, item_R5_before, item_R10_before, item_N5_before, item_N10_before = [], [], [], [], []
item_R1_ours, item_R5_ours, item_R10_ours, item_N5_ours, item_N10_ours = [], [], [], [], []
step = len(item_popularity) // 25
attack_ranks = list(range(0, len(item_popularity), step))[:25]
for i in attack_ranks:
item = item_popularity[i][1]
metrics_before = attacker.test(target=item)
metrics_ours = attacker.attack(target=item, num_attack=attack_item_num)
item_id.append(item)
item_rank.append(i)
item_R1_before.append(metrics_before['Recall@1'])
item_R5_before.append(metrics_before['Recall@5'])
item_R10_before.append(metrics_before['Recall@10'])
item_N5_before.append(metrics_before['NDCG@5'])
item_N10_before.append(metrics_before['NDCG@10'])
item_R1_ours.append(metrics_ours['Recall@1'])
item_R5_ours.append(metrics_ours['Recall@5'])
item_R10_ours.append(metrics_ours['Recall@10'])
item_N5_ours.append(metrics_ours['NDCG@5'])
item_N10_ours.append(metrics_ours['NDCG@10'])
attack_metrics = {
'item_id': item_id,
'item_rank': item_rank,
'item_R1_before': item_R1_before,
'item_R5_before': item_R5_before,
'item_R10_before': item_R10_before,
'item_N5_before': item_N5_before,
'item_N10_before': item_N10_before,
'item_R1_ours': item_R1_ours,
'item_R5_ours': item_R5_ours,
'item_R10_ours': item_R10_ours,
'item_N5_ours': item_N5_ours,
'item_N10_ours': item_N10_ours,
}
metrics_root = 'experiments/attack_rank/' + wb_model_spec + '/' + args.dataset_code
if not Path(metrics_root).is_dir():
Path(metrics_root).mkdir(parents=True)
with open(os.path.join(metrics_root, 'attack_bb_metrics.json'), 'w') as f:
json.dump(attack_metrics, f, indent=4)
if __name__ == "__main__":
set_template(args)
# when use k-core beauty and k is not 5 (beauty-dense)
# args.min_uc = k
# args.min_sc = k
if args.dataset_code == 'ml-1m':
args.num_epochs = 5
attack(args=args, attack_item_num=1)
else:
attack(args=args, attack_item_num=2)
import argparse
import torch
import copy
from pathlib import Path
from collections import defaultdict
def retrain(args, bb_model_root=None):
fix_random_seed_as(args.model_init_seed)
_, _, test_loader = dataloader_factory(args)
model_codes = {'b': 'bert', 's':'sas', 'n':'narm'}
wb_model_code = model_codes[input('Input white box model code, b for BERT, s for SASRec and n for NARM: ')]
wb_model_folder = {}
folder_list = [item for item in os.listdir('experiments/distillation_rank/') if (args.model_code + '2' + wb_model_code in item)]
for idx, folder_name in enumerate(folder_list):
wb_model_folder[idx + 1] = folder_name
wb_model_folder[idx + 2] = args.model_code + '_black_box'
print(wb_model_folder)
wb_model_spec = wb_model_folder[int(input('Input index of desired white box model: '))]
wb_model_root = 'experiments/distillation_rank/' + wb_model_spec + '/' + args.dataset_code
if wb_model_spec == args.model_code + '_black_box':
wb_model_root = 'experiments/' + args.model_code + '/' + args.dataset_code
if bb_model_root == None:
bb_model_root = 'experiments/' + args.model_code + '/' + args.dataset_code
if args.model_code == 'bert':
bb_model = BERT(args)
elif args.model_code == 'sas':
bb_model = SASRec(args)
elif args.model_code == 'narm':
bb_model = NARM(args)
if wb_model_code == 'bert':
wb_model = BERT(args)
elif wb_model_code == 'sas':
wb_model = SASRec(args)
elif wb_model_code == 'narm':
wb_model = NARM(args)
item_counter = defaultdict(int)
dataset = dataset_factory(args)
dataset = dataset.load_dataset()
train = dataset['train']
val = dataset['val']
test = dataset['test']
lengths = []
for user in train.keys():
seqs = train[user] + val[user] + test[user]
lengths.append(len(seqs))
for i in seqs:
item_counter[i] += 1
item_popularity = []
for i in item_counter.keys():
item_popularity.append((item_counter[i], i))
item_popularity.sort(reverse=True)
wb_model.load_state_dict(torch.load(os.path.join(wb_model_root, 'models', 'best_acc_model.pth'), map_location='cpu').get(STATE_DICT_KEY))
step = len(item_popularity) // 25
popular_items = [item_popularity[i][1] for i in range(int(0.05*len(item_popularity)))]
attack_ranks = list(range(0, len(item_popularity), step))[:25]
targets = [item_popularity[i][1] for i in attack_ranks]
bb_poisoned_metrics = {}
all_ratios = [0.01]
for ratio in all_ratios:
args.num_poisoned_seqs = int(len(train) * ratio)
retrainer = PoisonedGroupRetrainer(args, wb_model_spec, wb_model, bb_model, test_loader)
metrics_before, metrics_bb_after = retrainer.train_ours(targets, ratio, popular_items, int(0.05*len(item_popularity)))
bb_poisoned_metrics[ratio] = {
'before': metrics_before,
'ours': metrics_bb_after,
}
metrics_root = 'experiments/retrained/' + wb_model_spec + '/' + args.dataset_code
if not Path(metrics_root).is_dir():
Path(metrics_root).mkdir(parents=True)
with open(os.path.join(metrics_root, 'retrained_bb_metrics.json'), 'w') as f:
json.dump(bb_poisoned_metrics, f, indent=4)
if __name__ == "__main__":
set_template(args)
# when use k-core beauty and k is not 5 (beauty-dense)
# args.min_uc = k
# args.min_sc = k
args.num_epochs = 5
retrain(args=args)
Already preprocessed. Skip preprocessing Negatives samples exist. Loading. Negatives samples exist. Loading. Input white box model code, b for BERT, s for SASRec and n for NARM: b {1: 'bert2bert_autoregressive5', 2: 'bert_black_box'} Input index of desired white box model: 1 Already preprocessed. Skip preprocessing ## Generate Biased Data with Target [2459, 1009, 2135, 918, 3233, 1226, 498, 2917, 1332, 3184, 264, 2490, 1696, 1448, 144, 365, 1368, 2714, 1874, 3285, 2235, 3406, 3155, 1322, 2928] ## Generating poisoned dataset...
100%|██████████| 2/2 [00:07<00:00, 3.96s/it]
Already preprocessed. Skip preprocessing Negatives samples exist. Loading. Negatives samples exist. Loading. ## Biased Retrain on Item [2459, 1009, 2135, 918, 3233, 1226, 498, 2917, 1332, 3184, 264, 2490, 1696, 1448, 144, 365, 1368, 2714, 1874, 3285, 2235, 3406, 3155, 1322, 2928] ##
Epoch 1, loss 7.003 : 100%|██████████| 69/69 [00:12<00:00, 5.48it/s] Eval: N@1 0.078, N@5 0.175, N@10 0.225, R@1 0.078, R@5 0.270, R@10 0.423: 100%|██████████| 48/48 [00:03<00:00, 13.98it/s]
Update Best NDCG@10 Model at 1
Epoch 2, loss 6.993 : 100%|██████████| 69/69 [00:12<00:00, 5.44it/s] Eval: N@1 0.080, N@5 0.178, N@10 0.225, R@1 0.080, R@5 0.273, R@10 0.420: 100%|██████████| 48/48 [00:03<00:00, 13.96it/s]
Update Best NDCG@10 Model at 2
Epoch 3, loss 6.988 : 100%|██████████| 69/69 [00:12<00:00, 5.48it/s] Eval: N@1 0.085, N@5 0.179, N@10 0.230, R@1 0.085, R@5 0.269, R@10 0.426: 100%|██████████| 48/48 [00:03<00:00, 14.07it/s]
Update Best NDCG@10 Model at 3
Epoch 4, loss 6.986 : 100%|██████████| 69/69 [00:12<00:00, 5.49it/s] Eval: N@1 0.083, N@5 0.180, N@10 0.229, R@1 0.083, R@5 0.275, R@10 0.427: 100%|██████████| 48/48 [00:03<00:00, 13.97it/s] Epoch 5, loss 6.981 : 100%|██████████| 69/69 [00:12<00:00, 5.51it/s] Eval: N@1 0.078, N@5 0.179, N@10 0.227, R@1 0.078, R@5 0.274, R@10 0.423: 100%|██████████| 48/48 [00:03<00:00, 13.97it/s]
## Clean Black-Box Model Targeted Test on Item [2459, 1009, 2135, 918, 3233, 1226, 498, 2917, 1332, 3184, 264, 2490, 1696, 1448, 144, 365, 1368, 2714, 1874, 3285, 2235, 3406, 3155, 1322, 2928] ##
Eval: N@1 0.041, N@5 0.062, N@10 0.078, R@1 0.041, R@5 0.086, R@10 0.136: 100%|██████████| 48/48 [00:19<00:00, 2.45it/s]
## Retrained Black-Box Model Targeted Test on Item [2459, 1009, 2135, 918, 3233, 1226, 498, 2917, 1332, 3184, 264, 2490, 1696, 1448, 144, 365, 1368, 2714, 1874, 3285, 2235, 3406, 3155, 1322, 2928] ##
Eval: N@1 0.044, N@5 0.066, N@10 0.080, R@1 0.044, R@5 0.089, R@10 0.134: 100%|██████████| 48/48 [00:19<00:00, 2.43it/s]
!apt-get install tree
Reading package lists... Done Building dependency tree Reading state information... Done The following NEW packages will be installed: tree 0 upgraded, 1 newly installed, 0 to remove and 40 not upgraded. Need to get 40.7 kB of archives. After this operation, 105 kB of additional disk space will be used. Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 tree amd64 1.7.0-5 [40.7 kB] Fetched 40.7 kB in 0s (110 kB/s) Selecting previously unselected package tree. (Reading database ... 148560 files and directories currently installed.) Preparing to unpack .../tree_1.7.0-5_amd64.deb ... Unpacking tree (1.7.0-5) ... Setting up tree (1.7.0-5) ... Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
!tree .
. ├── data │ ├── ml-1m │ │ ├── movies.dat │ │ ├── ratings.dat │ │ ├── README │ │ └── users.dat │ └── preprocessed │ └── ml-1m_min_rating0-min_uc5-min_sc5-splitleave_one_out │ ├── dataset.pkl │ ├── random-sample_size100-seed98765-test.pkl │ └── random-sample_size100-seed98765-val.pkl ├── experiments │ ├── bert │ │ └── ml-1m │ │ ├── logs │ │ │ ├── events.out.tfevents.1631626670.6d6b3d5241b9.74.0 │ │ │ └── test_metrics.json │ │ └── models │ │ ├── best_acc_model.pth │ │ ├── checkpoint-recent.pth │ │ └── checkpoint-recent.pth.final │ ├── distillation_rank │ │ └── bert2bert_autoregressive5 │ │ └── ml-1m │ │ ├── logs │ │ │ ├── events.out.tfevents.1631626837.6d6b3d5241b9.74.1 │ │ │ ├── events.out.tfevents.1631627214.6d6b3d5241b9.74.2 │ │ │ ├── events.out.tfevents.1631627292.6d6b3d5241b9.74.3 │ │ │ ├── events.out.tfevents.1631627333.6d6b3d5241b9.74.4 │ │ │ ├── events.out.tfevents.1631627518.6d6b3d5241b9.74.5 │ │ │ ├── events.out.tfevents.1631627759.6d6b3d5241b9.74.6 │ │ │ └── test_metrics.json │ │ └── models │ │ ├── best_acc_model.pth │ │ ├── checkpoint-recent.pth │ │ └── checkpoint-recent.pth.final │ └── retrained │ └── bert2bert_autoregressive5 │ └── ml-1m │ ├── ratio_0.01_target_2459_1009_2135_918_3233_1226_498_2917_1332_3184_264_2490_1696_1448_144_365_1368_2714_1874_3285_2235_3406_3155_1322_2928 │ │ ├── logs │ │ │ ├── events.out.tfevents.1631630835.6d6b3d5241b9.74.7 │ │ │ └── events.out.tfevents.1631630902.6d6b3d5241b9.74.8 │ │ └── models │ │ ├── best_acc_model.pth │ │ ├── checkpoint-recent.pth │ │ └── checkpoint-recent.pth.final │ └── retrained_bb_metrics.json ├── gen_data │ └── ml-1m │ ├── bert2bert_autoregressive5_target_2459_1009_2135_918_3233_1226_498_2917_1332_3184_264_2490_1696_1448_144_365_1368_2714_1874_3285_2235_3406_3155_1322_2928_poisoned60_original6040 │ │ ├── poisoned_dataset.pkl │ │ ├── random-sample_size100-seed98765-poisoned_test.pkl │ │ └── random-sample_size100-seed98765-poisoned_val.pkl │ └── bert_5 │ └── autoregressive_dataset.pkl └── sample_data ├── anscombe.json ├── california_housing_test.csv ├── california_housing_train.csv ├── mnist_test.csv ├── mnist_train_small.csv └── README.md 25 directories, 38 files