Skip to content

Tensorflow implementation of DeepFM for CTR prediction.

License

Notifications You must be signed in to change notification settings

Suyash-pande-89/tensorflow-DeepFM

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tensorflow-DeepFM

This project includes a Tensorflow implementation of DeepFM [1].

NEWS

Usage

Input Format

This implementation requires the input data in the following format:

  • Xi: [[ind1_1, ind1_2, ...], [ind2_1, ind2_2, ...], ..., [indi_1, indi_2, ..., indi_j, ...], ...]
    • indi_j is the feature index of feature field j of sample i in the dataset
  • Xv: [[val1_1, val1_2, ...], [val2_1, val2_2, ...], ..., [vali_1, vali_2, ..., vali_j, ...], ...]
    • vali_j is the feature value of feature field j of sample i in the dataset
    • vali_j can be either binary (1/0, for binary/categorical features) or float (e.g., 10.24, for numerical features)
  • y: target of each sample in the dataset (1/0 for classification, numeric number for regression)

Please see example/DataReader.py an example how to prepare the data in required format for DeepFM.

Init and train a model

import tensorflow as tf
from sklearn.metrics import roc_auc_score

# params
dfm_params = {
    "use_fm": True,
    "use_deep": True,
    "embedding_size": 8,
    "dropout_fm": [1.0, 1.0],
    "deep_layers": [32, 32],
    "dropout_deep": [0.5, 0.5, 0.5],
    "deep_layers_activation": tf.nn.relu,
    "epoch": 30,
    "batch_size": 1024,
    "learning_rate": 0.001,
    "optimizer_type": "adam",
    "batch_norm": 1,
    "batch_norm_decay": 0.995,
    "l2_reg": 0.01,
    "verbose": True,
    "eval_metric": roc_auc_score,
    "random_seed": 2017
}

# prepare training and validation data in the required format
Xi_train, Xv_train, y_train = prepare(...)
Xi_valid, Xv_valid, y_valid = prepare(...)

# init a DeepFM model
dfm = DeepFM(**dfm_params)

# fit a DeepFM model
dfm.fit(Xi_train, Xv_train, y_train)

# make prediction
dfm.predict(Xi_valid, Xv_valid)

# evaluate a trained model
dfm.evaluate(Xi_valid, Xv_valid, y_valid)

You can use early_stopping in the training as follow

dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True)

You can refit the model on the whole training and validation set as follow

dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True, refit=True)

You can use the FM or DNN part only by setting the parameter use_fm or use_dnn to False.

Regression

This implementation also supports regression task. To use DeepFM for regression, you can set loss_type as mse. Accordingly, you should use eval_metric for regression, e.g., mse or mae.

Example

Folder example includes an example usage of DeepFM/FM/DNN models for Porto Seguro's Safe Driver Prediction competition on Kaggle.

Please download the data from the competition website and put them into the example/data folder.

To train DeepFM model for this dataset, run

$ cd example
$ python main.py

Please see example/DataReader.py how to parse the raw dataset into the required format for DeepFM.

Performance

DeepFM

dfm

FM

fm

DNN

dnn

Some tips

  • You should tune the parameters for each model in order to get reasonable performance.
  • You can also try to ensemble these models or ensemble them with other models (e.g., XGBoost or LightGBM).

Reference

[1] DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, Huifeng Guo, Ruiming Tang, Yunming Yey, Zhenguo Li, Xiuqiang He.

Acknowledgments

This project gets inspirations from the following projects:

License

MIT

About

Tensorflow implementation of DeepFM for CTR prediction.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%