import time
import math
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
def _check_and_convert_ratio(test_size, multi_ratios):
if not test_size and not multi_ratios:
raise ValueError("must provide either 'test_size' or 'multi_ratios'")
elif test_size is not None:
assert isinstance(test_size, float), "test_size must be float value"
assert 0.0 < test_size < 1.0, "test_size must be in (0.0, 1.0)"
ratios = [1 - test_size, test_size]
return ratios, 2
elif isinstance(multi_ratios, (list, tuple)):
assert len(multi_ratios) > 1, (
"multi_ratios must at least have two elements")
assert all([r > 0.0 for r in multi_ratios]), (
"ratios should be positive values")
if math.fsum(multi_ratios) != 1.0:
ratios = [r / math.fsum(multi_ratios) for r in multi_ratios]
else:
ratios = multi_ratios
return ratios, len(ratios)
else:
raise ValueError("multi_ratios should be list or tuple")
def _filter_unknown_user_item(data_list):
train_data = data_list[0]
unique_values = dict(user=set(train_data.user.tolist()),
item=set(train_data.item.tolist()))
split_data_all = [train_data]
for i, test_data in enumerate(data_list[1:], start=1):
# print(f"Non_train_data {i} size before filtering: {len(test_data)}")
out_of_bounds_row_indices = set()
for col in ["user", "item"]:
for j, val in enumerate(test_data[col]):
if val not in unique_values[col]:
out_of_bounds_row_indices.add(j)
mask = np.arange(len(test_data))
test_data_clean = test_data[~np.isin(
mask, list(out_of_bounds_row_indices))]
split_data_all.append(test_data_clean)
# print(f"Non_train_data {i} size after filtering: "
# f"{len(test_data_clean)}")
return split_data_all
def _pad_unknown_user_item(data_list):
train_data, test_data = data_list
n_users = train_data.user.nunique()
n_items = train_data.item.nunique()
unique_users = set(train_data.user.tolist())
unique_items = set(train_data.item.tolist())
split_data_all = [train_data]
for i, test_data in enumerate(data_list[1:], start=1):
test_data.loc[~test_data.user.isin(unique_users), "user"] = n_users
test_data.loc[~test_data.item.isin(unique_items), "item"] = n_items
split_data_all.append(test_data)
return split_data_all
def _groupby_user(user_indices, order):
sort_kind = "mergesort" if order else "quicksort"
users, user_position, user_counts = np.unique(user_indices,
return_inverse=True,
return_counts=True)
user_split_indices = np.split(np.argsort(user_position, kind=sort_kind),
np.cumsum(user_counts)[:-1])
return user_split_indices
def random_split(data, test_size=None, multi_ratios=None, shuffle=True,
filter_unknown=True, pad_unknown=False, seed=42):
ratios, n_splits = _check_and_convert_ratio(test_size, multi_ratios)
if not isinstance(ratios, list):
ratios = list(ratios)
# if we want to split data in multiple folds,
# then iteratively split data based on modified ratios
train_data = data.copy()
split_data_all = []
for i in range(n_splits - 1):
size = ratios.pop(-1)
ratios = [r / math.fsum(ratios) for r in ratios]
train_data, split_data = train_test_split(train_data,
test_size=size,
shuffle=shuffle,
random_state=seed)
split_data_all.insert(0, split_data)
split_data_all.insert(0, train_data) # insert final fold of data
if filter_unknown:
split_data_all = _filter_unknown_user_item(split_data_all)
elif pad_unknown:
split_data_all = _pad_unknown_user_item(split_data_all)
return split_data_all
def split_by_ratio(data, order=True, shuffle=False, test_size=None,
multi_ratios=None, filter_unknown=True, pad_unknown=False,
seed=42):
np.random.seed(seed)
assert ("user" in data.columns), "data must contains user column"
ratios, n_splits = _check_and_convert_ratio(test_size, multi_ratios)
n_users = data.user.nunique()
user_indices = data.user.to_numpy()
user_split_indices = _groupby_user(user_indices, order)
cum_ratios = np.cumsum(ratios).tolist()[:-1]
split_indices_all = [[] for _ in range(n_splits)]
for u in range(n_users):
u_data = user_split_indices[u]
u_data_len = len(u_data)
if u_data_len <= 3: # keep items of rare users in trainset
split_indices_all[0].extend(u_data)
else:
u_split_data = np.split(u_data, [
round(cum * u_data_len) for cum in cum_ratios
])
for i in range(n_splits):
split_indices_all[i].extend(list(u_split_data[i]))
if shuffle:
split_data_all = tuple(
np.random.permutation(data[idx]) for idx in split_indices_all)
else:
split_data_all = list(data.iloc[idx] for idx in split_indices_all)
if filter_unknown:
split_data_all = _filter_unknown_user_item(split_data_all)
elif pad_unknown:
split_data_all = _pad_unknown_user_item(split_data_all)
return split_data_all
def split_by_num(data, order=True, shuffle=False, test_size=1,
filter_unknown=True, pad_unknown=False, seed=42):
np.random.seed(seed)
assert ("user" in data.columns), "data must contains user column"
assert isinstance(test_size, int), "test_size must be int value"
assert 0 < test_size < len(data), "test_size must be in (0, len(data))"
n_users = data.user.nunique()
user_indices = data.user.to_numpy()
user_split_indices = _groupby_user(user_indices, order)
train_indices = []
test_indices = []
for u in range(n_users):
u_data = user_split_indices[u]
u_data_len = len(u_data)
if u_data_len <= 3: # keep items of rare users in trainset
train_indices.extend(u_data)
elif u_data_len <= test_size:
train_indices.extend(u_data[:-1])
test_indices.extend(u_data[-1:])
else:
k = test_size
train_indices.extend(u_data[:(u_data_len-k)])
test_indices.extend(u_data[-k:])
if shuffle:
train_indices = np.random.permutation(train_indices)
test_indices = np.random.permutation(test_indices)
split_data_all = (data.iloc[train_indices], data.iloc[test_indices])
if filter_unknown:
split_data_all = _filter_unknown_user_item(split_data_all)
elif pad_unknown:
split_data_all = _pad_unknown_user_item(split_data_all)
return split_data_all
def split_by_ratio_chrono(data, order=True, shuffle=False, test_size=None,
multi_ratios=None, seed=42):
assert all([
"user" in data.columns,
"time" in data.columns
]), "data must contains user and time column"
data.sort_values(by=["time"], inplace=True)
data.reset_index(drop=True, inplace=True)
return split_by_ratio(**locals())
def split_by_num_chrono(data, order=True, shuffle=False, test_size=1, seed=42):
assert all([
"user" in data.columns,
"time" in data.columns
]), "data must contains user and time column"
data.sort_values(by=["time"], inplace=True)
data.reset_index(drop=True, inplace=True)
return split_by_num(**locals())
data = pd.read_csv('sample_movielens_merged.csv')
data
user | item | label | time | sex | age | occupation | genre1 | genre2 | genre3 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 4617 | 260 | 3 | 964137835 | F | 25 | 6 | action | adventure | fantasy |
1 | 4617 | 2468 | 4 | 964137948 | F | 25 | 6 | action | comedy | romance |
2 | 4617 | 593 | 5 | 964138086 | F | 25 | 6 | drama | thriller | missing |
3 | 4617 | 296 | 2 | 964138229 | F | 25 | 6 | crime | drama | missing |
4 | 4617 | 608 | 2 | 964138310 | F | 25 | 6 | crime | drama | thriller |
5 | 4617 | 780 | 3 | 964138459 | F | 25 | 6 | action | sci-fi | war |
6 | 4617 | 1643 | 3 | 964138734 | F | 25 | 6 | drama | romance | missing |
7 | 4617 | 440 | 3 | 964138734 | F | 25 | 6 | comedy | romance | missing |
8 | 4617 | 1569 | 4 | 964138754 | F | 25 | 6 | comedy | romance | missing |
9 | 4617 | 1732 | 1 | 964138882 | F | 25 | 6 | comedy | crime | mystery |
10 | 3706 | 436 | 2 | 966279993 | M | 25 | 12 | drama | thriller | missing |
11 | 3706 | 1772 | 2 | 966280154 | M | 25 | 12 | action | comedy | musical |
12 | 3706 | 3334 | 3 | 966280523 | M | 25 | 12 | crime | drama | film-noir |
13 | 3706 | 1136 | 5 | 966376465 | M | 25 | 12 | comedy | missing | missing |
14 | 3706 | 1394 | 4 | 966376516 | M | 25 | 12 | comedy | missing | missing |
15 | 2137 | 3948 | 5 | 974639801 | F | 1 | 10 | comedy | missing | missing |
16 | 2137 | 1215 | 3 | 974640099 | F | 1 | 10 | action | adventure | comedy |
17 | 2137 | 1356 | 4 | 974640343 | F | 1 | 10 | action | adventure | sci-fi |
18 | 2137 | 2021 | 1 | 974640436 | F | 1 | 10 | fantasy | sci-fi | missing |
19 | 2137 | 780 | 5 | 974640455 | F | 1 | 10 | action | sci-fi | war |
20 | 2137 | 2012 | 4 | 974640506 | F | 1 | 10 | comedy | sci-fi | western |
21 | 2137 | 1037 | 5 | 974640534 | F | 1 | 10 | action | sci-fi | thriller |
22 | 2137 | 2701 | 3 | 974640720 | F | 1 | 10 | action | sci-fi | western |
23 | 2137 | 34 | 4 | 974641074 | F | 1 | 10 | children's | comedy | drama |
24 | 2137 | 748 | 4 | 974641742 | F | 1 | 10 | action | sci-fi | thriller |
25 | 2137 | 3745 | 5 | 974641844 | F | 1 | 10 | adventure | animation | sci-fi |
26 | 2137 | 3793 | 5 | 974641844 | F | 1 | 10 | action | sci-fi | missing |
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 27 entries, 0 to 26 Data columns (total 10 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 user 27 non-null int64 1 item 27 non-null int64 2 label 27 non-null int64 3 time 27 non-null int64 4 sex 27 non-null object 5 age 27 non-null int64 6 occupation 27 non-null int64 7 genre1 27 non-null object 8 genre2 27 non-null object 9 genre3 27 non-null object dtypes: int64(6), object(4) memory usage: 2.2+ KB
train_data, eval_data, test_data = random_split(data, multi_ratios=[0.5, 0.1, 0.1], seed=42,
filter_unknown=False)
train_data.shape, eval_data.shape, test_data.shape
((19, 10), (4, 10), (4, 10))
test_data.head()
user | item | label | time | sex | age | occupation | genre1 | genre2 | genre3 | |
---|---|---|---|---|---|---|---|---|---|---|
8 | 4617 | 1569 | 4 | 964138754 | F | 25 | 6 | comedy | romance | missing |
13 | 3706 | 1136 | 5 | 966376465 | M | 25 | 12 | comedy | missing | missing |
9 | 4617 | 1732 | 1 | 964138882 | F | 25 | 6 | comedy | crime | mystery |
21 | 2137 | 1037 | 5 | 974640534 | F | 1 | 10 | action | sci-fi | thriller |
train_data, eval_data, test_data = random_split(data,
multi_ratios=[0.8, 0.1, 0.1],
seed=42,
filter_unknown=True,
pad_unknown=False)
train_data.shape, eval_data.shape, test_data.shape
((21, 10), (0, 10), (0, 10))
eval_data.head()
user | item | label | time | sex | age | occupation | genre1 | genre2 | genre3 |
---|
train_data, eval_data = split_by_ratio(data, test_size=0.2)
train_data.shape, eval_data.shape
((80490, 10), (19473, 10))
eval_data.head()
user | item | label | time | sex | age | occupation | genre1 | genre2 | genre3 | |
---|---|---|---|---|---|---|---|---|---|---|
90449 | 2 | 1544 | 4 | 978300174 | M | 56 | 16 | action | adventure | sci-fi |
90418 | 3 | 2355 | 5 | 978298430 | M | 25 | 15 | animation | children's | comedy |
90382 | 5 | 3079 | 2 | 978246162 | M | 25 | 20 | drama | missing | missing |
90344 | 6 | 3717 | 4 | 978238371 | F | 50 | 9 | action | crime | missing |
90335 | 7 | 1573 | 4 | 978234874 | M | 35 | 1 | action | sci-fi | thriller |
train_data, eval_data = split_by_num(data, test_size=1)
train_data.shape, eval_data.shape
((95128, 10), (4882, 10))
train_data, eval_data = split_by_ratio_chrono(data, test_size=0.2)
train_data.shape, eval_data.shape
((80490, 10), (19392, 10))
eval_data.head()
user | item | label | time | sex | age | occupation | genre1 | genre2 | genre3 | |
---|---|---|---|---|---|---|---|---|---|---|
90449 | 2 | 1544 | 4 | 978300174 | M | 56 | 16 | action | adventure | sci-fi |
90418 | 3 | 2355 | 5 | 978298430 | M | 25 | 15 | animation | children's | comedy |
90382 | 5 | 3079 | 2 | 978246162 | M | 25 | 20 | drama | missing | missing |
90344 | 6 | 3717 | 4 | 978238371 | F | 50 | 9 | action | crime | missing |
90335 | 7 | 1573 | 4 | 978234874 | M | 35 | 1 | action | sci-fi | thriller |
train_data, eval_data = split_by_num_chrono(data, test_size=1)
train_data.shape, eval_data.shape
((95128, 10), (4880, 10))
eval_data.head()
user | item | label | time | sex | age | occupation | genre1 | genre2 | genre3 | |
---|---|---|---|---|---|---|---|---|---|---|
90449 | 2 | 1544 | 4 | 978300174 | M | 56 | 16 | action | adventure | sci-fi |
90418 | 3 | 2355 | 5 | 978298430 | M | 25 | 15 | animation | children's | comedy |
90382 | 5 | 3079 | 2 | 978246162 | M | 25 | 20 | drama | missing | missing |
90344 | 6 | 3717 | 4 | 978238371 | F | 50 | 9 | action | crime | missing |
90335 | 7 | 1573 | 4 | 978234874 | M | 35 | 1 | action | sci-fi | thriller |