Skip to content
This repository has been archived by the owner on Jun 10, 2021. It is now read-only.

[WIP] read quantizedWeight file #32

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/onmt/nn/ModuleFactory.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#ifdef WITH_QLINEAR
# include "onmt/nn/qLinear.h"
# include "onmt/nn/qLookupTable.h"
#endif

namespace onmt
Expand Down Expand Up @@ -99,6 +100,11 @@ namespace onmt
mod = new Linear<MatFwd, MatIn, ModelT>(data);
}
else if (name == "nn.LookupTable")
#ifdef WITH_QLINEAR
if (_qlinear)
mod = new qLookupTable<MatFwd, MatIn, ModelT>(data);
else
#endif
mod = new LookupTable<MatFwd, MatEmb, ModelT>(data);
else if (name == "nn.CAddTable")
mod = new CAddTable<MatFwd>();
Expand Down
35 changes: 24 additions & 11 deletions include/onmt/nn/qLinear.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,23 @@ namespace onmt
{
public:
qLinear(th::Table* data)
: Linear<MatFwd, MatIn, ModelT>(data), _quant_input_buffer(nullptr)
: Linear<MatFwd, MatIn, ModelT>(nullptr)
, _weight_short(StorageLoader<Eigen::Map<const Eigen::RowMajorMat<short> >, short>::get_matrix(data, "weight"))
, _bias_short(StorageLoader<Eigen::Map<const Eigen::RowMajorMat<short> >, short>::get_matrix(data, "bias"))
, _quant_input_buffer(nullptr)
{
// Quantize the weight - ncols=width is supposed to be multiple of SIMD_VSIZE
if (this->_wcols % SIMD_VSIZE)
throw std::runtime_error("Weight matrix width should be multiple of 8/16 for qLinear");
_malloc_align(_quant_weight_buffer, _quant_weight, this->_wrows * this->_wcols / SIMD_VSIZE);
simd::Quantize(this->_weight.data(), _quant_weight, this->_wrows, this->_wcols);
this->_wrows = this->_weight_short.rows();
this->_wcols = this->_weight_short.cols();
const short * ptr_weight = this->_weight_short.data();
void * p = (void *) ptr_weight;
size_t size = this->_wrows * this->_wcols * sizeof(short);
size_t space = size + 64;
align(sizeof(SIMD_TYPE), size, p, space);
_quant_weight = (SIMD_TYPE*)p;
/* there might be a memory shift necessary for alignment */
if (space - size != 64) {
memmove(p, ptr_weight, size);
}
}

virtual ~qLinear()
Expand Down Expand Up @@ -66,14 +76,14 @@ namespace onmt
_subdict);

/* add bias */
if (this->_bias.rows() > 0)
if (this->_bias_short.rows() > 0)
{
if (this->_rwrows)
for (int i = 0; i < input.rows(); ++i)
this->_output.row(i).noalias() += this->_rbias.transpose();
this->_output.row(i).noalias() += this->_rbias_short.transpose().template cast<float>() / simd::quant_mult;
else
for (int i = 0; i < input.rows(); ++i)
this->_output.row(i).noalias() += this->_bias.transpose();
this->_output.row(i).noalias() += this->_bias_short.transpose().template cast<float>() / simd::quant_mult;
}
}

Expand All @@ -82,14 +92,17 @@ namespace onmt
{
this->_rwrows = v.size();
_subdict = v;
this->_rbias.resize(v.size(), 1);
this->_rbias_short.resize(v.size(), 1);
/* adjust bias */
for (size_t i = 0; i < v.size(); i++) {
this->_rbias.row(i) = this->_bias.row(v[i]);
this->_rbias_short.row(i) = this->_bias_short.row(v[i]);
}
}

protected:
Eigen::Map<const Eigen::RowMajorMat<short> > _weight_short;
Eigen::Map<const Eigen::RowMajorMat<short> > _bias_short;
Eigen::RowMajorMat<short> _rbias_short;
void* _quant_weight_buffer;
void* _quant_input_buffer;
SIMD_TYPE* _quant_weight;
Expand Down
43 changes: 43 additions & 0 deletions include/onmt/nn/qLookupTable.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#pragma once

#include "onmt/nn/Module.h"
#include "onmt/th/Obj.h"
#include "onmt/StorageLoader.h"
#include "onmt/simd/MatrixMult.h"

namespace onmt
{
namespace nn
{

template <typename MatFwd, typename MatEmb, typename ModelT>
class qLookupTable: public Module<MatFwd>
{
public:
qLookupTable(th::Table* data)
: Module<MatFwd>("nn.qLookupTable")
, _weightq(StorageLoader<Eigen::Map<const Eigen::RowMajorMat<short> >, short>::get_matrix(data, "weight"))
{
_weight.resize(_weightq.rows(), _weightq.cols());
_weight.setZero();
}

void forward_impl(const MatFwd& input) override
{
this->_output.resize(input.rows(), _weight.cols());

for (size_t i = 0; i < input.batches(); ++i) {
if (_weight(input(i,0),0) == 0 && _weight(input(i,0),1) == 0 && _weight(input(i,0),2) == 0) {
_weight.row(input(i,0)) = _weightq.row(input(i,0)).template cast<ModelT>() / simd::quant_mult;
}
this->_output.row(i).noalias() = _weight.row(input(i, 0));
}
}

private:
Eigen::Map<const Eigen::RowMajorMat<short> > _weightq;
MatFwd _weight;
};

}
}