From 5c7cc895350cbeac29278b78e18c3e7a79da0450 Mon Sep 17 00:00:00 2001 From: "Jean A. Senellart" Date: Tue, 15 May 2018 00:08:22 +0200 Subject: [PATCH 1/2] read quantizedWeight file --- include/onmt/nn/ModuleFactory.hxx | 6 +++++ include/onmt/nn/qLinear.h | 35 +++++++++++++++++-------- include/onmt/nn/qLookupTable.h | 43 +++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 11 deletions(-) create mode 100644 include/onmt/nn/qLookupTable.h diff --git a/include/onmt/nn/ModuleFactory.hxx b/include/onmt/nn/ModuleFactory.hxx index d111e08..0edd339 100644 --- a/include/onmt/nn/ModuleFactory.hxx +++ b/include/onmt/nn/ModuleFactory.hxx @@ -33,6 +33,7 @@ #ifdef WITH_QLINEAR # include "onmt/nn/qLinear.h" +# include "onmt/nn/qLookupTable.h" #endif namespace onmt @@ -99,6 +100,11 @@ namespace onmt mod = new Linear(data); } else if (name == "nn.LookupTable") +#ifdef WITH_QLINEAR + if (_qlinear) + mod = new qLookupTable(data); + else +#endif mod = new LookupTable(data); else if (name == "nn.CAddTable") mod = new CAddTable(); diff --git a/include/onmt/nn/qLinear.h b/include/onmt/nn/qLinear.h index 84a6408..1eb0ec6 100644 --- a/include/onmt/nn/qLinear.h +++ b/include/onmt/nn/qLinear.h @@ -17,13 +17,23 @@ namespace onmt { public: qLinear(th::Table* data) - : Linear(data), _quant_input_buffer(nullptr) + : Linear(nullptr) + , _weight_short(StorageLoader >, short>::get_matrix(data, "weight")) + , _bias_short(StorageLoader >, short>::get_matrix(data, "bias")) + , _quant_input_buffer(nullptr) { - // Quantize the weight - ncols=width is supposed to be multiple of SIMD_VSIZE - if (this->_wcols % SIMD_VSIZE) - throw std::runtime_error("Weight matrix width should be multiple of 8/16 for qLinear"); - _malloc_align(_quant_weight_buffer, _quant_weight, this->_wrows * this->_wcols / SIMD_VSIZE); - simd::Quantize(this->_weight.data(), _quant_weight, this->_wrows, this->_wcols); + this->_wrows = this->_weight_short.rows(); + this->_wcols = this->_weight_short.cols(); + const short * ptr_weight = this->_weight_short.data(); + void * p = (void *) ptr_weight; + size_t size = this->_wrows * this->_wcols * sizeof(short); + size_t space = size + 64; + align(sizeof(SIMD_TYPE), size, p, space); + _quant_weight = (SIMD_TYPE*)p; + /* there might be a memory shift necessary for alignment */ + if (space - size != 64) { + memmove(p, ptr_weight, size); + } } virtual ~qLinear() @@ -66,14 +76,14 @@ namespace onmt _subdict); /* add bias */ - if (this->_bias.rows() > 0) + if (this->_bias_short.rows() > 0) { if (this->_rwrows) for (int i = 0; i < input.rows(); ++i) - this->_output.row(i).noalias() += this->_rbias.transpose(); + this->_output.row(i).noalias() += this->_rbias_short.transpose().template cast() / simd::quant_mult; else for (int i = 0; i < input.rows(); ++i) - this->_output.row(i).noalias() += this->_bias.transpose(); + this->_output.row(i).noalias() += this->_bias_short.transpose().template cast() / simd::quant_mult; } } @@ -82,14 +92,17 @@ namespace onmt { this->_rwrows = v.size(); _subdict = v; - this->_rbias.resize(v.size(), 1); + this->_rbias_short.resize(v.size(), 1); /* adjust bias */ for (size_t i = 0; i < v.size(); i++) { - this->_rbias.row(i) = this->_bias.row(v[i]); + this->_rbias_short.row(i) = this->_bias_short.row(v[i]); } } protected: + Eigen::Map > _weight_short; + Eigen::Map > _bias_short; + Eigen::RowMajorMat _rbias_short; void* _quant_weight_buffer; void* _quant_input_buffer; SIMD_TYPE* _quant_weight; diff --git a/include/onmt/nn/qLookupTable.h b/include/onmt/nn/qLookupTable.h new file mode 100644 index 0000000..e6951bd --- /dev/null +++ b/include/onmt/nn/qLookupTable.h @@ -0,0 +1,43 @@ +#pragma once + +#include "onmt/nn/Module.h" +#include "onmt/th/Obj.h" +#include "onmt/StorageLoader.h" +#include "onmt/simd/MatrixMult.h" + +namespace onmt +{ + namespace nn + { + + template + class qLookupTable: public Module + { + public: + qLookupTable(th::Table* data) + : Module("nn.qLookupTable") + , _weightq(StorageLoader >, short>::get_matrix(data, "weight")) + { + _weight.resize(_weightq.rows(), _weightq.cols()); + _weight.setZero(); + } + + void forward_impl(const MatFwd& input) override + { + this->_output.resize(input.rows(), _weight.cols()); + + for (size_t i = 0; i < input.batches(); ++i) { + if (_weight(input(i,0),0) == 0) { + _weight.row(input(i,0)) = _weightq.row(input(i,0)).template cast() / simd::quant_mult; + } + this->_output.row(i).noalias() = _weight.row(input(i, 0)); + } + } + + private: + Eigen::Map > _weightq; + MatFwd _weight; + }; + + } +} From 6fc8654e97419faae119e95e232222ba0450350d Mon Sep 17 00:00:00 2001 From: "Jean A. Senellart" Date: Tue, 15 May 2018 07:49:47 +0200 Subject: [PATCH 2/2] more strict test to check if row has been unquantized --- include/onmt/nn/qLookupTable.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onmt/nn/qLookupTable.h b/include/onmt/nn/qLookupTable.h index e6951bd..d0c0c49 100644 --- a/include/onmt/nn/qLookupTable.h +++ b/include/onmt/nn/qLookupTable.h @@ -27,7 +27,7 @@ namespace onmt this->_output.resize(input.rows(), _weight.cols()); for (size_t i = 0; i < input.batches(); ++i) { - if (_weight(input(i,0),0) == 0) { + if (_weight(input(i,0),0) == 0 && _weight(input(i,0),1) == 0 && _weight(input(i,0),2) == 0) { _weight.row(input(i,0)) = _weightq.row(input(i,0)).template cast() / simd::quant_mult; } this->_output.row(i).noalias() = _weight.row(input(i, 0));