Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Quantize QuestionAnswering models #1581

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions scripts/question_answering/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(self, backbone, units=768, layer_norm_eps=1E-12, dropout_prob=0.1,
self.answerable_scores.add(nn.Dense(2, flatten=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer))
self.quantized_backbone = None

def get_start_logits(self, contextual_embedding, p_mask):
"""
Expand Down Expand Up @@ -287,10 +288,14 @@ def forward(self, tokens, token_types, valid_length, p_mask, start_position):
Shape (batch_size, sequence_length)
answerable_logits
"""
backbone_net = self.backbone
if self.quantized_backbone != None:
backbone_net = self.quantized_backbone

if self.use_segmentation:
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
contextual_embeddings = backbone_net(tokens, token_types, valid_length)
else:
contextual_embeddings = self.backbone(tokens, valid_length)
contextual_embeddings = backbone_net(tokens, valid_length)
start_logits = self.get_start_logits(contextual_embeddings, p_mask)
end_logits = self.get_end_logits(contextual_embeddings,
np.expand_dims(start_position, axis=1),
Expand Down Expand Up @@ -337,11 +342,16 @@ def inference(self, tokens, token_types, valid_length, p_mask,
The answerable logits. Here 0 --> answerable and 1 --> not answerable.
Shape (batch_size, sequence_length, 2)
"""
backbone_net = self.backbone
if self.quantized_backbone != None:
backbone_net = self.quantized_backbone

# Shape (batch_size, sequence_length, C)
if self.use_segmentation:
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
contextual_embeddings = backbone_net(tokens, token_types, valid_length)
else:
contextual_embeddings = self.backbone(tokens, valid_length)
contextual_embeddings = backbone_net(tokens, valid_length)

start_logits = self.get_start_logits(contextual_embeddings, p_mask)
# The shape of start_top_index will be (..., start_top_n)
start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1,
Expand Down
105 changes: 95 additions & 10 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def parse_args():
'this will be truncated to this length. default is 64')
parser.add_argument('--pre_shuffle_seed', type=int, default=100,
help='Random seed for pre split shuffle')
parser.add_argument('--round_to', type=int, default=None,
parser.add_argument('--round_to', type=int, default=8,
help='The length of padded sequences will be rounded up to be multiple'
' of this argument. When round to is set to 8, training throughput '
'may increase for mixed precision training on GPUs with TensorCores.')
Expand All @@ -147,9 +147,9 @@ def parse_args():
parser.add_argument('--max_saved_ckpt', type=int, default=5,
help='The maximum number of saved checkpoints')
parser.add_argument('--dtype', type=str, default='float32',
help='Data type used for evaluation. Either float32 or float16. When you '
help='Data type used for evaluation. Either float32, float16 or int8. When you '
'use --dtype float16, amp will be turned on in the training phase and '
'fp16 will be used in evaluation.')
'fp16 will be used in evaluation. For now int8 data type is supported on CPU only.')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -195,13 +195,12 @@ def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length):
self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id

# TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality.
# Here, we use round_to=8 to improve the throughput.
self.BatchifyFunction = bf.NamedTuple(ChunkFeature,
{'qas_id': bf.List(),
'data': bf.Pad(val=self.pad_id, round_to=8),
'data': bf.Pad(val=self.pad_id, round_to=args.round_to),
'valid_length': bf.Stack(),
'segment_ids': bf.Pad(round_to=8),
'masks': bf.Pad(val=1, round_to=8),
'segment_ids': bf.Pad(round_to=args.round_to),
'masks': bf.Pad(val=1, round_to=args.round_to),
'is_impossible': bf.Stack(),
'gt_start': bf.Stack(),
'gt_end': bf.Stack(),
Expand Down Expand Up @@ -357,7 +356,7 @@ def get_squad_features(args, tokenizer, segment):
tokenizer=tokenizer,
is_training=is_training), data_examples)
logging.info('Done! Time spent:{:.2f} seconds'.format(time.time() - start))
with open(data_cache_path, 'w') as f:
with open(data_cache_path, 'w', encoding='utf-8') as f:
for feature in data_features:
f.write(feature.to_json() + '\n')

Expand Down Expand Up @@ -815,6 +814,83 @@ def predict_extended(original_feature,
assert len(nbest_json) >= 1
return not_answerable_score, nbest[0][0], nbest_json

def quantize_and_calibrate(net, dataloader):
class QuantizationDataLoader(mx.gluon.data.DataLoader):
def __init__(self, dataloader, use_segmentation):
self._dataloader = dataloader
self._iter = None
self._use_segmentation = use_segmentation

def __iter__(self):
self._iter = iter(self._dataloader)
return self

def __next__(self):
batch = next(self._iter)
if self._use_segmentation:
return [batch.data, batch.segment_ids, batch.valid_length]
else:
return [batch.data, batch.valid_length]

def __del__(self):
del(self._dataloader)

class BertLayerCollector(mx.contrib.quantization.CalibrationCollector):
"""Saves layer output min and max values in a dict with layer names as keys.
The collected min and max values will be directly used as thresholds for quantization.
"""
def __init__(self, clip_min, clip_max):
super(BertLayerCollector, self).__init__()
self.clip_min = clip_min
self.clip_max = clip_max

def collect(self, name, op_name, arr):
"""Callback function for collecting min and max values from an NDArray."""
if name not in self.include_layers:
return
arr = arr.copyto(mx.cpu()).asnumpy()
min_range = np.min(arr)
max_range = np.max(arr)

if (name.find("sg_onednn_fully_connected_eltwise") != -1 or op_name.find("LayerNorm") != -1) \
and max_range > self.clip_max:
max_range = self.clip_max
elif name.find('sg_onednn_fully_connected') != -1 and min_range < self.clip_min:
min_range = self.clip_min

if name in self.min_max_dict:
cur_min_max = self.min_max_dict[name]
self.min_max_dict[name] = (min(cur_min_max[0], min_range),
max(cur_min_max[1], max_range))
else:
self.min_max_dict[name] = (min_range, max_range)

calib_data = QuantizationDataLoader(dataloader, net.use_segmentation)
model_name = args.model_name
# disable specific layers in some models for the sake of accuracy

if model_name == 'google_albert_base_v2':
logging.warn(f"Currently quantized {model_name} shows significant accuracy drop which is not fixed yet")

exclude_layers_map = {"google_electra_large":
["sg_onednn_fully_connected_eltwise_2", "sg_onednn_fully_connected_eltwise_14",
"sg_onednn_fully_connected_eltwise_18", "sg_onednn_fully_connected_eltwise_22",
"sg_onednn_fully_connected_eltwise_26"
]}
exclude_layers = None
if model_name in exclude_layers_map.keys():
exclude_layers = exclude_layers_map[model_name]
net.quantized_backbone = mx.contrib.quant.quantize_net(net.backbone, quantized_dtype='auto',
quantize_mode='smart',
exclude_layers=exclude_layers,
exclude_layers_match=None,
calib_data=calib_data,
calib_mode='custom',
LayerOutputCollector=BertLayerCollector(clip_min=-50, clip_max=10),
num_calib_batches=10,
ctx=mx.cpu())
return net


def evaluate(args, last=True):
store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
Expand All @@ -826,11 +902,12 @@ def evaluate(args, last=True):
return
ctx_l = parse_ctx(args.gpus)
logging.info(
'Srarting inference without horovod on the first node on device {}'.format(
'Starting inference without horovod on the first node on device {}'.format(
str(ctx_l)))
network_dtype = args.dtype if args.dtype != 'int8' else 'float32'

cfg, tokenizer, qa_net, use_segmentation = get_network(
args.model_name, ctx_l, args.classifier_dropout, dtype=args.dtype)
args.model_name, ctx_l, args.classifier_dropout, dtype=network_dtype)
if args.dtype == 'float16':
qa_net.cast('float16')
qa_net.hybridize()
Expand Down Expand Up @@ -860,6 +937,9 @@ def eval_validation(ckpt_name, best_eval):
num_workers=0,
shuffle=False)

if args.dtype == 'int8':
quantize_and_calibrate(qa_net, dev_dataloader)

log_interval = args.eval_log_interval
all_results = []
epoch_tic = time.time()
Expand Down Expand Up @@ -999,6 +1079,11 @@ def eval_validation(ckpt_name, best_eval):
if __name__ == '__main__':
os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round'
args = parse_args()
if args.dtype == 'int8':
ctx_l = parse_ctx(args.gpus)
if ctx_l[0] != mx.cpu() or len(ctx_l) != 1:
raise ValueError("Evaluation on int8 data type is supported only for CPU for now")

if args.do_train:
if args.dtype == 'float16':
# Initialize amp if it's fp16 training
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def find_version(*file_paths):
'pylint_quotes',
'flake8',
'recommonmark',
'sphinx>=1.5.5',
'sphinx-gallery',
'sphinx_rtd_theme',
'mxtheme',
Expand Down