本仓库主要是进行论文的复现以及在基准数据集上的评测,提供简单易用的调用接口。使用TensorFlow实现是为了计算图定义的灵活性,同时提供类似Keras的调用接口。欢迎指正交流!
- python3
- tensorflow==1.4.0
- numpy==1.13.3
- scikit-learn==0.19.1
base
基类仿照keras
模型实现以下公有方法,包括
- compile
编译模型,指定优化器和损失函数以及度量 - save_model 保存模型
- load_mdel 加载模型
- train_on_batch 小批量训练
- fit 全量训练
- test_on_batch 小批量评估
- evaluate 模型评估
- predict_on_batch 小批量预测
- predict 全量预测 私有方法包括
- _create_optimizer
- _create_metrics
- _compute_sample_weight
同时设计了若干抽象方法
- _get_input_data
- _get_input_target
- _get_output_target
- _get_optimizer_loss
- _build_graph
要求子类在__init__
方法末尾调用self._build_graph()
构建计算图。
- 添加
tf.summary.FileWriter
- 添加自定义度量函数
- 添加带权损失函数
- 添加损失函数
DeepFM: A Factorization-Machine based Neural Network for CTR Prediction arxiv
- 正则化项和损失函数还需要修改
Deep & Cross Network for Ad Click Predictions arxiv
实现和论文的区别
- embedding_size
论文里提出每个field根据cardinality的不同来设置。这里用的是所有field具有相同embedding_size的实现,所以feature的编码是global的。 - Batch normalization
论文在Deep Network部分采用了BN,采用BN貌似没有dropout效果好。 - gradient clip
论文采用了梯度截断,范数设置为100 - 其他 论文采用512batch size,并提出使用早停来防止过拟合,L2和dropout正则效果一般。