diff --git a/include/onmt/nn/ModuleFactory.hxx b/include/onmt/nn/ModuleFactory.hxx index d111e08..0edd339 100644 --- a/include/onmt/nn/ModuleFactory.hxx +++ b/include/onmt/nn/ModuleFactory.hxx @@ -33,6 +33,7 @@ #ifdef WITH_QLINEAR # include "onmt/nn/qLinear.h" +# include "onmt/nn/qLookupTable.h" #endif namespace onmt @@ -99,6 +100,11 @@ namespace onmt mod = new Linear(data); } else if (name == "nn.LookupTable") +#ifdef WITH_QLINEAR + if (_qlinear) + mod = new qLookupTable(data); + else +#endif mod = new LookupTable(data); else if (name == "nn.CAddTable") mod = new CAddTable(); diff --git a/include/onmt/nn/qLinear.h b/include/onmt/nn/qLinear.h index 84a6408..1eb0ec6 100644 --- a/include/onmt/nn/qLinear.h +++ b/include/onmt/nn/qLinear.h @@ -17,13 +17,23 @@ namespace onmt { public: qLinear(th::Table* data) - : Linear(data), _quant_input_buffer(nullptr) + : Linear(nullptr) + , _weight_short(StorageLoader >, short>::get_matrix(data, "weight")) + , _bias_short(StorageLoader >, short>::get_matrix(data, "bias")) + , _quant_input_buffer(nullptr) { - // Quantize the weight - ncols=width is supposed to be multiple of SIMD_VSIZE - if (this->_wcols % SIMD_VSIZE) - throw std::runtime_error("Weight matrix width should be multiple of 8/16 for qLinear"); - _malloc_align(_quant_weight_buffer, _quant_weight, this->_wrows * this->_wcols / SIMD_VSIZE); - simd::Quantize(this->_weight.data(), _quant_weight, this->_wrows, this->_wcols); + this->_wrows = this->_weight_short.rows(); + this->_wcols = this->_weight_short.cols(); + const short * ptr_weight = this->_weight_short.data(); + void * p = (void *) ptr_weight; + size_t size = this->_wrows * this->_wcols * sizeof(short); + size_t space = size + 64; + align(sizeof(SIMD_TYPE), size, p, space); + _quant_weight = (SIMD_TYPE*)p; + /* there might be a memory shift necessary for alignment */ + if (space - size != 64) { + memmove(p, ptr_weight, size); + } } virtual ~qLinear() @@ -66,14 +76,14 @@ namespace onmt _subdict); /* add bias */ - if (this->_bias.rows() > 0) + if (this->_bias_short.rows() > 0) { if (this->_rwrows) for (int i = 0; i < input.rows(); ++i) - this->_output.row(i).noalias() += this->_rbias.transpose(); + this->_output.row(i).noalias() += this->_rbias_short.transpose().template cast() / simd::quant_mult; else for (int i = 0; i < input.rows(); ++i) - this->_output.row(i).noalias() += this->_bias.transpose(); + this->_output.row(i).noalias() += this->_bias_short.transpose().template cast() / simd::quant_mult; } } @@ -82,14 +92,17 @@ namespace onmt { this->_rwrows = v.size(); _subdict = v; - this->_rbias.resize(v.size(), 1); + this->_rbias_short.resize(v.size(), 1); /* adjust bias */ for (size_t i = 0; i < v.size(); i++) { - this->_rbias.row(i) = this->_bias.row(v[i]); + this->_rbias_short.row(i) = this->_bias_short.row(v[i]); } } protected: + Eigen::Map > _weight_short; + Eigen::Map > _bias_short; + Eigen::RowMajorMat _rbias_short; void* _quant_weight_buffer; void* _quant_input_buffer; SIMD_TYPE* _quant_weight; diff --git a/include/onmt/nn/qLookupTable.h b/include/onmt/nn/qLookupTable.h new file mode 100644 index 0000000..d0c0c49 --- /dev/null +++ b/include/onmt/nn/qLookupTable.h @@ -0,0 +1,43 @@ +#pragma once + +#include "onmt/nn/Module.h" +#include "onmt/th/Obj.h" +#include "onmt/StorageLoader.h" +#include "onmt/simd/MatrixMult.h" + +namespace onmt +{ + namespace nn + { + + template + class qLookupTable: public Module + { + public: + qLookupTable(th::Table* data) + : Module("nn.qLookupTable") + , _weightq(StorageLoader >, short>::get_matrix(data, "weight")) + { + _weight.resize(_weightq.rows(), _weightq.cols()); + _weight.setZero(); + } + + void forward_impl(const MatFwd& input) override + { + this->_output.resize(input.rows(), _weight.cols()); + + for (size_t i = 0; i < input.batches(); ++i) { + if (_weight(input(i,0),0) == 0 && _weight(input(i,0),1) == 0 && _weight(input(i,0),2) == 0) { + _weight.row(input(i,0)) = _weightq.row(input(i,0)).template cast() / simd::quant_mult; + } + this->_output.row(i).noalias() = _weight.row(input(i, 0)); + } + } + + private: + Eigen::Map > _weightq; + MatFwd _weight; + }; + + } +}