from fastai import * # Quick access to most common functionality
from fastai.collab import * # Quick access to collab filtering functionality
collab
models use data in a DataFrame
of user, items, and ratings.
path = untar_data(URLs.ML_SAMPLE)
path
PosixPath('/data1/jhoward/git/fastai/fastai/../data/movie_lens_sample')
ratings = pd.read_csv(path/'ratings.csv')
series2cat(ratings, 'userId', 'movieId')
ratings.head()
userId | movieId | rating | timestamp | |
---|---|---|---|---|
0 | 73 | 1097 | 4.0 | 1255504951 |
1 | 561 | 924 | 3.5 | 1172695223 |
2 | 157 | 260 | 3.5 | 1291598691 |
3 | 358 | 1210 | 5.0 | 957481884 |
4 | 130 | 316 | 2.0 | 1138999234 |
That's all we need to create and train a model:
learn = get_collab_learner(ratings, n_factors=50, min_score=0., max_score=5.)
learn.fit_one_cycle(4, 5e-3)
VBox(children=(HBox(children=(IntProgress(value=0, max=4), HTML(value='0.00% [0/4 00:00<00:00]'))), HTML(value…
Total time: 00:04 epoch train loss valid loss 0 2.214395 1.604201 (00:01) 1 1.006937 0.719938 (00:01) 2 0.704926 0.713904 (00:01) 3 0.600082 0.709458 (00:01)