-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathmovie_recommender.py
executable file
·53 lines (40 loc) · 2.04 KB
/
movie_recommender.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
### Load Data ###
# MovieLens dataset collected by the GroupLens Research Project at the University of Minnesota.
# For more information, see http://grouplens.org/datasets/movielens/
from os import path
import graphlab as gl
from datetime import datetime
# Path to the dataset directory
data_dir = './dataset/ml-20m'
# Table of movies we are recommending: movieId, title, genres
items = gl.SFrame.read_csv(path.join(data_dir, 'movies.csv'))
# Table of interactions between users and items: userId, movieId, rating, timestamp
actions = gl.SFrame.read_csv(path.join(data_dir, 'ratings.csv'))
### Prepare Data ###
# Prepare the data by removing items that are rare
rare_items = actions.groupby('movieId', gl.aggregate.COUNT).sort('Count')
rare_items = rare_items[rare_items['Count'] <= 5]
items = items.filter_by(rare_items['movieId'], 'movieId', exclude=True)
actions = actions[actions['rating'] >=4 ]
actions = actions.filter_by(rare_items['movieId'], 'movieId', exclude=True)
# Extract year, title, and genre
items['year'] = items['title'].apply(lambda x: x[-5:-1])
items['title'] = items['title'].apply(lambda x: x[:-7])
items['genres'] = items['genres'].apply(lambda x: x.split('|'))
actions['timestamp'] = actions['timestamp'].astype(datetime)
# Get the metadata ready
urls = gl.SFrame.read_csv(path.join(data_dir, 'movie_urls.csv'))
items = items.join(urls, on='movieId')
users = gl.SFrame.read_csv(path.join(data_dir, 'user_names.csv'))
training_data, validation_data = gl.recommender.util.random_split_by_user(actions, 'userId', 'movieId')
### Train Recommender Model ###
model = gl.recommender.create(training_data, 'userId', 'movieId')
# Interactively evaluate and explore recommendations
view = model.views.overview(observation_data=training_data,
validation_set=validation_data,
user_data=users,
user_name_column='name',
item_data=items,
item_name_column='title',
item_url_column='url')
view.show()