diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..037a5516 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +tutorials/* linguist-vendored \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 025215c8..e3911359 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,24 +5,22 @@ cache: pip sudo: true env: - global: - - PYTHONPATH=$PYTHONPATH:$TRAVIS_BUILD_DIR/tests:$TRAVIS_BUILD_DIR/matchzoo + global: + - PYTHONPATH=$PYTHONPATH:$TRAVIS_BUILD_DIR/tests:$TRAVIS_BUILD_DIR/matchzoo matrix: - allow_failures: - - os: osx - include: - - os: linux - dist: trusty - python: 3.6 - - os: osx - language: generic - env: PYTHON_VERSION=3.6 - - os: osx - language: generic - env: PYTHON_VERSION=3.7 + allow_failures: + - os: osx + include: + - os: linux + dist: trusty + python: 3.6 + - os: osx + osx_image: xcode10.2 + language: shell install: + - pip3 install -U pip - pip3 install -r requirements.txt - python3 -m nltk.downloader punkt - python3 -m nltk.downloader wordnet @@ -31,7 +29,10 @@ install: script: - stty cols 80 - export COLUMNS=80 - - make test + - if [ "$TRAVIS_EVENT_TYPE" == "pull_request" ]; then make push; fi + - if [ "$TRAVIS_EVENT_TYPE" == "push" ]; then make push; fi + - if [ "$TRAVIS_EVENT_TYPE" == "cron" ]; then make cron; fi + after_success: - - codecov \ No newline at end of file + - codecov diff --git a/Makefile b/Makefile index a7bd9a58..df5408cd 100644 --- a/Makefile +++ b/Makefile @@ -1,18 +1,70 @@ +# Usages: +# +# to install matchzoo dependencies: +# $ make init +# +# to run all matchzoo tests, recommended for big PRs and new versions: +# $ make test +# +# there are three kinds of tests: +# +# 1. "quick" tests +# - run in seconds +# - include all unit tests without marks and all doctests +# - for rapid prototyping +# - CI run this for all PRs +# +# 2. "slow" tests +# - run in minutes +# - include all unit tests marked "slow" +# - CI run this for all PRs +# +# 3. "cron" tests +# - run in minutes +# - involves underministic behavoirs (e.g. network connection) +# - include all unit tests marked "cron" +# - CI run this on a daily basis +# +# to run quick tests, excluding time consuming tests and crons: +# $ make quick +# +# to run slow tests, excluding normal tests and crons: +# $ make slow +# +# to run crons: +# $ make cron +# +# to run all tests: +# $ make test +# +# to run CI push/PR tests: +# $ make push +# +# to run docstring style check: +# $ make flake + init: pip install -r requirements.txt -TEST_ARGS = --doctest-modules --doctest-continue-on-failure --cov matchzoo/ --cov-report term-missing --cov-report html --cov-config .coveragerc matchzoo/ tests/ -W ignore::DeprecationWarning +TEST_ARGS = -v --full-trace -l --doctest-modules --doctest-continue-on-failure --cov matchzoo/ --cov-report term-missing --cov-report html --cov-config .coveragerc matchzoo/ tests/ -W ignore::DeprecationWarning --ignore=matchzoo/contrib FLAKE_ARGS = ./matchzoo --exclude=__init__.py,matchzoo/contrib test: pytest $(TEST_ARGS) flake8 $(FLAKE_ARGS) +push: + pytest -m 'not cron' $(TEST_ARGS) ${ARGS} + flake8 $(FLAKE_ARGS) + quick: - pytest -m 'not slow' $(TEST_ARGS) + pytest -m 'not slow and not cron' $(TEST_ARGS) ${ARGS} slow: - pytest -m 'slow' $(TEST_ARGS) + pytest -m 'slow and not cron' $(TEST_ARGS) ${ARGS} + +cron: + pytest -m 'cron' $(TEST_ARGS) ${ARGS} flake: - flake8 $(FLAKE_ARGS) + flake8 $(FLAKE_ARGS) ${ARGS} diff --git a/README.md b/README.md index b1a161d1..631e9d44 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,6 @@ > MatchZoo 是一个通用的文本匹配工具包,它旨在方便大家快速的实现、比较、以及分享最新的深度文本匹配模型。 [![Python 3.6](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://www.python.org/downloads/release/python-360/) -[![Gitter chat](https://badges.gitter.im/gitterHQ/gitter.png)](https://gitter.im/NTMC-Community/community) [![Pypi Downloads](https://img.shields.io/pypi/dm/matchzoo.svg?label=pypi)](https://pypi.org/project/MatchZoo/) [![Documentation Status](https://readthedocs.org/projects/matchzoo/badge/?version=master)](https://matchzoo.readthedocs.io/en/master/?badge=master) [![Build Status](https://travis-ci.org/NTMC-Community/MatchZoo.svg?branch=master)](https://travis-ci.org/NTMC-Community/MatchZoo/) @@ -68,7 +67,6 @@ import matchzoo as mz train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking') valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking') -predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking') ``` Preprocess your input data in three lines of code, keep track parameters to be passed into the model. @@ -85,7 +83,6 @@ Make use of MatchZoo customized loss functions and evaluation metrics: ranking_task = mz.tasks.Ranking(loss=mz.losses.RankCrossEntropyLoss(num_neg=4)) ranking_task.metrics = [ mz.metrics.NormalizedDiscountedCumulativeGain(k=3), - mz.metrics.NormalizedDiscountedCumulativeGain(k=5), mz.metrics.MeanAveragePrecision() ] ``` @@ -96,10 +93,6 @@ Initialize the model, fine-tune the hyper-parameters. model = mz.models.DSSM() model.params['input_shapes'] = preprocessor.context['input_shapes'] model.params['task'] = ranking_task -model.params['mlp_num_layers'] = 3 -model.params['mlp_num_units'] = 300 -model.params['mlp_num_fan_out'] = 128 -model.params['mlp_activation_func'] = 'relu' model.guess_and_fill_missing_params() model.build() model.compile() @@ -109,10 +102,8 @@ Generate pair-wise training data on-the-fly, evaluate model performance using cu ```python train_generator = mz.PairDataGenerator(train_processed, num_dup=1, num_neg=4, batch_size=64, shuffle=True) - valid_x, valid_y = valid_processed.unpack() -evaluate = mz.callbacks.EvaluateAllMetrics(model, x=valid_x, y=valid_y, batch_size=len(pred_x)) - +evaluate = mz.callbacks.EvaluateAllMetrics(model, x=valid_x, y=valid_y, batch_size=len(valid_x)) history = model.fit_generator(train_generator, epochs=20, callbacks=[evaluate], workers=5, use_multiprocessing=False) ``` @@ -127,7 +118,7 @@ If you're interested in the cutting-edge research progress, please take a look a ## Install -MatchZoo is dependent on [Keras](https://github.com/keras-team/keras), please install one of its backend engines: TensorFlow, Theano, or CNTK. We recommend the TensorFlow backend. Two ways to install MatchZoo: +MatchZoo is dependent on [Keras](https://github.com/keras-team/keras) and [Tensorflow](https://github.com/tensorflow/tensorflow). Two ways to install MatchZoo: **Install MatchZoo from Pypi:** @@ -144,7 +135,7 @@ python setup.py install ``` -## Models: +## Models 1. [DRMM](https://github.com/NTMC-Community/MatchZoo/tree/master/matchzoo/models/drmm.py): this model is an implementation of A Deep Relevance Matching Model for Ad-hoc Retrieval. diff --git a/docs/Readme.md b/docs/Readme.md index e5f392ce..7b1b43c7 100644 --- a/docs/Readme.md +++ b/docs/Readme.md @@ -16,7 +16,8 @@ pip install -r requirements.txt # Enter docs folder. cd docs # Use sphinx autodoc to generate rst. -sphinx-apidoc -o source/ ../matchzoo/ +# usage: sphinx-apidoc [OPTIONS] -o [EXCLUDE_PATTERN,...] +sphinx-apidoc -o source/ ../matchzoo/ ../matchzoo/contrib # Generate html from rst make clean make html diff --git a/docs/source/conf.py b/docs/source/conf.py index 8bc624db..55981b7d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -20,15 +20,16 @@ sys.path.insert(0, os.path.abspath('../../matchzoo/data_generator')) sys.path.insert(0, os.path.abspath('../../matchzoo/data_pack')) sys.path.insert(0, os.path.abspath('../../matchzoo/datasets')) +sys.path.insert(0, os.path.abspath('../../matchzoo/embedding')) sys.path.insert(0, os.path.abspath('../../matchzoo/engine')) sys.path.insert(0, os.path.abspath('../../matchzoo/layers')) sys.path.insert(0, os.path.abspath('../../matchzoo/losses')) -sys.path.insert(0, os.path.abspath('../../matchzoo/models')) sys.path.insert(0, os.path.abspath('../../matchzoo/metrics')) +sys.path.insert(0, os.path.abspath('../../matchzoo/models')) sys.path.insert(0, os.path.abspath('../../matchzoo/preprocessors')) -sys.path.insert(0, os.path.abspath('../../matchzoo/processor_units')) -sys.path.insert(0, os.path.abspath('../../matchzoo/utils')) sys.path.insert(0, os.path.abspath('../../matchzoo/tasks')) +sys.path.insert(0, os.path.abspath('../../matchzoo/utils')) + # -- Project information ----------------------------------------------------- @@ -39,7 +40,7 @@ # The short X.Y version version = '' # The full version, including alpha/beta/rc tags -release = '2.0' +release = '2.1' # -- General configuration --------------------------------------------------- diff --git a/docs/source/matchzoo.auto.preparer.rst b/docs/source/matchzoo.auto.preparer.rst new file mode 100644 index 00000000..3291f9fe --- /dev/null +++ b/docs/source/matchzoo.auto.preparer.rst @@ -0,0 +1,30 @@ +matchzoo.auto.preparer package +============================== + +Submodules +---------- + +matchzoo.auto.preparer.prepare module +------------------------------------- + +.. automodule:: matchzoo.auto.preparer.prepare + :members: + :undoc-members: + :show-inheritance: + +matchzoo.auto.preparer.preparer module +-------------------------------------- + +.. automodule:: matchzoo.auto.preparer.preparer + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: matchzoo.auto.preparer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/matchzoo.auto.rst b/docs/source/matchzoo.auto.rst index bb974511..1272de0a 100644 --- a/docs/source/matchzoo.auto.rst +++ b/docs/source/matchzoo.auto.rst @@ -1,25 +1,13 @@ matchzoo.auto package ===================== -Submodules ----------- +Subpackages +----------- -matchzoo.auto.prepare module ----------------------------- - -.. automodule:: matchzoo.auto.prepare - :members: - :undoc-members: - :show-inheritance: - -matchzoo.auto.tune module -------------------------- - -.. automodule:: matchzoo.auto.tune - :members: - :undoc-members: - :show-inheritance: +.. toctree:: + matchzoo.auto.preparer + matchzoo.auto.tuner Module contents --------------- diff --git a/docs/source/matchzoo.auto.tuner.callbacks.rst b/docs/source/matchzoo.auto.tuner.callbacks.rst new file mode 100644 index 00000000..1671dc48 --- /dev/null +++ b/docs/source/matchzoo.auto.tuner.callbacks.rst @@ -0,0 +1,46 @@ +matchzoo.auto.tuner.callbacks package +===================================== + +Submodules +---------- + +matchzoo.auto.tuner.callbacks.callback module +--------------------------------------------- + +.. automodule:: matchzoo.auto.tuner.callbacks.callback + :members: + :undoc-members: + :show-inheritance: + +matchzoo.auto.tuner.callbacks.lambda\_callback module +----------------------------------------------------- + +.. automodule:: matchzoo.auto.tuner.callbacks.lambda_callback + :members: + :undoc-members: + :show-inheritance: + +matchzoo.auto.tuner.callbacks.load\_embedding\_matrix module +------------------------------------------------------------ + +.. automodule:: matchzoo.auto.tuner.callbacks.load_embedding_matrix + :members: + :undoc-members: + :show-inheritance: + +matchzoo.auto.tuner.callbacks.save\_model module +------------------------------------------------ + +.. automodule:: matchzoo.auto.tuner.callbacks.save_model + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: matchzoo.auto.tuner.callbacks + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/matchzoo.auto.tuner.rst b/docs/source/matchzoo.auto.tuner.rst new file mode 100644 index 00000000..024d1e83 --- /dev/null +++ b/docs/source/matchzoo.auto.tuner.rst @@ -0,0 +1,37 @@ +matchzoo.auto.tuner package +=========================== + +Subpackages +----------- + +.. toctree:: + + matchzoo.auto.tuner.callbacks + +Submodules +---------- + +matchzoo.auto.tuner.tune module +------------------------------- + +.. automodule:: matchzoo.auto.tuner.tune + :members: + :undoc-members: + :show-inheritance: + +matchzoo.auto.tuner.tuner module +-------------------------------- + +.. automodule:: matchzoo.auto.tuner.tuner + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: matchzoo.auto.tuner + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/matchzoo.datasets.quora_qp.rst b/docs/source/matchzoo.datasets.quora_qp.rst new file mode 100644 index 00000000..e0eb3d83 --- /dev/null +++ b/docs/source/matchzoo.datasets.quora_qp.rst @@ -0,0 +1,22 @@ +matchzoo.datasets.quora\_qp package +=================================== + +Submodules +---------- + +matchzoo.datasets.quora\_qp.load\_data module +--------------------------------------------- + +.. automodule:: matchzoo.datasets.quora_qp.load_data + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: matchzoo.datasets.quora_qp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/matchzoo.datasets.rst b/docs/source/matchzoo.datasets.rst index 8a77b0fc..d559538d 100644 --- a/docs/source/matchzoo.datasets.rst +++ b/docs/source/matchzoo.datasets.rst @@ -7,6 +7,7 @@ Subpackages .. toctree:: matchzoo.datasets.embeddings + matchzoo.datasets.quora_qp matchzoo.datasets.snli matchzoo.datasets.toy matchzoo.datasets.wiki_qa diff --git a/docs/source/matchzoo.embedding.rst b/docs/source/matchzoo.embedding.rst new file mode 100644 index 00000000..f1b0265b --- /dev/null +++ b/docs/source/matchzoo.embedding.rst @@ -0,0 +1,22 @@ +matchzoo.embedding package +========================== + +Submodules +---------- + +matchzoo.embedding.embedding module +----------------------------------- + +.. automodule:: matchzoo.embedding.embedding + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: matchzoo.embedding + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/matchzoo.preprocessors.units.rst b/docs/source/matchzoo.preprocessors.units.rst new file mode 100644 index 00000000..dc501f4a --- /dev/null +++ b/docs/source/matchzoo.preprocessors.units.rst @@ -0,0 +1,134 @@ +matchzoo.preprocessors.units package +==================================== + +Submodules +---------- + +matchzoo.preprocessors.units.digit\_removal module +-------------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.digit_removal + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.fixed\_length module +------------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.fixed_length + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.frequency\_filter module +----------------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.frequency_filter + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.lemmatization module +------------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.lemmatization + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.lowercase module +--------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.lowercase + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.matching\_histogram module +------------------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.matching_histogram + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.ngram\_letter module +------------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.ngram_letter + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.punc\_removal module +------------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.punc_removal + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.stateful\_unit module +-------------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.stateful_unit + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.stemming module +-------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.stemming + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.stop\_removal module +------------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.stop_removal + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.tokenize module +-------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.tokenize + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.unit module +---------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.unit + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.vocabulary module +---------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.vocabulary + :members: + :undoc-members: + :show-inheritance: + +matchzoo.preprocessors.units.word\_hashing module +------------------------------------------------- + +.. automodule:: matchzoo.preprocessors.units.word_hashing + :members: + :undoc-members: + :show-inheritance: + + +Module contents +--------------- + +.. automodule:: matchzoo.preprocessors.units + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/matchzoo.rst b/docs/source/matchzoo.rst index ef27bcbb..56ee8b77 100644 --- a/docs/source/matchzoo.rst +++ b/docs/source/matchzoo.rst @@ -6,7 +6,7 @@ Subpackages .. toctree:: - matchzoo.contrib + matchzoo.auto matchzoo.data_generator matchzoo.data_pack matchzoo.datasets @@ -16,23 +16,13 @@ Subpackages matchzoo.losses matchzoo.metrics matchzoo.models - matchzoo.prepare matchzoo.preprocessors matchzoo.tasks - matchzoo.tune matchzoo.utils Submodules ---------- -matchzoo.logger module ----------------------- - -.. automodule:: matchzoo.logger - :members: - :undoc-members: - :show-inheritance: - matchzoo.version module ----------------------- diff --git a/docs/source/matchzoo.utils.rst b/docs/source/matchzoo.utils.rst index d19d2a78..c679d9fd 100644 --- a/docs/source/matchzoo.utils.rst +++ b/docs/source/matchzoo.utils.rst @@ -12,6 +12,14 @@ matchzoo.utils.list\_recursive\_subclasses module :undoc-members: :show-inheritance: +matchzoo.utils.make\_keras\_optimizer\_picklable module +------------------------------------------------------- + +.. automodule:: matchzoo.utils.make_keras_optimizer_picklable + :members: + :undoc-members: + :show-inheritance: + matchzoo.utils.one\_hot module ------------------------------ diff --git a/matchzoo/__init__.py b/matchzoo/__init__.py index 36484a2e..6205cb23 100644 --- a/matchzoo/__init__.py +++ b/matchzoo/__init__.py @@ -43,7 +43,7 @@ from .embedding.embedding import Embedding -from .utils import one_hot +from .utils import one_hot, make_keras_optimizer_picklable from .preprocessors.build_unit_from_data_pack import build_unit_from_data_pack from .preprocessors.build_vocab_unit import build_vocab_unit diff --git a/matchzoo/auto/preparer/preparer.py b/matchzoo/auto/preparer/preparer.py index 1e5b5fb0..eef74f91 100644 --- a/matchzoo/auto/preparer/preparer.py +++ b/matchzoo/auto/preparer/preparer.py @@ -137,6 +137,7 @@ def _build_model( else: embedding_matrix = None + self._handle_match_pyramid_dpool_size(model) self._handle_drmm_input_shapes(model) assert model.params.completed() @@ -148,6 +149,16 @@ def _build_model( return model, embedding_matrix + def _handle_match_pyramid_dpool_size(self, model): + if isinstance(model, mz.models.MatchPyramid): + suggestion = mz.layers.DynamicPoolingLayer.get_size_suggestion( + msize1=model.params['input_shapes'][0][0], + msize2=model.params['input_shapes'][1][0], + psize1=model.params['dpool_size'][0], + psize2=model.params['dpool_size'][1], + ) + model.params['dpool_size'] = suggestion + def _handle_drmm_input_shapes(self, model): if isinstance(model, mz.models.DRMM): left = model.params['input_shapes'][0] diff --git a/matchzoo/auto/tuner/callbacks/lambda_callback.py b/matchzoo/auto/tuner/callbacks/lambda_callback.py index 47ee3102..c64090de 100644 --- a/matchzoo/auto/tuner/callbacks/lambda_callback.py +++ b/matchzoo/auto/tuner/callbacks/lambda_callback.py @@ -1,5 +1,5 @@ from matchzoo.engine.base_model import BaseModel -from .callback import Callback +from matchzoo.auto.tuner.callbacks.callback import Callback class LambdaCallback(Callback): @@ -7,18 +7,46 @@ class LambdaCallback(Callback): LambdaCallback. Just a shorthand for creating a callback class. See :class:`matchzoo.tuner.callbacks.Callback` for more details. + + Example: + + >>> import matchzoo as mz + >>> model = mz.models.Naive() + >>> model.guess_and_fill_missing_params(verbose=0) + >>> data = mz.datasets.toy.load_data() + >>> data = model.get_default_preprocessor().fit_transform( + ... data, verbose=0) + >>> def show_inputs(*args): + ... print(' '.join(map(str, map(type, args)))) + >>> callback = mz.auto.tuner.callbacks.LambdaCallback( + ... on_run_start=show_inputs, + ... on_build_end=show_inputs, + ... on_run_end=show_inputs + ... ) + >>> _ = mz.auto.tune( + ... params=model.params, + ... train_data=data, + ... test_data=data, + ... num_runs=1, + ... callbacks=[callback], + ... verbose=0, + ... ) # noqa: E501 + + + + """ def __init__( self, on_run_start=None, on_build_end=None, - on_result_end=None + on_run_end=None ): """Init.""" self._on_run_start = on_run_start self._on_build_end = on_build_end - self._on_result_end = on_result_end + self._on_run_end = on_run_end def on_run_start(self, tuner, sample: dict): """`on_run_start`.""" @@ -32,5 +60,5 @@ def on_build_end(self, tuner, model: BaseModel): def on_run_end(self, tuner, model: BaseModel, result: dict): """`on_run_end`.""" - if self._on_result_end: - self._on_result_end(tuner, model, result) + if self._on_run_end: + self._on_run_end(tuner, model, result) diff --git a/matchzoo/auto/tuner/callbacks/load_embedding_matrix.py b/matchzoo/auto/tuner/callbacks/load_embedding_matrix.py index 0b59582f..931d38b2 100644 --- a/matchzoo/auto/tuner/callbacks/load_embedding_matrix.py +++ b/matchzoo/auto/tuner/callbacks/load_embedding_matrix.py @@ -1,13 +1,39 @@ from matchzoo.engine.base_model import BaseModel -from .callback import Callback +from matchzoo.auto.tuner.callbacks.callback import Callback class LoadEmbeddingMatrix(Callback): """ Load a pre-trained embedding after the model is built. + Used with tuner to load a pre-trained embedding matrix for each newly built + model instance. + :param embedding_matrix: Embedding matrix to load. + Example: + + >>> import matchzoo as mz + >>> model = mz.models.ArcI() + >>> prpr = model.get_default_preprocessor() + >>> data = mz.datasets.toy.load_data() + >>> data = prpr.fit_transform(data, verbose=0) + >>> embed = mz.datasets.toy.load_embedding() + >>> term_index = prpr.context['vocab_unit'].state['term_index'] + >>> matrix = embed.build_matrix(term_index) + >>> callback = mz.auto.tuner.callbacks.LoadEmbeddingMatrix(matrix) + >>> model.params.update(prpr.context) + >>> model.params['task'] = mz.tasks.Ranking() + >>> model.params['embedding_output_dim'] = embed.output_dim + >>> result = mz.auto.tune( + ... params=model.params, + ... train_data=data, + ... test_data=data, + ... num_runs=1, + ... callbacks=[callback], + ... verbose=0 + ... ) + """ def __init__(self, embedding_matrix): diff --git a/matchzoo/auto/tuner/callbacks/save_model.py b/matchzoo/auto/tuner/callbacks/save_model.py index 808f48a1..e50d5aef 100644 --- a/matchzoo/auto/tuner/callbacks/save_model.py +++ b/matchzoo/auto/tuner/callbacks/save_model.py @@ -21,9 +21,12 @@ class SaveModel(Callback): """ - def __init__(self, dir_path: typing.Union[str, Path]): + def __init__( + self, + dir_path: typing.Union[str, Path] = mz.USER_TUNED_MODELS_DIR + ): """Init.""" - self._dir_path = dir_path or mz.USER_TUNED_MODELS_DIR + self._dir_path = dir_path def on_run_end(self, tuner, model: BaseModel, result: dict): """Save model on run end.""" diff --git a/matchzoo/contrib/layers/__init__.py b/matchzoo/contrib/layers/__init__.py index e69de29b..09ef7e7c 100644 --- a/matchzoo/contrib/layers/__init__.py +++ b/matchzoo/contrib/layers/__init__.py @@ -0,0 +1,13 @@ +from .attention_layer import AttentionLayer +from .multi_perspective_layer import MultiPerspectiveLayer +from .matching_tensor_layer import MatchingTensorLayer +from .spatial_gru import SpatialGRU +from .decaying_dropout_layer import DecayingDropoutLayer +from .semantic_composite_layer import EncodingLayer + +layer_dict = { + "MatchingTensorLayer": MatchingTensorLayer, + "SpatialGRU": SpatialGRU, + "DecayingDropoutLayer": DecayingDropoutLayer, + "EncodingLayer": EncodingLayer +} diff --git a/matchzoo/contrib/layers/attention_layer.py b/matchzoo/contrib/layers/attention_layer.py new file mode 100644 index 00000000..049d72dc --- /dev/null +++ b/matchzoo/contrib/layers/attention_layer.py @@ -0,0 +1,144 @@ +"""An implementation of Attention Layer for Bimpm model.""" + +import tensorflow as tf +from keras import backend as K +from keras.engine import Layer + + +class AttentionLayer(Layer): + """ + Layer that compute attention for BiMPM model. + + For detailed information, see Bilateral Multi-Perspective Matching for + Natural Language Sentences, section 3.2. + + Reference: + https://github.com/zhiguowang/BiMPM/blob/master/src/layer_utils.py#L145-L196 + + Examples: + >>> import matchzoo as mz + >>> layer = mz.contrib.layers.AttentionLayer(att_dim=50) + >>> layer.compute_output_shape([(32, 10, 100), (32, 40, 100)]) + (32, 10, 40) + + """ + + def __init__(self, + att_dim: int, + att_type: str = 'default', + dropout_rate: float = 0.0): + """ + class: `AttentionLayer` constructor. + + :param att_dim: int + :param att_type: int + """ + super(AttentionLayer, self).__init__() + self._att_dim = att_dim + self._att_type = att_type + self._dropout_rate = dropout_rate + + @property + def att_dim(self): + """Get the attention dimension.""" + return self._att_dim + + @property + def att_type(self): + """Get the attention type.""" + return self._att_type + + def build(self, input_shapes): + """ + Build the layer. + + :param input_shapes: input_shape_lt, input_shape_rt + """ + if not isinstance(input_shapes, list): + raise ValueError('A attention layer should be called ' + 'on a list of inputs.') + + hidden_dim_lt = input_shapes[0][2] + hidden_dim_rt = input_shapes[1][2] + + self.attn_w1 = self.add_weight(name='attn_w1', + shape=(hidden_dim_lt, + self._att_dim), + initializer='uniform', + trainable=True) + if hidden_dim_lt == hidden_dim_rt: + self.attn_w2 = self.attn_w1 + else: + self.attn_w2 = self.add_weight(name='attn_w2', + shape=(hidden_dim_rt, + self._att_dim), + initializer='uniform', + trainable=True) + # diagonal_W: (1, 1, a) + self.diagonal_W = self.add_weight(name='diagonal_W', + shape=(1, + 1, + self._att_dim), + initializer='uniform', + trainable=True) + self.built = True + + def call(self, x: list, **kwargs): + """ + Calculate attention. + + :param x: [reps_lt, reps_rt] + :return attn_prob: [b, s_lt, s_rt] + """ + + if not isinstance(x, list): + raise ValueError('A attention layer should be called ' + 'on a list of inputs.') + + reps_lt, reps_rt = x + + attn_w1 = self.attn_w1 + attn_w1 = tf.expand_dims(tf.expand_dims(attn_w1, axis=0), axis=0) + # => [1, 1, d, a] + + reps_lt = tf.expand_dims(reps_lt, axis=-1) + attn_reps_lt = tf.reduce_sum(reps_lt * attn_w1, axis=2) + # => [b, s_lt, d, -1] + + attn_w2 = self.attn_w2 + attn_w2 = tf.expand_dims(tf.expand_dims(attn_w2, axis=0), axis=0) + # => [1, 1, d, a] + + reps_rt = tf.expand_dims(reps_rt, axis=-1) + attn_reps_rt = tf.reduce_sum(reps_rt * attn_w2, axis=2) # [b, s_rt, d, -1] + + attn_reps_lt = tf.tanh(attn_reps_lt) # [b, s_lt, a] + attn_reps_rt = tf.tanh(attn_reps_rt) # [b, s_rt, a] + + # diagonal_W + attn_reps_lt = attn_reps_lt * self.diagonal_W # [b, s_lt, a] + attn_reps_rt = tf.transpose(attn_reps_rt, (0, 2, 1)) + # => [b, a, s_rt] + + attn_value = K.batch_dot(attn_reps_lt, attn_reps_rt) # [b, s_lt, s_rt] + + # Softmax operation + attn_prob = tf.nn.softmax(attn_value) # [b, s_lt, s_rt] + + # TODO(tjf) remove diagonal or not for normalization + # if remove_diagonal: attn_value = attn_value * diagonal + + if len(x) == 4: + mask_lt, mask_rt = x[2], x[3] + attn_prob *= tf.expand_dims(mask_lt, axis=2) + attn_prob *= tf.expand_dims(mask_rt, axis=1) + + return attn_prob + + def compute_output_shape(self, input_shapes): + """Calculate the layer output shape.""" + if not isinstance(input_shapes, list): + raise ValueError('A attention layer should be called ' + 'on a list of inputs.') + input_shape_lt, input_shape_rt = input_shapes[0], input_shapes[1] + return input_shape_lt[0], input_shape_lt[1], input_shape_rt[1] diff --git a/matchzoo/contrib/layers/decaying_dropout_layer.py b/matchzoo/contrib/layers/decaying_dropout_layer.py new file mode 100644 index 00000000..eb3f5949 --- /dev/null +++ b/matchzoo/contrib/layers/decaying_dropout_layer.py @@ -0,0 +1,99 @@ +"""An implementation of Decaying Dropout Layer.""" + +import tensorflow as tf +from keras import backend as K +from keras.engine import Layer + +class DecayingDropoutLayer(Layer): + """ + Layer that processes dropout with exponential decayed keep rate during + training. + + :param initial_keep_rate: the initial keep rate of decaying dropout. + :param decay_interval: the decay interval of decaying dropout. + :param decay_rate: the decay rate of decaying dropout. + :param noise_shape: a 1D integer tensor representing the shape of the + binary dropout mask that will be multiplied with the input. + :param seed: a python integer to use as random seed. + :param kwargs: standard layer keyword arguments. + + Examples: + >>> import matchzoo as mz + >>> layer = mz.contrib.layers.DecayingDropoutLayer( + ... initial_keep_rate=1.0, + ... decay_interval=10000, + ... decay_rate=0.977, + ... ) + >>> num_batch, num_dim =5, 10 + >>> layer.build([num_batch, num_dim]) + """ + + def __init__(self, + initial_keep_rate: float = 1.0, + decay_interval: int = 10000, + decay_rate: float = 0.977, + noise_shape=None, + seed=None, + **kwargs): + """:class: 'DecayingDropoutLayer' constructor.""" + super(DecayingDropoutLayer, self).__init__(**kwargs) + self._iterations = None + self._initial_keep_rate = initial_keep_rate + self._decay_interval = decay_interval + self._decay_rate = min(1.0, max(0.0, decay_rate)) + self._noise_shape = noise_shape + self._seed = seed + + def _get_noise_shape(self, inputs): + if self._noise_shape is None: + return self._noise_shape + + symbolic_shape = tf.shape(inputs) + noise_shape = [symbolic_shape[axis] if shape is None else shape + for axis, shape in enumerate(self._noise_shape)] + return tuple(noise_shape) + + def build(self, input_shape): + """ + Build the layer. + + :param input_shape: the shape of the input tensor, + for DecayingDropoutLayer we need one input tensor. + """ + + self._iterations = self.add_weight(name='iterations', + shape=(1,), + dtype=K.floatx(), + initializer='zeros', + trainable=False) + super(DecayingDropoutLayer, self).build(input_shape) + + def call(self, inputs, training=None): + """ + The computation logic of DecayingDropoutLayer. + + :param inputs: an input tensor. + """ + noise_shape = self._get_noise_shape(inputs) + t = tf.cast(self._iterations, K.floatx()) + 1 + p = t / float(self._decay_interval) + + keep_rate = self._initial_keep_rate * tf.pow(self._decay_rate, p) + + def dropped_inputs(): + update_op = self._iterations.assign_add([1]) + with tf.control_dependencies([update_op]): + return tf.nn.dropout(inputs, 1 - keep_rate[0], noise_shape, + seed=self._seed) + + return K.in_train_phase(dropped_inputs, inputs, training=training) + + def get_config(self): + """Get the config dict of DecayingDropoutLayer.""" + config = {'initial_keep_rate': self._initial_keep_rate, + 'decay_interval': self._decay_interval, + 'decay_rate': self._decay_rate, + 'noise_shape': self._noise_shape, + 'seed': self._seed} + base_config = super(DecayingDropoutLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/matchzoo/contrib/layers/matching_tensor_layer.py b/matchzoo/contrib/layers/matching_tensor_layer.py new file mode 100644 index 00000000..0578ed1c --- /dev/null +++ b/matchzoo/contrib/layers/matching_tensor_layer.py @@ -0,0 +1,135 @@ +"""An implementation of Matching Tensor Layer.""" +import typing + +import numpy as np +import tensorflow as tf +from keras import backend as K +from keras.engine import Layer +from keras.initializers import constant + + +class MatchingTensorLayer(Layer): + """ + Layer that captures the basic interactions between two tensors. + + :param channels: Number of word interaction tensor channels + :param normalize: Whether to L2-normalize samples along the + dot product axis before taking the dot product. + If set to True, then the output of the dot product + is the cosine proximity between the two samples. + :param init_diag: Whether to initialize the diagonal elements + of the matrix. + :param kwargs: Standard layer keyword arguments. + + Examples: + >>> import matchzoo as mz + >>> layer = mz.contrib.layers.MatchingTensorLayer(channels=4, + ... normalize=True, + ... init_diag=True) + >>> num_batch, left_len, right_len, num_dim = 5, 3, 2, 10 + >>> layer.build([[num_batch, left_len, num_dim], + ... [num_batch, right_len, num_dim]]) + + """ + + def __init__(self, channels: int = 4, normalize: bool = True, + init_diag: bool = True, **kwargs): + """:class:`MatchingTensorLayer` constructor.""" + super().__init__(**kwargs) + self._channels = channels + self._normalize = normalize + self._init_diag = init_diag + self._shape1 = None + self._shape2 = None + + def build(self, input_shape: list): + """ + Build the layer. + + :param input_shape: the shapes of the input tensors, + for MatchingTensorLayer we need two input tensors. + """ + # Used purely for shape validation. + if not isinstance(input_shape, list) or len(input_shape) != 2: + raise ValueError('A `MatchingTensorLayer` layer should be called ' + 'on a list of 2 inputs.') + self._shape1 = input_shape[0] + self._shape2 = input_shape[1] + for idx in (0, 2): + if self._shape1[idx] != self._shape2[idx]: + raise ValueError( + 'Incompatible dimensions: ' + f'{self._shape1[idx]} != {self._shape2[idx]}.' + f'Layer shapes: {self._shape1}, {self._shape2}.' + ) + + if self._init_diag: + interaction_matrix = np.float32( + np.random.uniform( + -0.05, 0.05, + [self._channels, self._shape1[2], self._shape2[2]] + ) + ) + for channel_index in range(self._channels): + np.fill_diagonal(interaction_matrix[channel_index], 0.1) + self.interaction_matrix = self.add_weight( + name='interaction_matrix', + shape=(self._channels, self._shape1[2], self._shape2[2]), + initializer=constant(interaction_matrix), + trainable=True + ) + else: + self.interaction_matrix = self.add_weight( + name='interaction_matrix', + shape=(self._channels, self._shape1[2], self._shape2[2]), + initializer='uniform', + trainable=True + ) + super(MatchingTensorLayer, self).build(input_shape) + + def call(self, inputs: list, **kwargs) -> typing.Any: + """ + The computation logic of MatchingTensorLayer. + + :param inputs: two input tensors. + """ + x1 = inputs[0] + x2 = inputs[1] + # Normalize x1 and x2 + if self._normalize: + x1 = K.l2_normalize(x1, axis=2) + x2 = K.l2_normalize(x2, axis=2) + + # b = batch size + # l = length of `x1` + # r = length of `x2` + # d, e = embedding size + # c = number of channels + # output = [b, c, l, r] + output = tf.einsum( + 'bld,cde,bre->bclr', + x1, self.interaction_matrix, x2 + ) + return output + + def compute_output_shape(self, input_shape: list) -> tuple: + """ + Calculate the layer output shape. + + :param input_shape: the shapes of the input tensors, + for MatchingTensorLayer we need two input tensors. + """ + if not isinstance(input_shape, list) or len(input_shape) != 2: + raise ValueError('A `MatchingTensorLayer` layer should be called ' + 'on a list of 2 inputs.') + shape1 = list(input_shape[0]) + shape2 = list(input_shape[1]) + if len(shape1) != 3 or len(shape2) != 3: + raise ValueError('A `MatchingTensorLayer` layer should be called ' + 'on 2 inputs with 3 dimensions.') + if shape1[0] != shape2[0] or shape1[2] != shape2[2]: + raise ValueError('A `MatchingTensorLayer` layer should be called ' + 'on 2 inputs with same 0,2 dimensions.') + + output_shape = [shape1[0], self._channels, shape1[1], shape2[1]] + return tuple(output_shape) diff --git a/matchzoo/contrib/layers/multi_perspective_layer.py b/matchzoo/contrib/layers/multi_perspective_layer.py new file mode 100644 index 00000000..64cfd338 --- /dev/null +++ b/matchzoo/contrib/layers/multi_perspective_layer.py @@ -0,0 +1,468 @@ +"""An implementation of MultiPerspectiveLayer for Bimpm model.""" + +import tensorflow as tf +from keras import backend as K +from keras.engine import Layer + +from matchzoo.contrib.layers.attention_layer import AttentionLayer + + +class MultiPerspectiveLayer(Layer): + """ + A keras implementation of multi-perspective layer of BiMPM. + + For detailed information, see Bilateral Multi-Perspective + Matching for Natural Language Sentences, section 3.2. + + Examples: + >>> import matchzoo as mz + >>> perspective={'full': True, 'max-pooling': True, + ... 'attentive': True, 'max-attentive': True} + >>> layer = mz.contrib.layers.MultiPerspectiveLayer( + ... att_dim=50, mp_dim=20, perspective=perspective) + >>> layer.compute_output_shape( + ... [(32, 10, 100), (32, 50), None, (32, 50), None, + ... [(32, 40, 100), (32, 50), None, (32, 50), None]]) + (32, 10, 83) + + """ + + def __init__(self, + att_dim: int, + mp_dim: int, + perspective: dict): + """Class initialization.""" + super(MultiPerspectiveLayer, self).__init__() + self._att_dim = att_dim + self._mp_dim = mp_dim + self._perspective = perspective + + @classmethod + def list_available_perspectives(cls) -> list: + """List available strategy for multi-perspective matching.""" + return ['full', 'max-pooling', 'attentive', 'max-attentive'] + + @property + def num_perspective(self): + """Get the number of perspectives that is True.""" + return sum(self._perspective.values()) + + def build(self, input_shape: list): + """Input shape.""" + # The shape of the weights is l * d. + if self._perspective.get('full'): + self.full_match = MpFullMatch(self._mp_dim) + + if self._perspective.get('max-pooling'): + self.max_pooling_match = MpMaxPoolingMatch(self._mp_dim) + + if self._perspective.get('attentive'): + self.attentive_match = MpAttentiveMatch(self._att_dim, + self._mp_dim) + + if self._perspective.get('max-attentive'): + self.max_attentive_match = MpMaxAttentiveMatch(self._att_dim) + self.built = True + + def call(self, x: list, **kwargs): + """Call.""" + seq_lt, seq_rt = x[:5], x[5:] + # unpack seq_left and seq_right + # all hidden states, last hidden state of forward pass, + # last cell state of forward pass, last hidden state of + # backward pass, last cell state of backward pass. + lstm_reps_lt, forward_h_lt, _, backward_h_lt, _ = seq_lt + lstm_reps_rt, forward_h_rt, _, backward_h_rt, _ = seq_rt + + match_tensor_list = [] + match_dim = 0 + if self._perspective.get('full'): + # Each forward & backward contextual embedding compare + # with the last step of the last time step of the other sentence. + h_lt = tf.concat([forward_h_lt, backward_h_lt], axis=-1) + full_match_tensor = self.full_match([h_lt, lstm_reps_rt]) + match_tensor_list.append(full_match_tensor) + match_dim += self._mp_dim + 1 + + if self._perspective.get('max-pooling'): + # Each contextual embedding compare with each contextual embedding. + # retain the maximum of each dimension. + max_match_tensor = self.max_pooling_match([lstm_reps_lt, + lstm_reps_rt]) + match_tensor_list.append(max_match_tensor) + match_dim += self._mp_dim + + if self._perspective.get('attentive'): + # Each contextual embedding compare with each contextual embedding. + # retain sum of weighted mean of each dimension. + attentive_tensor = self.attentive_match([lstm_reps_lt, + lstm_reps_rt]) + match_tensor_list.append(attentive_tensor) + match_dim += self._mp_dim + 1 + + if self._perspective.get('max-attentive'): + # Each contextual embedding compare with each contextual embedding. + # retain max of weighted mean of each dimension. + relevancy_matrix = _calc_relevancy_matrix(lstm_reps_lt, + lstm_reps_rt) + max_attentive_tensor = self.max_attentive_match([lstm_reps_lt, + lstm_reps_rt, + relevancy_matrix]) + match_tensor_list.append(max_attentive_tensor) + match_dim += self._mp_dim + 1 + + mp_tensor = tf.concat(match_tensor_list, axis=-1) + return mp_tensor + + def compute_output_shape(self, input_shape: list): + """Compute output shape.""" + shape_a = input_shape[0] + + match_dim = 0 + if self._perspective.get('full'): + match_dim += self._mp_dim + 1 + if self._perspective.get('max-pooling'): + match_dim += self._mp_dim + if self._perspective.get('attentive'): + match_dim += self._mp_dim + 1 + if self._perspective.get('max-attentive'): + match_dim += self._mp_dim + 1 + + return shape_a[0], shape_a[1], match_dim + + +class MpFullMatch(Layer): + """Mp Full Match Layer.""" + + def __init__(self, mp_dim): + """Init.""" + super(MpFullMatch, self).__init__() + self.mp_dim = mp_dim + + def build(self, input_shapes): + """Build.""" + # input_shape = input_shapes[0] + self.built = True + + def call(self, x, **kwargs): + """Call. + """ + rep_lt, reps_rt = x + att_lt = tf.expand_dims(rep_lt, 1) + + match_tensor, match_dim = _multi_perspective_match(self.mp_dim, + reps_rt, + att_lt) + # match_tensor => [b, len_rt, mp_dim+1] + return match_tensor + + def compute_output_shape(self, input_shape): + """Compute output shape.""" + return input_shape[1][0], input_shape[1][1], self.mp_dim + 1 + + +class MpMaxPoolingMatch(Layer): + """MpMaxPoolingMatch.""" + + def __init__(self, mp_dim): + """Init.""" + super(MpMaxPoolingMatch, self).__init__() + self.mp_dim = mp_dim + + def build(self, input_shapes): + """Build.""" + d = input_shapes[0][-1] + self.kernel = self.add_weight(name='kernel', + shape=(1, 1, 1, self.mp_dim, d), + initializer='uniform', + trainable=True) + self.built = True + + def call(self, x, **kwargs): + """Call.""" + reps_lt, reps_rt = x + + # kernel: [1, 1, 1, mp_dim, d] + # lstm_lt => [b, len_lt, 1, 1, d] + reps_lt = tf.expand_dims(reps_lt, axis=2) + reps_lt = tf.expand_dims(reps_lt, axis=2) + reps_lt = reps_lt * self.kernel + + # lstm_rt -> [b, 1, len_rt, 1, d] + reps_rt = tf.expand_dims(reps_rt, axis=2) + reps_rt = tf.expand_dims(reps_rt, axis=1) + + match_tensor = _cosine_distance(reps_lt, reps_rt, cosine_norm=False) + max_match_tensor = tf.reduce_max(match_tensor, axis=1) + # match_tensor => [b, len_rt, m] + return max_match_tensor + + def compute_output_shape(self, input_shape): + """Compute output shape.""" + return input_shape[1][0], input_shape[1][1], self.mp_dim + + +class MpAttentiveMatch(Layer): + """ + MpAttentiveMatch Layer. + + Reference: + https://github.com/zhiguowang/BiMPM/blob/master/src/match_utils.py#L188-L193 + + Examples: + >>> import matchzoo as mz + >>> layer = mz.contrib.layers.multi_perspective_layer.MpAttentiveMatch( + ... att_dim=30, mp_dim=20) + >>> layer.compute_output_shape([(32, 10, 100), (32, 40, 100)]) + (32, 40, 20) + + """ + + def __init__(self, att_dim, mp_dim): + """Init.""" + super(MpAttentiveMatch, self).__init__() + self.att_dim = att_dim + self.mp_dim = mp_dim + + def build(self, input_shapes): + """Build.""" + # input_shape = input_shapes[0] + self.built = True + + def call(self, x, **kwargs): + """Call.""" + reps_lt, reps_rt = x[0], x[1] + # attention prob matrix + attention_layer = AttentionLayer(self.att_dim) + attn_prob = attention_layer([reps_rt, reps_lt]) + # attention reps + att_lt = K.batch_dot(attn_prob, reps_lt) + # mp match + attn_match_tensor, match_dim = _multi_perspective_match(self.mp_dim, + reps_rt, + att_lt) + return attn_match_tensor + + def compute_output_shape(self, input_shape): + """Compute output shape.""" + return input_shape[1][0], input_shape[1][1], self.mp_dim + + +class MpMaxAttentiveMatch(Layer): + """MpMaxAttentiveMatch.""" + + def __init__(self, mp_dim): + """Init.""" + super(MpMaxAttentiveMatch, self).__init__() + self.mp_dim = mp_dim + + def build(self, input_shapes): + """Build.""" + # input_shape = input_shapes[0] + self.built = True + + def call(self, x): + """Call.""" + reps_lt, reps_rt = x[0], x[1] + relevancy_matrix = x[2] + max_att_lt = cal_max_question_representation(reps_lt, relevancy_matrix) + max_attentive_tensor, match_dim = _multi_perspective_match(self.mp_dim, + reps_rt, + max_att_lt) + return max_attentive_tensor + + +def cal_max_question_representation(reps_lt, attn_scores): + """ + Calculate max_question_representation. + + :param reps_lt: [batch_size, passage_len, hidden_size] + :param attn_scores: [] + :return: [batch_size, passage_len, hidden_size]. + """ + attn_positions = tf.argmax(attn_scores, axis=2) + max_reps_lt = collect_representation(reps_lt, attn_positions) + return max_reps_lt + + +def collect_representation(representation, positions): + """ + Collect_representation. + + :param representation: [batch_size, node_num, feature_dim] + :param positions: [batch_size, neighbour_num] + :return: [batch_size, neighbour_num]? + """ + return collect_probs(representation, positions) + + +def collect_final_step_of_lstm(lstm_representation, lengths): + """ + Collect final step of lstm. + + :param lstm_representation: [batch_size, len_rt, dim] + :param lengths: [batch_size] + :return: [batch_size, dim] + """ + lengths = tf.maximum(lengths, K.zeros_like(lengths)) + + batch_size = tf.shape(lengths)[0] + # shape (batch_size) + batch_nums = tf.range(0, limit=batch_size) + # shape (batch_size, 2) + indices = tf.stack((batch_nums, lengths), axis=1) + result = tf.gather_nd(lstm_representation, indices, + name='last-forwar-lstm') + # [batch_size, dim] + return result + + +def collect_probs(probs, positions): + """ + Collect Probabilities. + + Reference: + https://github.com/zhiguowang/BiMPM/blob/master/src/layer_utils.py#L128-L140 + :param probs: [batch_size, chunks_size] + :param positions: [batch_size, pair_size] + :return: [batch_size, pair_size] + """ + batch_size = tf.shape(probs)[0] + pair_size = tf.shape(positions)[1] + # shape (batch_size) + batch_nums = K.arange(0, batch_size) + # [batch_size, 1] + batch_nums = tf.reshape(batch_nums, shape=[-1, 1]) + # [batch_size, pair_size] + batch_nums = K.tile(batch_nums, [1, pair_size]) + + # shape (batch_size, pair_size, 2) + # Alert: to solve error message + positions = tf.cast(positions, tf.int32) + indices = tf.stack([batch_nums, positions], axis=2) + + pair_probs = tf.gather_nd(probs, indices) + # pair_probs = tf.reshape(pair_probs, shape=[batch_size, pair_size]) + return pair_probs + + +def _multi_perspective_match(mp_dim, reps_rt, att_lt, + with_cosine=True, with_mp_cosine=True): + """ + The core function of zhiguowang's implementation. + + reference: + https://github.com/zhiguowang/BiMPM/blob/master/src/match_utils.py#L207-L223 + :param mp_dim: about 20 + :param reps_rt: [batch, len_rt, dim] + :param att_lt: [batch, len_rt, dim] + :param with_cosine: True + :param with_mp_cosine: True + :return: [batch, len, 1 + mp_dim] + """ + shape_rt = tf.shape(reps_rt) + batch_size = shape_rt[0] + len_lt = shape_rt[1] + + match_dim = 0 + match_result_list = [] + if with_cosine: + cosine_tensor = _cosine_distance(reps_rt, att_lt, False) + cosine_tensor = tf.reshape(cosine_tensor, + [batch_size, len_lt, 1]) + match_result_list.append(cosine_tensor) + match_dim += 1 + + if with_mp_cosine: + mp_cosine_layer = MpCosineLayer(mp_dim) + mp_cosine_tensor = mp_cosine_layer([reps_rt, att_lt]) + mp_cosine_tensor = tf.reshape(mp_cosine_tensor, + [batch_size, len_lt, mp_dim]) + match_result_list.append(mp_cosine_tensor) + match_dim += mp_cosine_layer.mp_dim + + match_result = tf.concat(match_result_list, 2) + return match_result, match_dim + + +class MpCosineLayer(Layer): + """ + Implementation of Multi-Perspective Cosine Distance. + + Reference: + https://github.com/zhiguowang/BiMPM/blob/master/src/match_utils.py#L121-L129 + + Examples: + >>> import matchzoo as mz + >>> layer = mz.contrib.layers.multi_perspective_layer.MpCosineLayer( + ... mp_dim=50) + >>> layer.compute_output_shape([(32, 10, 100), (32, 10, 100)]) + (32, 10, 50) + + """ + + def __init__(self, mp_dim, **kwargs): + """Init.""" + self.mp_dim = mp_dim + super(MpCosineLayer, self).__init__(**kwargs) + + def build(self, input_shape): + """Build.""" + self.kernel = self.add_weight(name='kernel', + shape=(1, 1, self.mp_dim, + input_shape[0][-1]), + initializer='uniform', + trainable=True) + super(MpCosineLayer, self).build(input_shape) + + def call(self, x, **kwargs): + """Call.""" + v1, v2 = x + v1 = tf.expand_dims(v1, 2) * self.kernel # [b, s_lt, m, d] + v2 = tf.expand_dims(v2, 2) # [b, s_lt, 1, d] + return _cosine_distance(v1, v2, False) + + def compute_output_shape(self, input_shape): + """Compute output shape.""" + return input_shape[0][0], input_shape[0][1], self.mp_dim + + +def _calc_relevancy_matrix(reps_lt, reps_rt): + reps_lt = tf.expand_dims(reps_lt, 1) # [b, 1, len_lt, d] + reps_rt = tf.expand_dims(reps_rt, 2) # [b, len_rt, 1, d] + relevancy_matrix = _cosine_distance(reps_lt, reps_rt) + # => [b, len_rt, len_lt, d] + return relevancy_matrix + + +def _mask_relevancy_matrix(relevancy_matrix, mask_lt, mask_rt): + """ + Mask relevancy matrix. + + :param relevancy_matrix: [b, len_rt, len_lt] + :param mask_lt: [b, len_lt] + :param mask_rt: [b, len_rt] + :return: masked_matrix: [b, len_rt, len_lt] + """ + if mask_lt is not None: + relevancy_matrix = relevancy_matrix * tf.expand_dims(mask_lt, 1) + relevancy_matrix = relevancy_matrix * tf.expand_dims(mask_rt, 2) + return relevancy_matrix + + +def _cosine_distance(v1, v2, cosine_norm=True, eps=1e-6): + """ + Only requires `tf.reduce_sum(v1 * v2, axis=-1)`. + + :param v1: [batch, time_steps(v1), 1, m, d] + :param v2: [batch, 1, time_steps(v2), m, d] + :param cosine_norm: True + :param eps: 1e-6 + :return: [batch, time_steps(v1), time_steps(v2), m] + """ + cosine_numerator = tf.reduce_sum(v1 * v2, axis=-1) + if not cosine_norm: + return K.tanh(cosine_numerator) + v1_norm = K.sqrt(tf.maximum(tf.reduce_sum(tf.square(v1), axis=-1), eps)) + v2_norm = K.sqrt(tf.maximum(tf.reduce_sum(tf.square(v2), axis=-1), eps)) + return cosine_numerator / v1_norm / v2_norm diff --git a/matchzoo/contrib/layers/semantic_composite_layer.py b/matchzoo/contrib/layers/semantic_composite_layer.py new file mode 100644 index 00000000..9f6cb5b4 --- /dev/null +++ b/matchzoo/contrib/layers/semantic_composite_layer.py @@ -0,0 +1,121 @@ +"""An implementation of EncodingModule for DIIN model.""" + +import tensorflow as tf +from keras import backend as K +from keras.engine import Layer + +from matchzoo.contrib.layers import DecayingDropoutLayer + + +class EncodingLayer(Layer): + """ + Apply a self-attention layer and a semantic composite fuse gate + to compute the encoding result of one tensor. + + :param initial_keep_rate: the initial_keep_rate parameter of + DecayingDropoutLayer. + :param decay_interval: the decay_interval parameter of + DecayingDropoutLayer. + :param decay_rate: the decay_rate parameter of DecayingDropoutLayer. + :param kwargs: standard layer keyword arguments. + + Example: + >>> import matchzoo as mz + >>> layer = mz.contrib.layers.EncodingLayer(1.0, 10000, 0.977) + >>> num_batch, left_len, num_dim = 5, 32, 10 + >>> layer.build([num_batch, left_len, num_dim]) + """ + + def __init__(self, + initial_keep_rate: float, + decay_interval: int, + decay_rate: float, + **kwargs): + """:class: 'EncodingLayer' constructor.""" + super(EncodingLayer, self).__init__(**kwargs) + self._initial_keep_rate = initial_keep_rate + self._decay_interval = decay_interval + self._decay_rate = decay_rate + self._w_itr_att = None + self._w1 = None + self._w2 = None + self._w3 = None + self._b1 = None + self._b2 = None + self._b3 = None + + def build(self, input_shape): + """ + Build the layer. + + :param input_shape: the shape of the input tensor, + for EncodingLayer we need one input tensor. + """ + d = input_shape[-1] + + self._w_itr_att = self.add_weight( + name='w_itr_att', shape=(3 * d,), initializer='glorot_uniform') + self._w1 = self.add_weight( + name='w1', shape=(2 * d, d,), initializer='glorot_uniform') + self._w2 = self.add_weight( + name='w2', shape=(2 * d, d,), initializer='glorot_uniform') + self._w3 = self.add_weight( + name='w3', shape=(2 * d, d,), initializer='glorot_uniform') + self._b1 = self.add_weight( + name='b1', shape=(d,), initializer='zeros') + self._b2 = self.add_weight( + name='b2', shape=(d,), initializer='zeros') + self._b3 = self.add_weight( + name='b3', shape=(d,), initializer='zeros') + + super(EncodingLayer, self).build(input_shape) + + def call(self, inputs, **kwargs): + """ + The computation logic of EncodingLayer. + + :param inputs: an input tensor. + """ + # Scalar dimensions referenced here: + # b = batch size + # p = inputs.shape()[1] + # d = inputs.shape()[2] + + # The input shape is [b, p, d] + # shape = [b, 1, p, d] + x = tf.expand_dims(inputs, 1) * 0 + # shape = [b, 1, d, p] + x = tf.transpose(x, (0, 1, 3, 2)) + # shape = [b, p, d, p] + mid = x + tf.expand_dims(inputs, -1) + # shape = [b, p, d, p] + up = tf.transpose(mid, (0, 3, 2, 1)) + # shape = [b, p, 3d, p] + inputs_concat = tf.concat([up, mid, up * mid], axis=2) + + # Self-attention layer. + # shape = [b, p, p] + A = K.dot(self._w_itr_att, inputs_concat) + # shape = [b, p, p] + SA = tf.nn.softmax(A, axis=2) + # shape = [b, p, d] + itr_attn = K.batch_dot(SA, inputs) + + # Semantic composite fuse gate. + # shape = [b, p, 2d] + inputs_attn_concat = tf.concat([inputs, itr_attn], axis=2) + concat_dropout = DecayingDropoutLayer( + initial_keep_rate=self._initial_keep_rate, + decay_interval=self._decay_interval, + decay_rate=self._decay_rate + )(inputs_attn_concat) + # shape = [b, p, d] + z = tf.tanh(K.dot(concat_dropout, self._w1) + self._b1) + # shape = [b, p, d] + r = tf.sigmoid(K.dot(concat_dropout, self._w2) + self._b2) + # shape = [b, p, d] + f = tf.sigmoid(K.dot(concat_dropout, self._w3) + self._b3) + # shape = [b, p, d] + encoding = r * inputs + f * z + + return encoding diff --git a/matchzoo/contrib/layers/spatial_gru.py b/matchzoo/contrib/layers/spatial_gru.py new file mode 100644 index 00000000..c583c9d6 --- /dev/null +++ b/matchzoo/contrib/layers/spatial_gru.py @@ -0,0 +1,290 @@ +"""An implementation of Spatial GRU Layer.""" +import typing +import tensorflow as tf +from keras import backend as K +from keras.engine import Layer +from keras.layers import Permute +from keras.layers import Reshape +from keras import activations +from keras import initializers + + +class SpatialGRU(Layer): + """ + Spatial GRU layer. + + :param units: Number of SpatialGRU units. + :param activation: Activation function to use. Default: + hyperbolic tangent (`tanh`). If you pass `None`, no + activation is applied (ie. "linear" activation: `a(x) = x`). + :param recurrent_activation: Activation function to use for + the recurrent step. Default: sigmoid (`sigmoid`). + If you pass `None`, no activation is applied (ie. "linear" + activation: `a(x) = x`). + :param kernel_initializer: Initializer for the `kernel` weights + matrix, used for the linear transformation of the inputs. + :param recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, used for the linear transformation of the + recurrent state. + :param direction: Scanning direction. `lt` (i.e., left top) + indicates the scanning from left top to right bottom, and + `rb` (i.e., right bottom) indicates the scanning from + right bottom to left top. + :param kwargs: Standard layer keyword arguments. + + Examples: + >>> import matchzoo as mz + >>> layer = mz.contrib.layers.SpatialGRU(units=10, + ... direction='lt') + >>> num_batch, channel, left_len, right_len = 5, 5, 3, 2 + >>> layer.build([num_batch, channel, left_len, right_len]) + + """ + + def __init__( + self, + units: int = 10, + activation: str = 'tanh', + recurrent_activation: str = 'sigmoid', + kernel_initializer: str = 'glorot_uniform', + recurrent_initializer: str = 'orthogonal', + direction: str = 'lt', + **kwargs + ): + """:class:`SpatialGRU` constructor.""" + super().__init__(**kwargs) + self._units = units + self._activation = activations.get(activation) + self._recurrent_activation = activations.get(recurrent_activation) + + self._kernel_initializer = initializers.get(kernel_initializer) + self._recurrent_initializer = initializers.get(recurrent_initializer) + self._direction = direction + + def build(self, input_shape: typing.Any): + """ + Build the layer. + + :param input_shape: the shapes of the input tensors. + """ + # Scalar dimensions referenced here: + # B = batch size (number of sequences) + # L = `input_left` sequence length + # R = `input_right` sequence length + # C = number of channels + # U = number of units + + # input_shape = [B, C, L, R] + self._batch_size = input_shape[0] + self._channel = input_shape[1] + self._input_dim = self._channel + 3 * self._units + + self._text1_maxlen = input_shape[2] + self._text2_maxlen = input_shape[3] + self._recurrent_step = self._text1_maxlen * self._text2_maxlen + # W = [3*U+C, 7*U] + self._W = self.add_weight( + name='W', + shape=(self._input_dim, self._units * 7), + initializer=self._kernel_initializer, + trainable=True + ) + # U = [3*U, U] + self._U = self.add_weight( + name='U', + shape=(self._units * 3, self._units), + initializer=self._recurrent_initializer, + trainable=True + ) + # bias = [8*U,] + self._bias = self.add_weight( + name='bias', + shape=(self._units * 8,), + initializer='zeros', + trainable=True + ) + + # w_rl, w_rt, w_rd = [B, 3*U] + self._wr = self._W[:, :self._units * 3] + # b_rl, b_rt, b_rd = [B, 3*U] + self._br = self._bias[:self._units * 3] + # w_zi, w_zl, w_zt, w_zd = [B, 4*U] + self._wz = self._W[:, self._units * 3: self._units * 7] + # b_zi, b_zl, b_zt, b_zd = [B, 4*U] + self._bz = self._bias[self._units * 3: self._units * 7] + # w_ij = [C, U] + self._w_ij = self.add_weight( + name='W_ij', + shape=(self._channel, self._units), + initializer=self._recurrent_initializer, + trainable=True + ) + # b_ij = [7*U] + self._b_ij = self._bias[self._units * 7:] + super(SpatialGRU, self).build(input_shape) + + def softmax_by_row(self, z: typing.Any) -> tuple: + """Conduct softmax on each dimension across the four gates.""" + + # z_transform: [B, U, 4] + z_transform = Permute((2, 1))(Reshape((4, self._units))(z)) + size = [-1, 1, -1] + # Perform softmax on each slice + for i in range(0, self._units): + begin = [0, i, 0] + # z_slice: [B, 1, 4] + z_slice = tf.slice(z_transform, begin, size) + if i == 0: + z_s = tf.nn.softmax(z_slice) + else: + z_s = tf.concat([z_s, tf.nn.softmax(z_slice)], 1) + # zi, zl, zt, zd: [B, U] + zi, zl, zt, zd = tf.unstack(z_s, axis=2) + return zi, zl, zt, zd + + def calculate_recurrent_unit( + self, + inputs: typing.Any, + states: typing.Any, + step: int, + h: typing.Any, + ) -> tuple: + """ + Calculate recurrent unit. + + :param inputs: A TensorArray which contains interaction + between left text and right text. + :param states: A TensorArray which stores the hidden state + of every step. + :param step: Recurrent step. + :param h: Hidden state from last operation. + """ + # Get index i, j + i = tf.math.floordiv(step, tf.constant(self._text2_maxlen)) + j = tf.math.mod(step, tf.constant(self._text2_maxlen)) + + # Get hidden state h_diag, h_top, h_left + # h_diag, h_top, h_left = [B, U] + h_diag = states.read(i * (self._text2_maxlen + 1) + j) + h_top = states.read(i * (self._text2_maxlen + 1) + j + 1) + h_left = states.read((i + 1) * (self._text2_maxlen + 1) + j) + + # Get interaction between word i, j: s_ij + # s_ij = [B, C] + s_ij = inputs.read(step) + + # Concatenate h_top, h_left, h_diag, s_ij + # q = [B, 3*U+C] + q = tf.concat([tf.concat([h_top, h_left], 1), + tf.concat([h_diag, s_ij], 1)], 1) + + # Calculate reset gate + # r = [B, 3*U] + r = self._recurrent_activation( + self._time_distributed_dense(self._wr, q, self._br)) + + # Calculate updating gate + # z: [B, 4*U] + z = self._time_distributed_dense(self._wz, q, self._bz) + + # Perform softmax + # zi, zl, zt, zd: [B, U] + zi, zl, zt, zd = self.softmax_by_row(z) + + # Get h_ij_ + # h_ij_ = [B, U] + h_ij_l = self._time_distributed_dense(self._w_ij, s_ij, self._b_ij) + h_ij_r = K.dot(r * (tf.concat([h_left, h_top, h_diag], 1)), self._U) + h_ij_ = self._activation(h_ij_l + h_ij_r) + + # Calculate h_ij + # h_ij = [B, U] + h_ij = zl * h_left + zt * h_top + zd * h_diag + zi * h_ij_ + + # Write h_ij to states + states = states.write(((i + 1) * (self._text2_maxlen + 1) + j + 1), + h_ij) + h_ij.set_shape(h_top.get_shape()) + + return inputs, states, step + 1, h_ij + + def call(self, inputs: list, **kwargs) -> typing.Any: + """ + The computation logic of SpatialGRU. + + :param inputs: input tensors. + """ + batch_size = tf.shape(inputs)[0] + # h0 = [B, U] + self._bounder_state_h0 = tf.zeros([batch_size, self._units]) + + # input_x = [L, R, B, C] + input_x = tf.transpose(inputs, [2, 3, 0, 1]) + if self._direction == 'rb': + # input_x: [R, L, B, C] + input_x = tf.reverse(input_x, [0, 1]) + elif self._direction != 'lt': + raise ValueError(f"Invalid direction. " + f"`{self._direction}` received. " + f"Must be in `lt`, `rb`.") + # input_x = [L*R*B, C] + input_x = tf.reshape(input_x, [-1, self._channel]) + # input_x = L*R * [B, C] + input_x = tf.split( + axis=0, + num_or_size_splits=self._text1_maxlen * self._text2_maxlen, + value=input_x + ) + + # inputs = L*R * [B, C] + inputs = tf.TensorArray( + dtype=tf.float32, + size=self._text1_maxlen * self._text2_maxlen, + name='inputs' + ) + inputs = inputs.unstack(input_x) + + # states = (L+1)*(R+1) * [B, U] + states = tf.TensorArray( + dtype=tf.float32, + size=(self._text1_maxlen + 1) * (self._text2_maxlen + 1), + name='states', + clear_after_read=False + ) + # Initialize states + for i in range(self._text2_maxlen + 1): + states = states.write(i, self._bounder_state_h0) + for i in range(1, self._text1_maxlen + 1): + states = states.write(i * (self._text2_maxlen + 1), + self._bounder_state_h0) + + # Calculate h_ij + # h_ij = [B, U] + _, _, _, h_ij = tf.while_loop( + cond=lambda _0, _1, i, _3: tf.less(i, self._recurrent_step), + body=self.calculate_recurrent_unit, + loop_vars=( + inputs, + states, + tf.constant(0, dtype=tf.int32), + self._bounder_state_h0 + ), + parallel_iterations=1, + swap_memory=True + ) + return h_ij + + def compute_output_shape(self, input_shape: typing.Any) -> tuple: + """ + Calculate the layer output shape. + + :param input_shape: the shapes of the input tensors. + """ + output_shape = [input_shape[0], self._units] + return tuple(output_shape) + + @classmethod + def _time_distributed_dense(cls, w, x, b): + x = K.dot(x, w) + x = K.bias_add(x, b) + return x diff --git a/matchzoo/contrib/models/__init__.py b/matchzoo/contrib/models/__init__.py index 6faa60bb..cefd02a0 100644 --- a/matchzoo/contrib/models/__init__.py +++ b/matchzoo/contrib/models/__init__.py @@ -1 +1,6 @@ from .match_lstm import MatchLSTM +from .match_srnn import MatchSRNN +from .hbmp import HBMP +from .esim import ESIM +from .bimpm import BiMPM +from .diin import DIIN diff --git a/matchzoo/contrib/models/bimpm.py b/matchzoo/contrib/models/bimpm.py new file mode 100644 index 00000000..112967ea --- /dev/null +++ b/matchzoo/contrib/models/bimpm.py @@ -0,0 +1,149 @@ +"""BiMPM.""" + +from keras.models import Model +from keras.layers import Dense, Concatenate, Dropout +from keras.layers import Bidirectional, LSTM + +from matchzoo.engine.param import Param +from matchzoo.engine.param_table import ParamTable +from matchzoo.engine.base_model import BaseModel +from matchzoo.contrib.layers import MultiPerspectiveLayer + + +class BiMPM(BaseModel): + """ + BiMPM. + + Reference: + https://github.com/zhiguowang/BiMPM/blob/master/src/SentenceMatchModelGraph.py#L43-L186 + Examples: + >>> import matchzoo as mz + >>> model = mz.contrib.models.BiMPM() + >>> model.guess_and_fill_missing_params(verbose=0) + >>> model.build() + + """ + + @classmethod + def get_default_params(cls) -> ParamTable: + """:return: model default parameters.""" + params = super().get_default_params(with_embedding=True) + params['optimizer'] = 'adam' + + # params.add(Param('dim_word_embedding', 50)) + # TODO(tjf): remove unused params in the final version + # params.add(Param('dim_char_embedding', 50)) + # params.add(Param('word_embedding_mat')) + # params.add(Param('char_embedding_mat')) + # params.add(Param('embedding_random_scale', 0.2)) + # params.add(Param('activation_embedding', 'softmax')) + + # BiMPM Setting + params.add(Param('perspective', {'full': True, + 'max-pooling': True, + 'attentive': True, + 'max-attentive': True})) + params.add(Param('mp_dim', 3)) + params.add(Param('att_dim', 3)) + params.add(Param('hidden_size', 4)) + params.add(Param('dropout_rate', 0.0)) + params.add(Param('w_initializer', 'glorot_uniform')) + params.add(Param('b_initializer', 'zeros')) + params.add(Param('activation_hidden', 'linear')) + + params.add(Param('with_match_highway', False)) + params.add(Param('with_aggregation_highway', False)) + + return params + + def build(self): + """Build model structure.""" + # ~ Input Layer + input_left, input_right = self._make_inputs() + + # Word Representation Layer + # TODO: concatenate word level embedding and character level embedding. + embedding = self._make_embedding_layer() + embed_left = embedding(input_left) + embed_right = embedding(input_right) + + # L119-L121 + # https://github.com/zhiguowang/BiMPM/blob/master/src/SentenceMatchModelGraph.py#L119-L121 + embed_left = Dropout(self._params['dropout_rate'])(embed_left) + embed_right = Dropout(self._params['dropout_rate'])(embed_right) + + # ~ Word Level Matching Layer + # Reference: + # https://github.com/zhiguowang/BiMPM/blob/master/src/match_utils.py#L207-L223 + # TODO + pass + + # ~ Encoding Layer + # Note: When merge_mode = None, output will be [forward, backward], + # The default merge_mode is concat, and the output will be [lstm]. + # If with return_state, then the output would append [h,c,h,c]. + bi_lstm = Bidirectional( + LSTM(self._params['hidden_size'], + return_sequences=True, + return_state=True, + dropout=self._params['dropout_rate'], + kernel_initializer=self._params['w_initializer'], + bias_initializer=self._params['b_initializer']), + merge_mode='concat') + # x_left = [lstm_lt, forward_h_lt, _, backward_h_lt, _ ] + x_left = bi_lstm(embed_left) + x_right = bi_lstm(embed_right) + + # ~ Multi-Perspective Matching layer. + # Output is two sequence of vectors. + # Cons: Haven't support multiple context layer + multi_perspective = MultiPerspectiveLayer(self._params['att_dim'], + self._params['mp_dim'], + self._params['perspective']) + # Note: input to `keras layer` must be list of tensors. + mp_left = multi_perspective(x_left + x_right) + mp_right = multi_perspective(x_right + x_left) + + # ~ Dropout Layer + mp_left = Dropout(self._params['dropout_rate'])(mp_left) + mp_right = Dropout(self._params['dropout_rate'])(mp_right) + + # ~ Highway Layer + # reference: + # https://github.com/zhiguowang/BiMPM/blob/master/src/match_utils.py#L289-L295 + if self._params['with_match_highway']: + # the input is left matching representations (question / passage) + pass + + # ~ Aggregation layer + # TODO: mask the above layer + aggregation = Bidirectional( + LSTM(self._params['hidden_size'], + return_sequences=False, + return_state=False, + dropout=self._params['dropout_rate'], + kernel_initializer=self._params['w_initializer'], + bias_initializer=self._params['b_initializer']), + merge_mode='concat') + rep_left = aggregation(mp_left) + rep_right = aggregation(mp_right) + + # Concatenate the concatenated vector of left and right. + x = Concatenate()([rep_left, rep_right]) + + # ~ Highway Network + # reference: + # https://github.com/zhiguowang/BiMPM/blob/master/src/match_utils.py#L289-L295 + if self._params['with_aggregation_highway']: + pass + + # ~ Prediction layer. + # reference: + # https://github.com/zhiguowang/BiMPM/blob/master/src/SentenceMatchModelGraph.py#L140-L153 + x = Dense(self._params['hidden_size'], + activation=self._params['activation_hidden'])(x) + x = Dense(self._params['hidden_size'], + activation=self._params['activation_hidden'])(x) + x_out = self._make_output_layer()(x) + self._backend = Model(inputs=[input_left, input_right], + outputs=x_out) diff --git a/matchzoo/contrib/models/diin.py b/matchzoo/contrib/models/diin.py new file mode 100644 index 00000000..a346c84c --- /dev/null +++ b/matchzoo/contrib/models/diin.py @@ -0,0 +1,313 @@ +"""DIIN model.""" +import typing + +import keras +import keras.backend as K +import tensorflow as tf + +from matchzoo import preprocessors +from matchzoo.contrib.layers import DecayingDropoutLayer +from matchzoo.contrib.layers import EncodingLayer +from matchzoo.engine import hyper_spaces +from matchzoo.engine.base_model import BaseModel +from matchzoo.engine.param import Param +from matchzoo.engine.param_table import ParamTable + + +class DIIN(BaseModel): + """ + DIIN model. + + Examples: + >>> model = DIIN() + >>> model.guess_and_fill_missing_params() + >>> model.params['embedding_input_dim'] = 10000 + >>> model.params['embedding_output_dim'] = 300 + >>> model.params['embedding_trainable'] = True + >>> model.params['optimizer'] = 'adam' + >>> model.params['dropout_initial_keep_rate'] = 1.0 + >>> model.params['dropout_decay_interval'] = 10000 + >>> model.params['dropout_decay_rate'] = 0.977 + >>> model.params['char_embedding_input_dim'] = 100 + >>> model.params['char_embedding_output_dim'] = 8 + >>> model.params['char_conv_filters'] = 100 + >>> model.params['char_conv_kernel_size'] = 5 + >>> model.params['first_scale_down_ratio'] = 0.3 + >>> model.params['nb_dense_blocks'] = 3 + >>> model.params['layers_per_dense_block'] = 8 + >>> model.params['growth_rate'] = 20 + >>> model.params['transition_scale_down_ratio'] = 0.5 + >>> model.build() + """ + + @classmethod + def get_default_params(cls) -> ParamTable: + """:return: model default parameters.""" + params = super().get_default_params(with_embedding=True) + params['optimizer'] = 'adam' + params.add(Param(name='dropout_decay_interval', value=10000, + desc="The decay interval of decaying_dropout.")) + params.add(Param(name='char_embedding_input_dim', value=100, + desc="The input dimension of character embedding " + "layer.")) + params.add(Param(name='char_embedding_output_dim', value=2, + desc="The output dimension of character embedding " + "layer.")) + params.add(Param(name='char_conv_filters', value=8, + desc="The filter size of character convolution " + "layer.")) + params.add(Param(name='char_conv_kernel_size', value=2, + desc="The kernel size of character convolution " + "layer.")) + params.add(Param(name='first_scale_down_ratio', value=0.3, + desc="The channel scale down ratio of the " + "convolution layer before densenet.")) + params.add(Param(name='nb_dense_blocks', value=1, + desc="The number of blocks in densenet.")) + params.add(Param(name='layers_per_dense_block', value=2, + desc="The number of convolution layers in dense " + "block.")) + params.add(Param(name='growth_rate', value=2, + desc="The filter size of each convolution layer in " + "dense block.")) + params.add(Param(name='transition_scale_down_ratio', value=0.5, + desc="The channel scale down ratio of the " + "convolution layer in transition block.")) + params.add(Param( + name='dropout_initial_keep_rate', value=1.0, + hyper_space=hyper_spaces.quniform( + low=0.8, high=1.0, q=0.02), + desc="The initial keep rate of decaying_dropout." + )) + params.add(Param( + name='dropout_decay_rate', value=0.97, + hyper_space=hyper_spaces.quniform( + low=0.90, high=0.99, q=0.01), + desc="The decay rate of decaying_dropout." + )) + return params + + @classmethod + def get_default_preprocessor(cls): + """:return: Default preprocessor.""" + return preprocessors.DIINPreprocessor() + + def guess_and_fill_missing_params(self, verbose: int = 1): + """ + Guess and fill missing parameters in :attr:'params'. + + Use this method to automatically fill-in hyper parameters. + This involves some guessing so the parameter it fills could be + wrong. For example, the default task is 'Ranking', and if we do not + set it to 'Classification' manually for data packs prepared for + classification, then the shape of the model output and the data will + mismatch. + + :param verbose: Verbosity. + """ + self._params.get('input_shapes').set_default([(32,), + (32,), + (32, 16), + (32, 16), + (32,), + (32,)], verbose) + super().guess_and_fill_missing_params(verbose) + + def _make_inputs(self) -> list: + text_left = keras.layers.Input( + name='text_left', + shape=self._params['input_shapes'][0] + ) + text_right = keras.layers.Input( + name='text_right', + shape=self._params['input_shapes'][1] + ) + char_left = keras.layers.Input( + name='char_left', + shape=self._params['input_shapes'][2] + ) + char_right = keras.layers.Input( + name='char_right', + shape=self._params['input_shapes'][3] + ) + match_left = keras.layers.Input( + name='match_left', + shape=self._params['input_shapes'][4] + ) + match_right = keras.layers.Input( + name='match_right', + shape=self._params['input_shapes'][5] + ) + return [text_left, text_right, + char_left, char_right, + match_left, match_right] + + def build(self): + """Build model structure.""" + + # Scalar dimensions referenced here: + # B = batch size (number of sequences) + # D = word embedding size + # L = 'input_left' sequence length + # R = 'input_right' sequence length + # C = fixed word length + + inputs = self._make_inputs() + # Left text and right text. + # shape = [B, L] + # shape = [B, R] + text_left, text_right = inputs[0:2] + # Left character and right character. + # shape = [B, L, C] + # shape = [B, R, C] + char_left, char_right = inputs[2:4] + # Left exact match and right exact match. + # shape = [B, L] + # shape = [B, R] + match_left, match_right = inputs[4:6] + + # Embedding module + left_embeddings = [] + right_embeddings = [] + + # Word embedding feature + word_embedding = self._make_embedding_layer() + # shape = [B, L, D] + left_word_embedding = word_embedding(text_left) + # shape = [B, R, D] + right_word_embedding = word_embedding(text_right) + left_word_embedding = DecayingDropoutLayer( + initial_keep_rate=self._params['dropout_initial_keep_rate'], + decay_interval=self._params['dropout_decay_interval'], + decay_rate=self._params['dropout_decay_rate'] + )(left_word_embedding) + right_word_embedding = DecayingDropoutLayer( + initial_keep_rate=self._params['dropout_initial_keep_rate'], + decay_interval=self._params['dropout_decay_interval'], + decay_rate=self._params['dropout_decay_rate'] + )(right_word_embedding) + left_embeddings.append(left_word_embedding) + right_embeddings.append(right_word_embedding) + + # Exact match feature + # shape = [B, L, 1] + left_exact_match = keras.layers.Reshape( + target_shape=(K.int_shape(match_left)[1], 1,) + )(match_left) + # shape = [B, R, 1] + right_exact_match = keras.layers.Reshape( + target_shape=(K.int_shape(match_left)[1], 1,) + )(match_right) + left_embeddings.append(left_exact_match) + right_embeddings.append(right_exact_match) + + # Char embedding feature + char_embedding = self._make_char_embedding_layer() + char_embedding.build( + input_shape=(None, None, K.int_shape(char_left)[-1])) + left_char_embedding = char_embedding(char_left) + right_char_embedding = char_embedding(char_right) + left_embeddings.append(left_char_embedding) + right_embeddings.append(right_char_embedding) + + # Concatenate + left_embedding = keras.layers.Concatenate()(left_embeddings) + right_embedding = keras.layers.Concatenate()(right_embeddings) + d = K.int_shape(left_embedding)[-1] + + # Encoding module + left_encoding = EncodingLayer( + initial_keep_rate=self._params['dropout_initial_keep_rate'], + decay_interval=self._params['dropout_decay_interval'], + decay_rate=self._params['dropout_decay_rate'] + )(left_embedding) + right_encoding = EncodingLayer( + initial_keep_rate=self._params['dropout_initial_keep_rate'], + decay_interval=self._params['dropout_decay_interval'], + decay_rate=self._params['dropout_decay_rate'] + )(right_embedding) + + # Interaction module + interaction = keras.layers.Lambda(self._make_interaction)( + [left_encoding, right_encoding]) + + # Feature extraction module + feature_extractor_input = keras.layers.Conv2D( + filters=int(d * self._params['first_scale_down_ratio']), + kernel_size=(1, 1), + activation=None)(interaction) + feature_extractor = self._create_densenet() + features = feature_extractor(feature_extractor_input) + + # Output module + features = DecayingDropoutLayer( + initial_keep_rate=self._params['dropout_initial_keep_rate'], + decay_interval=self._params['dropout_decay_interval'], + decay_rate=self._params['dropout_decay_rate'])(features) + out = self._make_output_layer()(features) + + self._backend = keras.Model(inputs=inputs, outputs=out) + + def _make_char_embedding_layer(self) -> keras.layers.Layer: + """ + Apply embedding, conv and maxpooling operation over time dimension + for each token to obtain a vector. + + :return: Wrapper Keras 'Layer' as character embedding feature + extractor. + """ + + return keras.layers.TimeDistributed(keras.Sequential([ + keras.layers.Embedding( + input_dim=self._params['char_embedding_input_dim'], + output_dim=self._params['char_embedding_output_dim'], + input_length=self._params['input_shapes'][2][-1]), + keras.layers.Conv1D( + filters=self._params['char_conv_filters'], + kernel_size=self._params['char_conv_kernel_size']), + keras.layers.GlobalMaxPooling1D()])) + + def _make_interaction(self, inputs_) -> typing.Any: + left_encoding = inputs_[0] + right_encoding = inputs_[1] + + left_encoding = tf.expand_dims(left_encoding, axis=2) + right_encoding = tf.expand_dims(right_encoding, axis=1) + + interaction = left_encoding * right_encoding + return interaction + + def _create_densenet(self) -> typing.Callable: + """ + DenseNet is consisted of 'nb_dense_blocks' sets of Dense block + and Transition block pair. + + :return: Wrapper Keras 'Layer' as DenseNet, tensor in tensor out. + """ + def _wrapper(x): + for _ in range(self._params['nb_dense_blocks']): + # Dense block + # Apply 'layers_per_dense_block' convolution layers. + for _ in range(self._params['layers_per_dense_block']): + out_conv = keras.layers.Conv2D( + filters=self._params['growth_rate'], + kernel_size=(3, 3), + padding='same', + activation='relu')(x) + x = keras.layers.Concatenate(axis=-1)([x, out_conv]) + + # Transition block + # Apply a convolution layer and a maxpooling layer. + scale_down_ratio = self._params['transition_scale_down_ratio'] + nb_filter = int(K.int_shape(x)[-1] * scale_down_ratio) + x = keras.layers.Conv2D( + filters=nb_filter, + kernel_size=(1, 1), + padding='same', + activation=None)(x) + x = keras.layers.MaxPool2D(strides=(2, 2))(x) + + out_densenet = keras.layers.Flatten()(x) + return out_densenet + + return _wrapper diff --git a/matchzoo/contrib/models/esim.py b/matchzoo/contrib/models/esim.py new file mode 100644 index 00000000..539903ba --- /dev/null +++ b/matchzoo/contrib/models/esim.py @@ -0,0 +1,212 @@ +"""ESIM model.""" + +import keras +import keras.backend as K +import tensorflow as tf + +import matchzoo as mz +from matchzoo.engine.base_model import BaseModel +from matchzoo.engine.param import Param +from matchzoo.engine.param_table import ParamTable + + +class ESIM(BaseModel): + """ + ESIM model. + + Examples: + >>> model = ESIM() + >>> task = classification_task = mz.tasks.Classification(num_classes=2) + >>> model.params['task'] = task + >>> model.params['input_shapes'] = [(20, ), (40, )] + >>> model.params['lstm_dim'] = 300 + >>> model.params['mlp_num_units'] = 300 + >>> model.params['embedding_input_dim'] = 5000 + >>> model.params['embedding_output_dim'] = 10 + >>> model.params['embedding_trainable'] = False + >>> model.params['mlp_num_layers'] = 0 + >>> model.params['mlp_num_fan_out'] = 300 + >>> model.params['mlp_activation_func'] = 'tanh' + >>> model.params['mask_value'] = 0 + >>> model.params['dropout_rate'] = 0.5 + >>> model.params['optimizer'] = keras.optimizers.Adam(lr=4e-4) + >>> model.guess_and_fill_missing_params() + >>> model.build() + """ + + @classmethod + def get_default_params(cls) -> ParamTable: + """Get default parameters.""" + params = super().get_default_params(with_embedding=True, + with_multi_layer_perceptron=True) + + params.add(Param( + name='dropout_rate', + value=0.5, + desc="The dropout rate for all fully-connected layer" + )) + + params.add(Param( + name='lstm_dim', + value=8, + desc="The dimension of LSTM layer." + )) + + params.add(Param( + name='mask_value', + value=0, + desc="The value would be regarded as pad" + )) + + return params + + def _expand_dim(self, inp: tf.Tensor, axis: int) -> keras.layers.Layer: + """ + Wrap keras.backend.expand_dims into a Lambda layer. + + :param inp: input tensor to expand the dimension + :param axis: the axis of new dimension + """ + return keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=axis))(inp) + + def _make_atten_mask_layer(self) -> keras.layers.Layer: + """ + Make mask layer for attention weight matrix so that + each word won't pay attention to timestep. + """ + return keras.layers.Lambda( + lambda weight_mask: weight_mask[0] + (1.0 - weight_mask[1]) * -1e7, + name="atten_mask") + + def _make_bilstm_layer(self, lstm_dim: int) -> keras.layers.Layer: + """ + Bidirectional LSTM layer in ESIM. + + :param lstm_dim: int, dimension of LSTM layer + :return: `keras.layers.Layer`. + """ + return keras.layers.Bidirectional( + layer=keras.layers.LSTM(lstm_dim, return_sequences=True), + merge_mode='concat') + + def _max(self, texts: tf.Tensor, mask: tf.Tensor) -> tf.Tensor: + """ + Compute the max of each text according to their real length + + :param texts: np.array with shape [B, T, H] + :param lengths: np.array with shape [B, T, ], + where 1 means valid, 0 means pad + """ + mask = self._expand_dim(mask, axis=2) + new_texts = keras.layers.Multiply()([texts, mask]) + + text_max = keras.layers.Lambda( + lambda x: tf.reduce_max(x, axis=1), + )(new_texts) + + return text_max + + def _avg(self, texts: tf.Tensor, mask: tf.Tensor) -> tf.Tensor: + """ + Compute the mean of each text according to their real length + + :param texts: np.array with shape [B, T, H] + :param lengths: np.array with shape [B, T, ], + where 1 means valid, 0 means pad + """ + mask = self._expand_dim(mask, axis=2) + new_texts = keras.layers.Multiply()([texts, mask]) + + # timestep-wise division, exclude the PAD number when calc avg + text_avg = keras.layers.Lambda( + lambda text_mask: + tf.reduce_sum(text_mask[0], axis=1) / tf.reduce_sum(text_mask[1], axis=1), + )([new_texts, mask]) + + return text_avg + + def build(self): + """Build model.""" + # parameters + lstm_dim = self._params['lstm_dim'] + dropout_rate = self._params['dropout_rate'] + + # layers + create_mask = keras.layers.Lambda( + lambda x: + tf.cast(tf.not_equal(x, self._params['mask_value']), K.floatx()) + ) + embedding = self._make_embedding_layer() + lstm_compare = self._make_bilstm_layer(lstm_dim) + lstm_compose = self._make_bilstm_layer(lstm_dim) + dense_compare = keras.layers.Dense(units=lstm_dim, + activation='relu', + use_bias=True) + dropout = keras.layers.Dropout(dropout_rate) + + # model + a, b = self._make_inputs() # [B, T_a], [B, T_b] + a_mask = create_mask(a) # [B, T_a] + b_mask = create_mask(b) # [B, T_b] + + # encoding + a_emb = dropout(embedding(a)) # [B, T_a, E_dim] + b_emb = dropout(embedding(b)) # [B, T_b, E_dim] + + a_ = lstm_compare(a_emb) # [B, T_a, H*2] + b_ = lstm_compare(b_emb) # [B, T_b, H*2] + + # mask a_ and b_, since the position is no more zero + a_ = keras.layers.Multiply()([a_, self._expand_dim(a_mask, axis=2)]) + b_ = keras.layers.Multiply()([b_, self._expand_dim(b_mask, axis=2)]) + + # local inference + e = keras.layers.Dot(axes=-1)([a_, b_]) # [B, T_a, T_b] + _ab_mask = keras.layers.Multiply()( # _ab_mask: [B, T_a, T_b] + [self._expand_dim(a_mask, axis=2), # [B, T_a, 1] + self._expand_dim(b_mask, axis=1)]) # [B, 1, T_b] + + pm = keras.layers.Permute((2, 1)) + mask_layer = self._make_atten_mask_layer() + softmax_layer = keras.layers.Softmax(axis=-1) + + e_a = softmax_layer(mask_layer([e, _ab_mask])) # [B, T_a, T_b] + e_b = softmax_layer(mask_layer([pm(e), pm(_ab_mask)])) # [B, T_b, T_a] + + # alignment (a_t = a~, b_t = b~ ) + a_t = keras.layers.Dot(axes=(2, 1))([e_a, b_]) # [B, T_a, H*2] + b_t = keras.layers.Dot(axes=(2, 1))([e_b, a_]) # [B, T_b, H*2] + + # local inference info enhancement + m_a = keras.layers.Concatenate(axis=-1)([ + a_, + a_t, + keras.layers.Subtract()([a_, a_t]), + keras.layers.Multiply()([a_, a_t])]) # [B, T_a, H*2*4] + m_b = keras.layers.Concatenate(axis=-1)([ + b_, + b_t, + keras.layers.Subtract()([b_, b_t]), + keras.layers.Multiply()([b_, b_t])]) # [B, T_b, H*2*4] + + # project m_a and m_b from 4*H*2 dim to H dim + m_a = dropout(dense_compare(m_a)) # [B, T_a, H] + m_b = dropout(dense_compare(m_b)) # [B, T_a, H] + + # inference composition + v_a = lstm_compose(m_a) # [B, T_a, H*2] + v_b = lstm_compose(m_b) # [B, T_b, H*2] + + # pooling + v_a = keras.layers.Concatenate(axis=-1)( + [self._avg(v_a, a_mask), self._max(v_a, a_mask)]) # [B, H*4] + v_b = keras.layers.Concatenate(axis=-1)( + [self._avg(v_b, b_mask), self._max(v_b, b_mask)]) # [B, H*4] + v = keras.layers.Concatenate(axis=-1)([v_a, v_b]) # [B, H*8] + + # mlp (multilayer perceptron) classifier + output = self._make_multi_layer_perceptron_layer()(v) # [B, H] + output = dropout(output) + output = self._make_output_layer()(output) # [B, #classes] + + self._backend = keras.Model(inputs=[a, b], outputs=output) diff --git a/matchzoo/contrib/models/hbmp.py b/matchzoo/contrib/models/hbmp.py new file mode 100644 index 00000000..bc16605d --- /dev/null +++ b/matchzoo/contrib/models/hbmp.py @@ -0,0 +1,154 @@ +"""HBMP model.""" +import keras +import typing + +from matchzoo.engine import hyper_spaces +from matchzoo.engine.param_table import ParamTable +from matchzoo.engine.param import Param +from matchzoo.engine.base_model import BaseModel + + +class HBMP(BaseModel): + """ + HBMP model. + + Examples: + >>> model = HBMP() + >>> model.guess_and_fill_missing_params(verbose=0) + >>> model.params['embedding_input_dim'] = 200 + >>> model.params['embedding_output_dim'] = 100 + >>> model.params['embedding_trainable'] = True + >>> model.params['alpha'] = 0.1 + >>> model.params['mlp_num_layers'] = 3 + >>> model.params['mlp_num_units'] = [10, 10] + >>> model.params['lstm_num_units'] = 5 + >>> model.params['dropout_rate'] = 0.1 + >>> model.build() + """ + + @classmethod + def get_default_params(cls) -> ParamTable: + """:return: model default parameters.""" + params = super().get_default_params(with_embedding=True) + params['optimizer'] = 'adam' + params.add(Param(name='alpha', value=0.1, + desc="Negative slope coefficient of LeakyReLU " + "function.")) + params.add(Param(name='mlp_num_layers', value=3, + desc="The number of layers of mlp.")) + params.add(Param(name='mlp_num_units', value=[10, 10], + desc="The hidden size of the FC layers, but not " + "include the final layer.")) + params.add(Param(name='lstm_num_units', value=5, + desc="The hidden size of the LSTM layer.")) + params.add(Param( + name='dropout_rate', value=0.1, + hyper_space=hyper_spaces.quniform( + low=0.0, high=0.8, q=0.01), + desc="The dropout rate." + )) + return params + + def build(self): + """Build model structure.""" + input_left, input_right = self._make_inputs() + + embedding = self._make_embedding_layer() + embed_left = embedding(input_left) + embed_right = embedding(input_right) + + # Get sentence embedding + embed_sen_left = self._sentence_encoder( + embed_left, + lstm_num_units=self._params['lstm_num_units'], + drop_rate=self._params['dropout_rate']) + embed_sen_right = self._sentence_encoder( + embed_right, + lstm_num_units=self._params['lstm_num_units'], + drop_rate=self._params['dropout_rate']) + + # Concatenate two sentence embedding: [embed_sen_left, embed_sen_right, + # |embed_sen_left-embed_sen_right|, embed_sen_left*embed_sen_right] + embed_minus = keras.layers.Subtract()( + [embed_sen_left, embed_sen_right]) + embed_minus_abs = keras.layers.Lambda(lambda x: abs(x))(embed_minus) + embed_multiply = keras.layers.Multiply()( + [embed_sen_left, embed_sen_right]) + concat = keras.layers.Concatenate(axis=1)( + [embed_sen_left, embed_sen_right, embed_minus_abs, embed_multiply]) + + # Multiply perception layers to classify + mlp_out = self._classifier( + concat, + mlp_num_layers=self._params['mlp_num_layers'], + mlp_num_units=self._params['mlp_num_units'], + drop_rate=self._params['dropout_rate'], + leaky_relu_alpah=self._params['alpha']) + out = self._make_output_layer()(mlp_out) + + self._backend = keras.Model( + inputs=[input_left, input_right], outputs=out) + + def _classifier( + self, + input_: typing.Any, + mlp_num_layers: int, + mlp_num_units: list, + drop_rate: float, + leaky_relu_alpah: float + ) -> typing.Any: + for i in range(mlp_num_layers - 1): + input_ = keras.layers.Dropout(rate=drop_rate)(input_) + input_ = keras.layers.Dense(mlp_num_units[i])(input_) + input_ = keras.layers.LeakyReLU(alpha=leaky_relu_alpah)(input_) + + return input_ + + def _sentence_encoder( + self, + input_: typing.Any, + lstm_num_units: int, + drop_rate: float + ) -> typing.Any: + """ + Stack three BiLSTM MaxPooling blocks as a hierarchical structure. + Concatenate the output of three blocs as the input sentence embedding. + Each BiLSTM layer reads the input sentence as the input. + Each BiLSTM layer except the first one is initialized(the initial + hidden state and the cell state) with the final state of the previous + layer. + """ + emb1 = keras.layers.Bidirectional( + keras.layers.LSTM( + units=lstm_num_units, + return_sequences=True, + return_state=True, + dropout=drop_rate, + recurrent_dropout=drop_rate), + merge_mode='concat')(input_) + emb1_maxpooling = keras.layers.GlobalMaxPooling1D()(emb1[0]) + + emb2 = keras.layers.Bidirectional( + keras.layers.LSTM( + units=lstm_num_units, + return_sequences=True, + return_state=True, + dropout=drop_rate, + recurrent_dropout=drop_rate), + merge_mode='concat')(input_, initial_state=emb1[1:5]) + emb2_maxpooling = keras.layers.GlobalMaxPooling1D()(emb2[0]) + + emb3 = keras.layers.Bidirectional( + keras.layers.LSTM( + units=lstm_num_units, + return_sequences=True, + return_state=True, + dropout=drop_rate, + recurrent_dropout=drop_rate), + merge_mode='concat')(input_, initial_state=emb2[1:5]) + emb3_maxpooling = keras.layers.GlobalMaxPooling1D()(emb3[0]) + + emb = keras.layers.Concatenate(axis=1)( + [emb1_maxpooling, emb2_maxpooling, emb3_maxpooling]) + + return emb diff --git a/matchzoo/contrib/models/match_lstm.py b/matchzoo/contrib/models/match_lstm.py index e1150afb..f8c073d3 100644 --- a/matchzoo/contrib/models/match_lstm.py +++ b/matchzoo/contrib/models/match_lstm.py @@ -1,6 +1,7 @@ """Match LSTM model.""" import keras import keras.backend as K +import tensorflow as tf from matchzoo.engine.base_model import BaseModel from matchzoo.engine.param import Param @@ -68,19 +69,19 @@ def build(self): def attention(tensors): """Attention layer.""" left, right = tensors - tensor_left = K.expand_dims(left, axis=2) - tensor_right = K.expand_dims(right, axis=1) + tensor_left = tf.expand_dims(left, axis=2) + tensor_right = tf.expand_dims(right, axis=1) tensor_left = K.repeat_elements(tensor_left, len_right, 2) tensor_right = K.repeat_elements(tensor_right, len_left, 1) - tensor_merged = K.concatenate([tensor_left, tensor_right], axis=-1) + tensor_merged = tf.concat([tensor_left, tensor_right], axis=-1) middle_output = keras.layers.Dense(self._params['fc_num_units'], activation='tanh')( tensor_merged) attn_scores = keras.layers.Dense(1)(middle_output) - attn_scores = K.squeeze(attn_scores, axis=3) - exp_attn_scores = K.exp( - attn_scores - K.max(attn_scores, axis=-1, keepdims=True)) - exp_sum = K.sum(exp_attn_scores, axis=-1, keepdims=True) + attn_scores = tf.squeeze(attn_scores, axis=3) + exp_attn_scores = tf.math.exp( + attn_scores - tf.reduce_max(attn_scores, axis=-1, keepdims=True)) + exp_sum = tf.reduce_sum(exp_attn_scores, axis=-1, keepdims=True) attention_weights = exp_attn_scores / exp_sum return K.batch_dot(attention_weights, right) diff --git a/matchzoo/contrib/models/match_srnn.py b/matchzoo/contrib/models/match_srnn.py new file mode 100644 index 00000000..66ae800a --- /dev/null +++ b/matchzoo/contrib/models/match_srnn.py @@ -0,0 +1,93 @@ +"""An implementation of Match-SRNN Model.""" + +import keras + +from matchzoo.contrib.layers import MatchingTensorLayer +from matchzoo.contrib.layers import SpatialGRU +from matchzoo.engine import hyper_spaces +from matchzoo.engine.base_model import BaseModel +from matchzoo.engine.param import Param +from matchzoo.engine.param_table import ParamTable + + +class MatchSRNN(BaseModel): + """ + Match-SRNN Model. + + Examples: + >>> model = MatchSRNN() + >>> model.params['channels'] = 4 + >>> model.params['units'] = 10 + >>> model.params['dropout_rate'] = 0.0 + >>> model.params['direction'] = 'lt' + >>> model.guess_and_fill_missing_params(verbose=0) + >>> model.build() + + """ + + @classmethod + def get_default_params(cls) -> ParamTable: + """:return: model default parameters.""" + params = super().get_default_params(with_embedding=True) + params.add(Param(name='channels', value=4, + desc="Number of word interaction tensor channels")) + params.add(Param(name='units', value=10, + desc="Number of SpatialGRU units")) + params.add(Param(name='direction', value='lt', + desc="Direction of SpatialGRU scanning")) + params.add(Param( + name='dropout_rate', value=0.0, + hyper_space=hyper_spaces.quniform(low=0.0, high=0.8, + q=0.01), + desc="The dropout rate." + )) + return params + + def build(self): + """ + Build model structure. + + Match-SRNN: Modeling the Recursive Matching Structure + with Spatial RNN + """ + + # Scalar dimensions referenced here: + # B = batch size (number of sequences) + # D = embedding size + # L = `input_left` sequence length + # R = `input_right` sequence length + # C = number of channels + + # Left input and right input. + # query = [B, L] + # doc = [B, R] + query, doc = self._make_inputs() + + # Process left and right input. + # embed_query = [B, L, D] + # embed_doc = [B, R, D] + embedding = self._make_embedding_layer() + embed_query = embedding(query) + embed_doc = embedding(doc) + + # Get matching tensor + # matching_tensor = [B, C, L, R] + matching_tensor_layer = MatchingTensorLayer( + channels=self._params['channels']) + matching_tensor = matching_tensor_layer([embed_query, embed_doc]) + + # Apply spatial GRU to the word level interaction tensor + # h_ij = [B, U] + spatial_gru = SpatialGRU( + units=self._params['units'], + direction=self._params['direction']) + h_ij = spatial_gru(matching_tensor) + + # Apply Dropout + x = keras.layers.Dropout( + rate=self._params['dropout_rate'])(h_ij) + + # Make output layer + x_out = self._make_output_layer()(x) + + self._backend = keras.Model(inputs=[query, doc], outputs=x_out) diff --git a/matchzoo/data_generator/callbacks/lambda_callback.py b/matchzoo/data_generator/callbacks/lambda_callback.py index 27797159..684171ba 100644 --- a/matchzoo/data_generator/callbacks/lambda_callback.py +++ b/matchzoo/data_generator/callbacks/lambda_callback.py @@ -8,10 +8,19 @@ class LambdaCallback(Callback): See :class:`matchzoo.data_generator.callbacks.Callback` for more details. Example: + + >>> import matchzoo as mz >>> from matchzoo.data_generator.callbacks import LambdaCallback - >>> callback = LambdaCallback(on_batch_unpacked=print) - >>> callback.on_batch_unpacked('x', 'y') - x y + >>> data = mz.datasets.toy.load_data() + >>> batch_func = lambda x: print(type(x)) + >>> unpack_func = lambda x, y: print(type(x), type(y)) + >>> callback = LambdaCallback(on_batch_data_pack=batch_func, + ... on_batch_unpacked=unpack_func) + >>> data_gen = mz.DataGenerator( + ... data, batch_size=len(data), callbacks=[callback]) + >>> _ = data_gen[0] + + """ diff --git a/matchzoo/data_pack/data_pack.py b/matchzoo/data_pack/data_pack.py index e8a94015..5b3e862d 100644 --- a/matchzoo/data_pack/data_pack.py +++ b/matchzoo/data_pack/data_pack.py @@ -294,6 +294,32 @@ def drop_label(self): """ self._relation = self._relation.drop(columns='label') + @_optional_inplace + def drop_invalid(self): + """ + Remove rows from the data pack where the length is zero. + + :param inplace: `True` to modify inplace, `False` to return a modified + copy. (default: `False`) + + Example: + >>> import matchzoo as mz + >>> data_pack = mz.datasets.toy.load_data() + >>> data_pack.append_text_length(inplace=True, verbose=0) + >>> data_pack.drop_invalid(inplace=True) + """ + if not ('length_left' in self._left and 'length_right' in self._right): + raise ValueError(f"`lenght_left` or `length_right` is missing. " + f"Please call `append_text_length` in advance.") + valid_left = self._left.loc[self._left.length_left != 0] + valid_right = self._right.loc[self._right.length_right != 0] + self._left = self._left[self._left.index.isin(valid_left.index)] + self._right = self._right[self._right.index.isin(valid_right.index)] + self._relation = self._relation[self._relation.id_left.isin( + valid_left.index) & self._relation.id_right.isin( + valid_right.index)] + self._relation.reset_index(drop=True, inplace=True) + @_optional_inplace def append_text_length(self, verbose=1): """ @@ -350,27 +376,24 @@ def apply_on_text( ... rename='length_left', ... inplace=True, ... verbose=0) - >>> list(frame[0].columns) - ['id_left', 'text_left', 'length_left', 'id_right', 'text_right', \ -'label'] + >>> list(frame[0].columns) # noqa: E501 + ['id_left', 'text_left', 'length_left', 'id_right', 'text_right', 'label'] To do the same to the right text: >>> data_pack.apply_on_text(len, mode='right', ... rename='length_right', ... inplace=True, ... verbose=0) - >>> list(frame[0].columns) - ['id_left', 'text_left', 'length_left', 'id_right', 'text_right', \ -'length_right', 'label'] + >>> list(frame[0].columns) # noqa: E501 + ['id_left', 'text_left', 'length_left', 'id_right', 'text_right', 'length_right', 'label'] To do the same to the both texts at the same time: >>> data_pack.apply_on_text(len, mode='both', ... rename=('extra_left', 'extra_right'), ... inplace=True, ... verbose=0) - >>> list(frame[0].columns) - ['id_left', 'text_left', 'length_left', 'extra_left', 'id_right', \ -'text_right', 'length_right', 'extra_right', 'label'] + >>> list(frame[0].columns) # noqa: E501 + ['id_left', 'text_left', 'length_left', 'extra_left', 'id_right', 'text_right', 'length_right', 'extra_right', 'label'] To suppress outputs: >>> data_pack.apply_on_text(len, mode='both', verbose=0, diff --git a/matchzoo/datasets/__init__.py b/matchzoo/datasets/__init__.py index 45f592ff..44eb4669 100644 --- a/matchzoo/datasets/__init__.py +++ b/matchzoo/datasets/__init__.py @@ -3,6 +3,7 @@ from . import embeddings from . import snli from . import quora_qp +from . import cqa_ql_16 from pathlib import Path diff --git a/matchzoo/datasets/bert_resources/uncased_vocab_100.txt b/matchzoo/datasets/bert_resources/uncased_vocab_100.txt new file mode 100644 index 00000000..bbdb8526 --- /dev/null +++ b/matchzoo/datasets/bert_resources/uncased_vocab_100.txt @@ -0,0 +1,101 @@ +[PAD] +##ness +episode +bed +added +table +indian +private +charles +route +available +idea +throughout +centre +addition +appointed +style +1994 +books +eight +construction +press +mean +wall +friends +remained +schools +study +##ch +##um +institute +oh +chinese +sometimes +events +possible +1992 +australian +type +brown +forward +talk +process +food +debut +seat +performance +committee +features +character +arts +herself +else +lot +strong +russian +range +hours +peter +arm +##da +morning +dr +sold +##ry +quickly +directed +1993 +guitar +china +##w +31 +list +##ma +performed +media +uk +players +smile +##rs +myself +40 +placed +coach +province +##gawa +typed +##dry +favors +allegheny +glaciers +##rly +recalling +aziz +##log +parasite +requiem +auf +##berto +##llin +[UNK] \ No newline at end of file diff --git a/matchzoo/datasets/cqa_ql_16/__init__.py b/matchzoo/datasets/cqa_ql_16/__init__.py new file mode 100644 index 00000000..0394d77f --- /dev/null +++ b/matchzoo/datasets/cqa_ql_16/__init__.py @@ -0,0 +1 @@ +from .load_data import load_data \ No newline at end of file diff --git a/matchzoo/datasets/cqa_ql_16/load_data.py b/matchzoo/datasets/cqa_ql_16/load_data.py new file mode 100644 index 00000000..5b6b0a2e --- /dev/null +++ b/matchzoo/datasets/cqa_ql_16/load_data.py @@ -0,0 +1,203 @@ +"""CQA-QL-16 data loader.""" + +import xml +import typing +from pathlib import Path + +import keras +import pandas as pd + +import matchzoo + + +_train_dev_url = "http://alt.qcri.org/semeval2016/task3/data/uploads/" \ + "semeval2016-task3-cqa-ql-traindev-v3.2.zip" +_test_url = "http://alt.qcri.org/semeval2016/task3/data/uploads/" \ + "semeval2016_task3_test.zip" + + +def load_data( + stage: str = 'train', + task: str = 'classification', + target_label: str = 'PerfectMatch', + return_classes: bool = False, + match_type: str = 'question', + mode: str = 'both', +) -> typing.Union[matchzoo.DataPack, tuple]: + """ + Load CQA-QL-16 data. + + :param stage: One of `train`, `dev`, and `test`. + (default: `train`) + :param task: Could be one of `ranking`, `classification` or instance + of :class:`matchzoo.engine.BaseTask`. (default: `classification`) + :param target_label: If `ranking`, choose one of classification + label as the positive label. (default: `PerfectMatch`) + :param return_classes: `True` to return classes for classification + task, `False` otherwise. + :param match_type: Matching text types. One of `question`, + `answer`, and `external_answer`. (default: `question`) + :param mode: Train data use method. One of `part1`, `part2`, + and `both`. (default: `both`) + + :return: A DataPack unless `task` is `classification` and `return_classes` + is `True`: a tuple of `(DataPack, classes)` in that case. + """ + if stage not in ('train', 'dev', 'test'): + raise ValueError(f"{stage} is not a valid stage." + f"Must be one of `train`, `dev`, and `test`.") + + if match_type not in ('question', 'answer', 'external_answer'): + raise ValueError(f"{match_type} is not a valid method. Must be one of" + f" `question`, `answer`, `external_answer`.") + + if mode not in ('part1', 'part2', 'both'): + raise ValueError(f"{mode} is not a valid method." + f"Must be one of `part1`, `part2`, `both`.") + + data_root = _download_data(stage) + data_pack = _read_data(data_root, stage, match_type, mode) + + if task == 'ranking': + if match_type in ('anwer', 'external_answer') and target_label not in [ + 'Good', 'PotentiallyUseful', 'Bad']: + raise ValueError(f"{target_label} is not a valid target label." + f"Must be one of `Good`, `PotentiallyUseful`," + f" `Bad`.") + elif match_type == 'question' and target_label not in [ + 'PerfectMatch', 'Relevant', 'Irrelevant']: + raise ValueError(f"{target_label} is not a valid target label." + f" Must be one of `PerfectMatch`, `Relevant`," + f" `Irrelevant`.") + binary = (data_pack.relation['label'] == target_label).astype(float) + data_pack.relation['label'] = binary + return data_pack + elif task == 'classification': + if match_type in ('answer', 'external_answer'): + classes = ['Good', 'PotentiallyUseful', 'Bad'] + else: + classes = ['PerfectMatch', 'Relevant', 'Irrelevant'] + label = data_pack.relation['label'].apply(classes.index) + data_pack.relation['label'] = label + data_pack.one_hot_encode_label(num_classes=3, inplace=True) + if return_classes: + return data_pack, classes + else: + return data_pack + else: + raise ValueError(f"{task} is not a valid task." + f"Must be one of `Ranking` and `Classification`.") + + +def _download_data(stage): + if stage in ['train', 'dev']: + return _download_train_dev_data() + else: + return _download_test_data() + + +def _download_train_dev_data(): + ref_path = keras.utils.data_utils.get_file( + 'semeval_train', _train_dev_url, extract=True, + cache_dir=matchzoo.USER_DATA_DIR, + cache_subdir='semeval_train' + ) + return Path(ref_path).parent.joinpath('v3.2') + + +def _download_test_data(): + ref_path = keras.utils.data_utils.get_file( + 'semeval_test', _test_url, extract=True, + cache_dir=matchzoo.USER_DATA_DIR, + cache_subdir='semeval_test' + ) + return Path(ref_path).parent.joinpath('SemEval2016_task3_test/English') + + +def _read_data(path, stage, match_type, mode='both'): + if stage == 'train': + if mode == 'part1': + path = path.joinpath( + 'train/SemEval2016-Task3-CQA-QL-train-part1.xml') + data = _load_data_by_type(path, match_type) + elif mode == 'part2': + path = path.joinpath( + 'train/SemEval2016-Task3-CQA-QL-train-part2.xml') + data = _load_data_by_type(path, match_type) + else: + part1 = path.joinpath( + 'train/SemEval2016-Task3-CQA-QL-train-part1.xml') + p1 = _load_data_by_type(part1, match_type) + part2 = path.joinpath( + 'train/SemEval2016-Task3-CQA-QL-train-part1.xml') + p2 = _load_data_by_type(part2, match_type) + data = pd.concat([p1, p2], ignore_index=True) + return matchzoo.pack(data) + elif stage == 'dev': + path = path.joinpath('dev/SemEval2016-Task3-CQA-QL-dev.xml') + data = _load_data_by_type(path, match_type) + return matchzoo.pack(data) + else: + path = path.joinpath('SemEval2016-Task3-CQA-QL-test.xml') + data = _load_data_by_type(path, match_type) + return matchzoo.pack(data) + + +def _load_data_by_type(path, match_type): + if match_type == 'question': + return _load_question(path) + elif match_type == 'answer': + return _load_answer(path) + else: + return _load_external_answer(path) + + +def _load_question(path): + doc = xml.etree.ElementTree.parse(path) + dataset = [] + for question in doc.iterfind('OrgQuestion'): + qid = question.attrib['ORGQ_ID'] + query = question.findtext('OrgQBody') + rel_question = question.find('Thread').find('RelQuestion') + question = rel_question.findtext('RelQBody') + question_id = rel_question.attrib['RELQ_ID'] + dataset.append([qid, question_id, query, question, + rel_question.attrib['RELQ_RELEVANCE2ORGQ']]) + df = pd.DataFrame(dataset, columns=[ + 'id_left', 'id_right', 'text_left', 'text_right', 'label']) + return df + + +def _load_answer(path): + doc = xml.etree.ElementTree.parse(path) + dataset = [] + for org_q in doc.iterfind('OrgQuestion'): + for thread in org_q.iterfind('Thread'): + ques = thread.find('RelQuestion') + qid = ques.attrib['RELQ_ID'] + question = ques.findtext('RelQBody') + for comment in thread.iterfind('RelComment'): + aid = comment.attrib['RELC_ID'] + answer = comment.findtext('RelCText') + dataset.append([qid, aid, question, answer, + comment.attrib['RELC_RELEVANCE2RELQ']]) + df = pd.DataFrame(dataset, columns=[ + 'id_left', 'id_right', 'text_left', 'text_right', 'label']) + return df + + +def _load_external_answer(path): + doc = xml.etree.ElementTree.parse(path) + dataset = [] + for question in doc.iterfind('OrgQuestion'): + qid = question.attrib['ORGQ_ID'] + query = question.findtext('OrgQBody') + thread = question.find('Thread') + for comment in thread.iterfind('RelComment'): + answer = comment.findtext('RelCText') + aid = comment.attrib['RELC_ID'] + dataset.append([qid, aid, query, answer, + comment.attrib['RELC_RELEVANCE2ORGQ']]) + df = pd.DataFrame(dataset, columns=[ + 'id_left', 'id_right', 'text_left', 'text_right', 'label']) + return df diff --git a/matchzoo/embedding/embedding.py b/matchzoo/embedding/embedding.py index 3787f8bc..fb21af36 100644 --- a/matchzoo/embedding/embedding.py +++ b/matchzoo/embedding/embedding.py @@ -25,13 +25,13 @@ class Embedding(object): To load from a file: >>> embedding = mz.embedding.load_from_file(embed_path) >>> matrix = embedding.build_matrix(term_index) - >>> matrix.shape[0] == len(term_index) + 1 + >>> matrix.shape[0] == len(term_index) True To build your own: >>> data = pd.DataFrame(data=[[0, 1], [2, 3]], index=['A', 'B']) >>> embedding = mz.Embedding(data) - >>> matrix = embedding.build_matrix({'A': 2, 'B': 1}) + >>> matrix = embedding.build_matrix({'A': 2, 'B': 1, '_PAD': 0}) >>> matrix.shape == (3, 2) True @@ -70,7 +70,7 @@ def build_matrix( `(-0.2, 0.2)`). :return: A matrix. """ - input_dim = len(term_index) + 1 + input_dim = len(term_index) matrix = np.empty((input_dim, self.output_dim)) for index in np.ndindex(*matrix.shape): diff --git a/matchzoo/engine/base_model.py b/matchzoo/engine/base_model.py index bc77e41b..c134bbb2 100644 --- a/matchzoo/engine/base_model.py +++ b/matchzoo/engine/base_model.py @@ -405,6 +405,17 @@ def save(self, dirpath: typing.Union[str, Path]): h5 file saved by `keras`. :param dirpath: directory path of the saved model + + Example: + + >>> import matchzoo as mz + >>> model = mz.models.Naive() + >>> model.guess_and_fill_missing_params(verbose=0) + >>> model.build() + >>> model.save('temp-model') + >>> import shutil + >>> shutil.rmtree('temp-model') + """ dirpath = Path(dirpath) params_path = dirpath.joinpath(self.PARAMS_FILENAME) @@ -505,13 +516,17 @@ def _make_output_layer(self) -> keras.layers.Layer: raise ValueError(f"{task} is not a valid task type." f"Must be in `Ranking` and `Classification`.") - def _make_embedding_layer(self, name: str = 'embedding' - ) -> keras.layers.Layer: + def _make_embedding_layer( + self, + name: str = 'embedding', + **kwargs + ) -> keras.layers.Layer: return keras.layers.Embedding( self._params['embedding_input_dim'], self._params['embedding_output_dim'], trainable=self._params['embedding_trainable'], - name=name + name=name, + **kwargs ) def _make_multi_layer_perceptron_layer(self) -> keras.layers.Layer: @@ -537,6 +552,19 @@ def load_model(dirpath: typing.Union[str, Path]) -> BaseModel: :param dirpath: directory path of the saved model :return: a :class:`BaseModel` instance + + Example: + + >>> import matchzoo as mz + >>> model = mz.models.Naive() + >>> model.guess_and_fill_missing_params(verbose=0) + >>> model.build() + >>> model.save('my-model') + >>> model.params.keys() == mz.load_model('my-model').params.keys() + True + >>> import shutil + >>> shutil.rmtree('my-model') + """ dirpath = Path(dirpath) diff --git a/matchzoo/layers/dynamic_pooling_layer.py b/matchzoo/layers/dynamic_pooling_layer.py index 40b989f7..049bbf82 100644 --- a/matchzoo/layers/dynamic_pooling_layer.py +++ b/matchzoo/layers/dynamic_pooling_layer.py @@ -1,7 +1,7 @@ """An implementation of Dynamic Pooling Layer.""" import typing -from keras import backend as K +import tensorflow as tf from keras.engine import Layer @@ -49,34 +49,25 @@ def call(self, inputs: list, **kwargs) -> typing.Any: :param inputs: two input tensors. """ + self._validate_dpool_size() x, dpool_index = inputs - dpool_shape = K.tf.shape(dpool_index) - batch_index_one = K.tf.expand_dims( - K.tf.expand_dims( - K.tf.range(dpool_shape[0]), axis=-1), + dpool_shape = tf.shape(dpool_index) + batch_index_one = tf.expand_dims( + tf.expand_dims( + tf.range(dpool_shape[0]), axis=-1), axis=-1) - batch_index = K.tf.expand_dims( - K.tf.tile(batch_index_one, [1, self._msize1, self._msize2]), + batch_index = tf.expand_dims( + tf.tile(batch_index_one, [1, self._msize1, self._msize2]), axis=-1) - dpool_index_ex = K.tf.concat([batch_index, dpool_index], axis=3) - x_expand = K.tf.gather_nd(x, dpool_index_ex) - stride1 = self._msize1 / self._psize1 - stride2 = self._msize2 / self._psize2 - - suggestion1 = self._msize1 / stride1 - suggestion2 = self._msize2 / stride2 - - if suggestion1 != self._psize1 or suggestion2 != self._psize2: - raise ValueError("DynamicPooling Layer can not " - "generate ({} x {}) output feature map, " - "please use ({} x {} instead.)" - .format(self._psize1, self._psize2, - suggestion1, suggestion2)) - - x_pool = K.tf.nn.max_pool(x_expand, - [1, stride1, stride2, 1], - [1, stride1, stride2, 1], - "VALID") + dpool_index_ex = tf.concat([batch_index, dpool_index], axis=3) + x_expand = tf.gather_nd(x, dpool_index_ex) + stride1 = self._msize1 // self._psize1 + stride2 = self._msize2 // self._psize2 + + x_pool = tf.nn.max_pool(x_expand, + [1, stride1, stride2, 1], + [1, stride1, stride2, 1], + "VALID") return x_pool def compute_output_shape(self, input_shape: list) -> tuple: @@ -97,3 +88,41 @@ def get_config(self) -> dict: } base_config = super(DynamicPoolingLayer, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + def _validate_dpool_size(self): + suggestion = self.get_size_suggestion( + self._msize1, self._msize2, self._psize1, self._psize2 + ) + if suggestion != (self._psize1, self._psize2): + raise ValueError( + "DynamicPooling Layer can not " + f"generate ({self._psize1} x {self._psize2}) output " + f"feature map, please use ({suggestion[0]} x {suggestion[1]})" + f" instead. `model.params['dpool_size'] = {suggestion}` " + ) + + @classmethod + def get_size_suggestion( + cls, + msize1: int, + msize2: int, + psize1: int, + psize2: int + ) -> typing.Tuple[int, int]: + """ + Get `dpool_size` suggestion for a given shape. + + Returns the nearest legal `dpool_size` for the given combination of + `(psize1, psize2)`. + + :param msize1: size of the left text. + :param msize2: size of the right text. + :param psize1: base size of the pool. + :param psize2: base size of the pool. + :return: + """ + stride1 = msize1 // psize1 + stride2 = msize2 // psize2 + suggestion1 = msize1 // stride1 + suggestion2 = msize2 // stride2 + return (suggestion1, suggestion2) diff --git a/matchzoo/layers/matching_layer.py b/matchzoo/layers/matching_layer.py index 24b591cf..54dead2a 100644 --- a/matchzoo/layers/matching_layer.py +++ b/matchzoo/layers/matching_layer.py @@ -1,7 +1,7 @@ """An implementation of Matching Layer.""" import typing -from keras import backend as K +import tensorflow as tf from keras.engine import Layer @@ -74,9 +74,9 @@ def call(self, inputs: list, **kwargs) -> typing.Any: x2 = inputs[1] if self._matching_type == 'dot': if self._normalize: - x1 = K.l2_normalize(x1, axis=2) - x2 = K.l2_normalize(x2, axis=2) - return K.tf.expand_dims(K.tf.einsum('abd,acd->abc', x1, x2), 3) + x1 = tf.math.l2_normalize(x1, axis=2) + x2 = tf.math.l2_normalize(x2, axis=2) + return tf.expand_dims(tf.einsum('abd,acd->abc', x1, x2), 3) else: if self._matching_type == 'mul': def func(x, y): @@ -89,14 +89,14 @@ def func(x, y): return x - y elif self._matching_type == 'concat': def func(x, y): - return K.tf.concat([x, y], axis=3) + return tf.concat([x, y], axis=3) else: raise ValueError(f"Invalid matching type." f"{self._matching_type} received." f"Mut be in `dot`, `mul`, `plus`, " f"`minus` and `concat`.") - x1_exp = K.tf.stack([x1] * self._shape2[1], 2) - x2_exp = K.tf.stack([x2] * self._shape1[1], 1) + x1_exp = tf.stack([x1] * self._shape2[1], 2) + x2_exp = tf.stack([x2] * self._shape1[1], 1) return func(x1_exp, x2_exp) def compute_output_shape(self, input_shape: list) -> tuple: diff --git a/matchzoo/losses/rank_cross_entropy_loss.py b/matchzoo/losses/rank_cross_entropy_loss.py index 2be64023..97e13fe0 100644 --- a/matchzoo/losses/rank_cross_entropy_loss.py +++ b/matchzoo/losses/rank_cross_entropy_loss.py @@ -1,10 +1,13 @@ """The rank cross entropy loss.""" -import numpy as np +import numpy as np +import tensorflow as tf from keras import layers, backend as K +from keras.losses import Loss +from keras.utils import losses_utils -class RankCrossEntropyLoss(object): +class RankCrossEntropyLoss(Loss): """ Rank cross entropy loss. @@ -26,9 +29,12 @@ def __init__(self, num_neg: int = 1): :param num_neg: number of negative instances in cross entropy loss. """ + super().__init__(reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE, + name="rank_crossentropy") self._num_neg = num_neg - def __call__(self, y_true: np.array, y_pred: np.array) -> np.array: + def call(self, y_true: np.array, y_pred: np.array, + sample_weight=None) -> np.array: """ Calculate rank cross entropy loss. @@ -46,9 +52,12 @@ def __call__(self, y_true: np.array, y_pred: np.array) -> np.array: lambda a: a[neg_idx + 1::(self._num_neg + 1), :])(y_true) logits.append(neg_logits) labels.append(neg_labels) - logits = K.concatenate(logits, axis=-1) - labels = K.concatenate(labels, axis=-1) - return -K.mean(K.sum(labels * K.log(K.softmax(logits)), axis=-1)) + logits = tf.concat(logits, axis=-1) + labels = tf.concat(labels, axis=-1) + smoothed_prob = tf.nn.softmax(logits) + np.finfo(float).eps + loss = -(tf.reduce_sum(labels * tf.math.log(smoothed_prob), axis=-1)) + return losses_utils.compute_weighted_loss( + loss, sample_weight, reduction=self.reduction) @property def num_neg(self): diff --git a/matchzoo/losses/rank_hinge_loss.py b/matchzoo/losses/rank_hinge_loss.py index 43dccbc7..157f8a85 100644 --- a/matchzoo/losses/rank_hinge_loss.py +++ b/matchzoo/losses/rank_hinge_loss.py @@ -1,10 +1,13 @@ """The rank hinge loss.""" -import numpy as np +import numpy as np +import tensorflow as tf from keras import layers, backend as K +from keras.losses import Loss +from keras.utils import losses_utils -class RankHingeLoss(object): +class RankHingeLoss(Loss): """ Rank hinge loss. @@ -28,10 +31,14 @@ def __init__(self, num_neg: int = 1, margin: float = 1.0): :param num_neg: number of negative instances in hinge loss. :param margin: the margin between positive and negative scores. """ + super().__init__(reduction=losses_utils.Reduction.SUM_OVER_BATCH_SIZE, + name="rank_hinge") + self._num_neg = num_neg self._margin = margin - def __call__(self, y_true: np.array, y_pred: np.array) -> np.array: + def call(self, y_true: np.array, y_pred: np.array, + sample_weight=None) -> np.array: """ Calculate rank hinge loss. @@ -47,9 +54,11 @@ def __call__(self, y_true: np.array, y_pred: np.array) -> np.array: layers.Lambda( lambda a: a[(neg_idx + 1)::(self._num_neg + 1), :], output_shape=(1,))(y_pred)) - y_neg = K.mean(K.concatenate(y_neg, axis=-1), axis=-1, keepdims=True) - loss = K.maximum(0., self._margin + y_neg - y_pos) - return K.mean(loss) + y_neg = tf.concat(y_neg, axis=-1) + y_neg = tf.reduce_mean(y_neg, axis=-1, keepdims=True) + loss = tf.maximum(0., self._margin + y_neg - y_pos) + return losses_utils.compute_weighted_loss( + loss, sample_weight, reduction=self.reduction) @property def num_neg(self): diff --git a/matchzoo/models/conv_knrm.py b/matchzoo/models/conv_knrm.py index 47eb64a2..c074371e 100644 --- a/matchzoo/models/conv_knrm.py +++ b/matchzoo/models/conv_knrm.py @@ -1,12 +1,10 @@ """ConvKNRM model.""" import keras -import keras.backend as K +import tensorflow as tf from .knrm import KNRM from matchzoo.engine.param import Param -from matchzoo.engine.param_table import ParamTable -from matchzoo.engine import hyper_spaces class ConvKNRM(KNRM): @@ -87,13 +85,13 @@ def build(self): mu = 1.0 mm_exp = self._kernel_layer(mu, sigma)(mm) mm_doc_sum = keras.layers.Lambda( - lambda x: K.tf.reduce_sum(x, 2))( + lambda x: tf.reduce_sum(x, 2))( mm_exp) - mm_log = keras.layers.Activation(K.tf.log1p)(mm_doc_sum) + mm_log = keras.layers.Activation(tf.math.log1p)(mm_doc_sum) mm_sum = keras.layers.Lambda( - lambda x: K.tf.reduce_sum(x, 1))(mm_log) + lambda x: tf.reduce_sum(x, 1))(mm_log) KM.append(mm_sum) - phi = keras.layers.Lambda(lambda x: K.tf.stack(x, 1))(KM) + phi = keras.layers.Lambda(lambda x: tf.stack(x, 1))(KM) out = self._make_output_layer()(phi) self._backend = keras.Model(inputs=[query, doc], outputs=[out]) diff --git a/matchzoo/models/drmm.py b/matchzoo/models/drmm.py index cfaeb4cc..f1b19b48 100644 --- a/matchzoo/models/drmm.py +++ b/matchzoo/models/drmm.py @@ -3,11 +3,11 @@ import keras import keras.backend as K +import tensorflow as tf from matchzoo.engine.base_model import BaseModel from matchzoo.engine.param import Param from matchzoo.engine.param_table import ParamTable -from matchzoo.engine import hyper_spaces class DRMM(BaseModel): @@ -67,11 +67,11 @@ def build(self): # shape = [B, L, D] embed_query = embedding(query) # shape = [B, L] - atten_mask = K.not_equal(query, self._params['mask_value']) + atten_mask = tf.not_equal(query, self._params['mask_value']) # shape = [B, L] - atten_mask = K.cast(atten_mask, K.floatx()) + atten_mask = tf.cast(atten_mask, K.floatx()) # shape = [B, L, D] - atten_mask = K.expand_dims(atten_mask, axis=2) + atten_mask = tf.expand_dims(atten_mask, axis=2) # shape = [B, L, D] attention_probs = self.attention_layer(embed_query, atten_mask) @@ -114,7 +114,7 @@ def attention_layer(cls, attention_input: typing.Any, )(dense_input) # shape = [B, L, 1] attention_probs = keras.layers.Lambda( - lambda x: keras.layers.activations.softmax(x, axis=1), + lambda x: tf.nn.softmax(x, axis=1), output_shape=lambda s: (s[0], s[1], s[2]), name="attention_probs" )(dense_input) diff --git a/matchzoo/models/drmmtks.py b/matchzoo/models/drmmtks.py index b0111b2f..4ce5edee 100644 --- a/matchzoo/models/drmmtks.py +++ b/matchzoo/models/drmmtks.py @@ -2,7 +2,7 @@ import typing import keras -import keras.backend as K +import tensorflow as tf from matchzoo.engine.base_model import BaseModel from matchzoo.engine.param import Param @@ -67,11 +67,11 @@ def build(self): # shape = [B, R, D] embed_doc = embedding(doc) # shape = [B, L] - atten_mask = K.not_equal(query, self._params['mask_value']) + atten_mask = tf.not_equal(query, self._params['mask_value']) # shape = [B, L] - atten_mask = K.cast(atten_mask, K.floatx()) + atten_mask = tf.cast(atten_mask, keras.backend.floatx()) # shape = [B, L, 1] - atten_mask = K.expand_dims(atten_mask, axis=2) + atten_mask = tf.expand_dims(atten_mask, axis=2) # shape = [B, L, 1] attention_probs = self.attention_layer(embed_query, atten_mask) @@ -85,7 +85,7 @@ def build(self): self.params['input_shapes'][0][0], self.params['input_shapes'][1][0]) matching_topk = keras.layers.Lambda( - lambda x: K.tf.nn.top_k(x, k=effective_top_k, sorted=True)[0] + lambda x: tf.nn.top_k(x, k=effective_top_k, sorted=True)[0] )(matching_matrix) # Process right input. @@ -127,7 +127,7 @@ def attention_layer(cls, attention_input: typing.Any, )(dense_input) # shape = [B, L, 1] attention_probs = keras.layers.Lambda( - lambda x: keras.layers.activations.softmax(x, axis=1), + lambda x: tf.nn.softmax(x, axis=1), output_shape=lambda s: (s[0], s[1], s[2]), name="attention_probs" )(dense_input) diff --git a/matchzoo/models/duet.py b/matchzoo/models/duet.py index cb85b602..22783fac 100644 --- a/matchzoo/models/duet.py +++ b/matchzoo/models/duet.py @@ -1,13 +1,11 @@ """DUET Model.""" import keras -import keras.backend as K import tensorflow as tf +from matchzoo.engine import hyper_spaces from matchzoo.engine.base_model import BaseModel from matchzoo.engine.param import Param -from matchzoo.engine.param_table import ParamTable -from matchzoo.engine import hyper_spaces class DUET(BaseModel): @@ -149,10 +147,10 @@ def _xor_match(cls, x): t2 = x[1] t1_shape = t1.get_shape() t2_shape = t2.get_shape() - t1_expand = K.tf.stack([t1] * t2_shape[1], 2) - t2_expand = K.tf.stack([t2] * t1_shape[1], 1) - out_bool = K.tf.equal(t1_expand, t2_expand) - out = K.tf.cast(out_bool, K.tf.float32) + t1_expand = tf.stack([t1] * t2_shape[1], 2) + t2_expand = tf.stack([t2] * t1_shape[1], 1) + out_bool = tf.equal(t1_expand, t2_expand) + out = tf.cast(out_bool, tf.float32) return out @classmethod diff --git a/matchzoo/models/knrm.py b/matchzoo/models/knrm.py index 6e0d3d77..7d1ff915 100644 --- a/matchzoo/models/knrm.py +++ b/matchzoo/models/knrm.py @@ -1,10 +1,9 @@ """KNRM model.""" import keras -import keras.backend as K +import tensorflow as tf from matchzoo.engine.base_model import BaseModel from matchzoo.engine.param import Param -from matchzoo.engine.param_table import ParamTable from matchzoo.engine import hyper_spaces @@ -69,13 +68,13 @@ def build(self): mu = 1.0 mm_exp = self._kernel_layer(mu, sigma)(mm) mm_doc_sum = keras.layers.Lambda( - lambda x: K.tf.reduce_sum(x, 2))(mm_exp) - mm_log = keras.layers.Activation(K.tf.log1p)(mm_doc_sum) + lambda x: tf.reduce_sum(x, 2))(mm_exp) + mm_log = keras.layers.Activation(tf.math.log1p)(mm_doc_sum) mm_sum = keras.layers.Lambda( - lambda x: K.tf.reduce_sum(x, 1))(mm_log) + lambda x: tf.reduce_sum(x, 1))(mm_log) KM.append(mm_sum) - phi = keras.layers.Lambda(lambda x: K.tf.stack(x, 1))(KM) + phi = keras.layers.Lambda(lambda x: tf.stack(x, 1))(KM) out = self._make_output_layer()(phi) self._backend = keras.Model(inputs=[query, doc], outputs=[out]) @@ -90,6 +89,6 @@ def _kernel_layer(cls, mu: float, sigma: float) -> keras.layers.Layer: """ def kernel(x): - return K.tf.exp(-0.5 * (x - mu) * (x - mu) / sigma / sigma) + return tf.math.exp(-0.5 * (x - mu) * (x - mu) / sigma / sigma) return keras.layers.Activation(kernel) diff --git a/matchzoo/models/mvlstm.py b/matchzoo/models/mvlstm.py index dc896d83..425a2b97 100644 --- a/matchzoo/models/mvlstm.py +++ b/matchzoo/models/mvlstm.py @@ -1,14 +1,12 @@ """An implementation of MVLSTM Model.""" -import typing import keras -import keras.backend as K +import tensorflow as tf -import matchzoo +from matchzoo.engine import hyper_spaces from matchzoo.engine.base_model import BaseModel from matchzoo.engine.param import Param from matchzoo.engine.param_table import ParamTable -from matchzoo.engine import hyper_spaces class MVLSTM(BaseModel): @@ -52,7 +50,7 @@ def build(self): query, doc = self._make_inputs() # Embedding layer - embedding = self._make_embedding_layer() + embedding = self._make_embedding_layer(mask_zero=True) embed_query = embedding(query) embed_doc = embedding(doc) @@ -73,7 +71,7 @@ def build(self): axes=[2, 2], normalize=False)([rep_query, rep_doc]) matching_signals = keras.layers.Reshape((-1,))(matching_matrix) matching_topk = keras.layers.Lambda( - lambda x: K.tf.nn.top_k(x, k=self._params['top_k'], sorted=True)[0] + lambda x: tf.nn.top_k(x, k=self._params['top_k'], sorted=True)[0] )(matching_signals) # Multilayer perceptron layer. diff --git a/matchzoo/preprocessors/__init__.py b/matchzoo/preprocessors/__init__.py index 3042e574..f119f4f7 100644 --- a/matchzoo/preprocessors/__init__.py +++ b/matchzoo/preprocessors/__init__.py @@ -3,6 +3,8 @@ from .naive_preprocessor import NaivePreprocessor from .basic_preprocessor import BasicPreprocessor from .cdssm_preprocessor import CDSSMPreprocessor +from .diin_preprocessor import DIINPreprocessor +from .bert_preprocessor import BertPreprocessor def list_available() -> list: diff --git a/matchzoo/preprocessors/basic_preprocessor.py b/matchzoo/preprocessors/basic_preprocessor.py index 0d8415c9..0fd82d37 100644 --- a/matchzoo/preprocessors/basic_preprocessor.py +++ b/matchzoo/preprocessors/basic_preprocessor.py @@ -44,7 +44,7 @@ class BasicPreprocessor(BasePreprocessor): >>> preprocessor.context['input_shapes'] [(10,), (20,)] >>> preprocessor.context['vocab_size'] - 225 + 228 >>> processed_train_data = preprocessor.transform(train_data, ... verbose=0) >>> type(processed_train_data) @@ -105,7 +105,7 @@ def fit(self, data_pack: DataPack, verbose: int = 1): vocab_unit = build_vocab_unit(data_pack, verbose=verbose) self._context['vocab_unit'] = vocab_unit - vocab_size = len(vocab_unit.state['term_index']) + 1 + vocab_size = len(vocab_unit.state['term_index']) self._context['vocab_size'] = vocab_size self._context['embedding_input_dim'] = vocab_size self._context['input_shapes'] = [(self._fixed_length_left,), diff --git a/matchzoo/preprocessors/bert_preprocessor.py b/matchzoo/preprocessors/bert_preprocessor.py new file mode 100644 index 00000000..2c6b64ce --- /dev/null +++ b/matchzoo/preprocessors/bert_preprocessor.py @@ -0,0 +1,139 @@ +"""Bert Preprocessor.""" + +from tqdm import tqdm + +from . import units +from .chain_transform import chain_transform +from matchzoo import DataPack +from matchzoo.engine.base_preprocessor import BasePreprocessor +from .build_vocab_unit import built_bert_vocab_unit +from .build_unit_from_data_pack import build_unit_from_data_pack + +tqdm.pandas() + + +class BertPreprocessor(BasePreprocessor): + """Bert-base Model preprocessor.""" + + def __init__(self, bert_vocab_path: str, + fixed_length_left: int = 30, + fixed_length_right: int = 30, + filter_mode: str = 'df', + filter_low_freq: float = 2, + filter_high_freq: float = float('inf'), + remove_stop_words: bool = False, + lower_case: bool = True, + chinese_version: bool = False, + ): + """ + Bert-base Model preprocessor. + + Example: + >>> import matchzoo as mz + >>> train_data = mz.datasets.toy.load_data() + >>> test_data = mz.datasets.toy.load_data(stage='test') + >>> # The argument 'bert_vocab_path' must feed the bert vocab path + >>> bert_preprocessor = mz.preprocessors.BertPreprocessor( + ... bert_vocab_path= + ... 'matchzoo/datasets/bert_resources/uncased_vocab_100.txt') + >>> train_data_processed = bert_preprocessor.fit_transform( + ... train_data) + >>> test_data_processed = bert_preprocessor.transform(test_data) + + """ + super().__init__() + self._fixed_length_left = fixed_length_left + self._fixed_length_right = fixed_length_right + self._bert_vocab_path = bert_vocab_path + self._left_fixedlength_unit = units.FixedLength( + self._fixed_length_left, + pad_mode='post' + ) + self._right_fixedlength_unit = units.FixedLength( + self._fixed_length_right, + pad_mode='post' + ) + self._filter_unit = units.FrequencyFilter( + low=filter_low_freq, + high=filter_high_freq, + mode=filter_mode + ) + self._units = self._default_units() + self._vocab_unit = built_bert_vocab_unit(self._bert_vocab_path) + + if chinese_version: + self._units.insert(1, units.ChineseTokenize()) + if lower_case: + self._units.append(units.Lowercase()) + self._units.append(units.StripAccent()) + self._units.append(units.WordPieceTokenize( + self._vocab_unit.state['term_index'])) + if remove_stop_words: + self._units.append(units.StopRemoval()) + + def fit(self, data_pack: DataPack, verbose: int = 1): + """ + Fit pre-processing context for transformation. + + :param verbose: Verbosity. + :param data_pack: Data_pack to be preprocessed. + :return: class:`BertPreprocessor` instance. + """ + data_pack = data_pack.apply_on_text(chain_transform(self._units), + verbose=verbose) + fitted_filter_unit = build_unit_from_data_pack(self._filter_unit, + data_pack, + flatten=False, + mode='right', + verbose=verbose) + self._context['filter_unit'] = fitted_filter_unit + self._context['vocab_unit'] = self._vocab_unit + vocab_size = len(self._vocab_unit.state['term_index']) + self._context['vocab_size'] = vocab_size + self._context['embedding_input_dim'] = vocab_size + self._context['input_shapes'] = [(self._fixed_length_left,), + (self._fixed_length_right,)] + return self + + def transform(self, data_pack: DataPack, verbose: int = 1) -> DataPack: + """ + Apply transformation on data, create fixed length representation. + + :param data_pack: Inputs to be preprocessed. + :param verbose: Verbosity. + + :return: Transformed data as :class:`DataPack` object. + """ + data_pack = data_pack.copy() + data_pack.apply_on_text(chain_transform(self._units), inplace=True, + verbose=verbose) + + data_pack.apply_on_text(self._context['filter_unit'].transform, + mode='right', inplace=True, verbose=verbose) + data_pack.apply_on_text(self._context['vocab_unit'].transform, + mode='both', inplace=True, verbose=verbose) + data_pack.append_text_length(inplace=True, verbose=verbose) + data_pack.apply_on_text(self._left_fixedlength_unit.transform, + mode='left', inplace=True, verbose=verbose) + data_pack.apply_on_text(self._right_fixedlength_unit.transform, + mode='right', inplace=True, verbose=verbose) + + max_len_left = self._fixed_length_left + max_len_right = self._fixed_length_right + + data_pack.left['length_left'] = \ + data_pack.left['length_left'].apply( + lambda val: min(val, max_len_left)) + + data_pack.right['length_right'] = \ + data_pack.right['length_right'].apply( + lambda val: min(val, max_len_right)) + return data_pack + + @classmethod + def _default_units(cls) -> list: + """Prepare needed process units.""" + return [ + units.BertClean(), + units.BasicTokenize() + ] diff --git a/matchzoo/preprocessors/build_vocab_unit.py b/matchzoo/preprocessors/build_vocab_unit.py index 3d9442de..77dc54a8 100644 --- a/matchzoo/preprocessors/build_vocab_unit.py +++ b/matchzoo/preprocessors/build_vocab_unit.py @@ -1,12 +1,13 @@ from matchzoo.data_pack import DataPack from .units import Vocabulary from .build_unit_from_data_pack import build_unit_from_data_pack +from .units import BertVocabulary def build_vocab_unit( - data_pack: DataPack, - mode: str = 'both', - verbose: int = 1 + data_pack: DataPack, + mode: str = 'both', + verbose: int = 1 ) -> Vocabulary: """ Build a :class:`preprocessor.units.Vocabulary` given `data_pack`. @@ -28,3 +29,16 @@ def build_vocab_unit( mode=mode, flatten=True, verbose=verbose ) + + +def built_bert_vocab_unit(vocab_path: str) -> BertVocabulary: + """ + Build a :class:`preprocessor.units.BertVocabulary` given `vocab_path`. + + :param vocab_path: bert vocabulary path. + :return: A built vocabulary unit. + + """ + vocab_unit = BertVocabulary(pad_value='[PAD]', oov_value='[UNK]') + vocab_unit.fit(vocab_path) + return vocab_unit diff --git a/matchzoo/preprocessors/cdssm_preprocessor.py b/matchzoo/preprocessors/cdssm_preprocessor.py index edeac4e9..d7f16754 100644 --- a/matchzoo/preprocessors/cdssm_preprocessor.py +++ b/matchzoo/preprocessors/cdssm_preprocessor.py @@ -73,7 +73,7 @@ def fit(self, data_pack: DataPack, verbose: int = 1): vocab_unit = build_vocab_unit(data_pack, verbose=verbose) self._context['vocab_unit'] = vocab_unit - vocab_size = len(vocab_unit.state['term_index']) + 1 + vocab_size = len(vocab_unit.state['term_index']) self._context['input_shapes'] = [ (self._fixed_length_left, vocab_size), (self._fixed_length_right, vocab_size) diff --git a/matchzoo/preprocessors/diin_preprocessor.py b/matchzoo/preprocessors/diin_preprocessor.py new file mode 100644 index 00000000..7d64ed16 --- /dev/null +++ b/matchzoo/preprocessors/diin_preprocessor.py @@ -0,0 +1,159 @@ +"""DIIN Preprocessor.""" + +from tqdm import tqdm +import pandas as pd + +from matchzoo.engine.base_preprocessor import BasePreprocessor +from matchzoo import DataPack +from .build_vocab_unit import build_vocab_unit +from .chain_transform import chain_transform +from . import units + +tqdm.pandas() + + +class DIINPreprocessor(BasePreprocessor): + """DIIN Model preprocessor.""" + + def __init__(self, + fixed_length_left: int = 10, + fixed_length_right: int = 10, + fixed_length_word: int = 5): + """ + DIIN Model preprocessor. + + :param fixed_length_left: Integer, maximize length of :attr:'left' in + the data_pack. + :param fixed_length_right: Integer, maximize length of :attr:'right' in + the data_pack. + :param fixed_length_word: Integer, maximize length of each word. + + Example: + >>> import matchzoo as mz + >>> train_data = mz.datasets.toy.load_data() + >>> test_data = mz.datasets.toy.load_data(stage='test') + >>> diin_preprocessor = mz.preprocessors.DIINPreprocessor( + ... fixed_length_left=5, + ... fixed_length_right=5, + ... fixed_length_word=3, + ... ) + >>> diin_preprocessor = diin_preprocessor.fit( + ... train_data, verbose=0) + >>> diin_preprocessor.context['input_shapes'] + [(5,), (5,), (5, 3), (5, 3), (5,), (5,)] + >>> diin_preprocessor.context['vocab_size'] + 893 + >>> train_data_processed = diin_preprocessor.transform( + ... train_data, verbose=0) + >>> type(train_data_processed) + + >>> test_data_processed = diin_preprocessor.transform( + ... test_data, verbose=0) + >>> type(test_data_processed) + + + """ + super().__init__() + self._fixed_length_left = fixed_length_left + self._fixed_length_right = fixed_length_right + self._fixed_length_word = fixed_length_word + self._left_fixedlength_unit = units.FixedLength( + self._fixed_length_left, + pad_value='0', + pad_mode='post' + ) + self._right_fixedlength_unit = units.FixedLength( + self._fixed_length_right, + pad_value='0', + pad_mode='post' + ) + self._units = self._default_units() + + def fit(self, data_pack: DataPack, verbose: int = 1): + """ + Fit pre-processing context for transformation. + + :param data_pack: data_pack to be preprocessed. + :param verbose: Verbosity. + :return: class:'DIINPreprocessor' instance. + """ + func = chain_transform(self._units) + data_pack = data_pack.apply_on_text(func, mode='both', verbose=verbose) + + vocab_unit = build_vocab_unit(data_pack, verbose=verbose) + vocab_size = len(vocab_unit.state['term_index']) + self._context['vocab_unit'] = vocab_unit + self._context['vocab_size'] = vocab_size + self._context['embedding_input_dim'] = vocab_size + + data_pack = data_pack.apply_on_text( + units.NgramLetter(ngram=1, reduce_dim=True).transform, + mode='both', verbose=verbose) + char_unit = build_vocab_unit(data_pack, verbose=verbose) + self._context['char_unit'] = char_unit + + self._context['input_shapes'] = [ + (self._fixed_length_left,), + (self._fixed_length_right,), + (self._fixed_length_left, self._fixed_length_word,), + (self._fixed_length_right, self._fixed_length_word,), + (self._fixed_length_left,), + (self._fixed_length_right,) + ] + return self + + def transform(self, data_pack: DataPack, verbose: int = 1) -> DataPack: + """ + Apply transformation on data. + + :param data_pack: Inputs to be preprocessed. + :param verbose: Verbosity. + + :return: Transformed data as :class:'DataPack' object. + """ + data_pack = data_pack.copy() + data_pack.apply_on_text( + chain_transform(self._units), + mode='both', inplace=True, verbose=verbose) + + # Process character representation + data_pack.apply_on_text( + units.NgramLetter(ngram=1, reduce_dim=False).transform, + rename=('char_left', 'char_right'), + mode='both', inplace=True, verbose=verbose) + char_index_dict = self._context['char_unit'].state['term_index'] + left_charindex_unit = units.CharacterIndex( + char_index_dict, self._fixed_length_left, self._fixed_length_word) + right_charindex_unit = units.CharacterIndex( + char_index_dict, self._fixed_length_right, self._fixed_length_word) + data_pack.left['char_left'] = data_pack.left['char_left'].apply( + left_charindex_unit.transform) + data_pack.right['char_right'] = data_pack.right['char_right'].apply( + right_charindex_unit.transform) + + # Process word representation + data_pack.apply_on_text( + self._context['vocab_unit'].transform, + mode='both', inplace=True, verbose=verbose) + + # Process exact match representation + frame = data_pack.relation.join( + data_pack.left, on='id_left', how='left' + ).join(data_pack.right, on='id_right', how='left') + left_exactmatch_unit = units.WordExactMatch( + self._fixed_length_left, match='text_left', to_match='text_right') + right_exactmatch_unit = units.WordExactMatch( + self._fixed_length_right, match='text_right', to_match='text_left') + data_pack.relation['match_left'] = frame.apply( + left_exactmatch_unit.transform, axis=1) + data_pack.relation['match_right'] = frame.apply( + right_exactmatch_unit.transform, axis=1) + + data_pack.apply_on_text( + self._left_fixedlength_unit.transform, + mode='left', inplace=True, verbose=verbose) + data_pack.apply_on_text( + self._right_fixedlength_unit.transform, + mode='right', inplace=True, verbose=verbose) + + return data_pack diff --git a/matchzoo/preprocessors/dssm_preprocessor.py b/matchzoo/preprocessors/dssm_preprocessor.py index 561cb2c2..2c0212a4 100644 --- a/matchzoo/preprocessors/dssm_preprocessor.py +++ b/matchzoo/preprocessors/dssm_preprocessor.py @@ -58,7 +58,7 @@ def fit(self, data_pack: DataPack, verbose: int = 1): vocab_unit = build_vocab_unit(data_pack, verbose=verbose) self._context['vocab_unit'] = vocab_unit - vocab_size = len(vocab_unit.state['term_index']) + 1 + vocab_size = len(vocab_unit.state['term_index']) self._context['vocab_size'] = vocab_size self._context['embedding_input_dim'] = vocab_size self._context['input_shapes'] = [(vocab_size,), (vocab_size,)] diff --git a/matchzoo/preprocessors/units/__init__.py b/matchzoo/preprocessors/units/__init__.py index d9e88953..7faa1c06 100644 --- a/matchzoo/preprocessors/units/__init__.py +++ b/matchzoo/preprocessors/units/__init__.py @@ -13,6 +13,14 @@ from .tokenize import Tokenize from .vocabulary import Vocabulary from .word_hashing import WordHashing +from .character_index import CharacterIndex +from .word_exact_match import WordExactMatch +from .bert_clean import BertClean +from .bert_clean import StripAccent +from .tokenize import ChineseTokenize +from .tokenize import BasicTokenize +from .tokenize import WordPieceTokenize +from .vocabulary import BertVocabulary def list_available() -> list: diff --git a/matchzoo/preprocessors/units/bert_clean.py b/matchzoo/preprocessors/units/bert_clean.py new file mode 100644 index 00000000..e6747a78 --- /dev/null +++ b/matchzoo/preprocessors/units/bert_clean.py @@ -0,0 +1,42 @@ +from .unit import Unit +from matchzoo.utils.bert_utils import \ + is_whitespace, is_control, run_strip_accents + + +class BertClean(Unit): + """Clean unit for raw text.""" + + def transform(self, input_: str) -> str: + """ + Process input data from raw terms to cleaned text. + + :param input_: raw textual input. + + :return cleaned_text: cleaned text. + """ + output = [] + for char in input_: + cp = ord(char) + if cp == 0 or cp == 0xfffd or is_control(char): + continue + if is_whitespace(char): + output.append(" ") + else: + output.append(char) + cleaned_text = "".join(output) + return cleaned_text + + +class StripAccent(Unit): + """Process unit for text lower case.""" + + def transform(self, input_: list) -> list: + """ + Strips accents from each token. + + :param input_: list of tokens. + + :return tokens: Accent-stripped list of tokens. + """ + + return [run_strip_accents(token) for token in input_] diff --git a/matchzoo/preprocessors/units/character_index.py b/matchzoo/preprocessors/units/character_index.py new file mode 100644 index 00000000..17126765 --- /dev/null +++ b/matchzoo/preprocessors/units/character_index.py @@ -0,0 +1,62 @@ +import numpy as np + +from .unit import Unit + + +class CharacterIndex(Unit): + """ + CharacterIndexUnit for DIIN model. + + The input of :class:'CharacterIndexUnit' should be a list of word + character list extracted from a text. The output is the character + index representation of this text. + + :class:`NgramLetterUnit` and :class:`VocabularyUnit` are two + essential prerequisite of :class:`CharacterIndexUnit`. + + Examples: + >>> input_ = [['#', 'a', '#'],['#', 'o', 'n', 'e', '#']] + >>> character_index = CharacterIndex( + ... char_index={ + ... '': 0, '': 1, 'a': 2, 'n': 3, 'e':4, '#':5}, + ... fixed_length_text=2, + ... fixed_length_word=5) + >>> index = character_index.transform(input_) + >>> index + [[5.0, 2.0, 5.0, 0.0, 0.0], [5.0, 1.0, 3.0, 4.0, 5.0]] + + """ + + def __init__( + self, + char_index: dict, + fixed_length_text: int, + fixed_length_word: int + ): + """ + Class initialization. + + :param char_index: character-index mapping generated by + :class:'VocabularyUnit'. + :param fixed_length_text: maximize length of a text. + :param fixed_length_word: maximize length of a word. + """ + self._char_index = char_index + self._fixed_length_text = fixed_length_text + self._fixed_length_word = fixed_length_word + + def transform(self, input_: list) -> list: + """ + Transform list of characters to corresponding indices. + + :param input_: list of characters generated by + :class:'NgramLetterUnit'. + + :return: character index representation of a text. + """ + idx = np.zeros((self._fixed_length_text, self._fixed_length_word)) + for i in range(min(len(input_), self._fixed_length_text)): + for j in range(min(len(input_[i]), self._fixed_length_word)): + idx[i, j] = self._char_index.get(input_[i][j], 1) + + return idx.tolist() diff --git a/matchzoo/preprocessors/units/frequency_filter.py b/matchzoo/preprocessors/units/frequency_filter.py index 5de7d69e..89a7523c 100644 --- a/matchzoo/preprocessors/units/frequency_filter.py +++ b/matchzoo/preprocessors/units/frequency_filter.py @@ -66,11 +66,11 @@ def fit(self, list_of_tokens: typing.List[typing.List[str]]): if self._low <= v < self._high: valid_terms.add(k) - self._state[self._mode] = valid_terms + self._context[self._mode] = valid_terms def transform(self, input_: list) -> list: """Transform a list of tokens by filtering out unwanted words.""" - valid_terms = self._state[self._mode] + valid_terms = self._context[self._mode] return list(filter(lambda token: token in valid_terms, input_)) @classmethod diff --git a/matchzoo/preprocessors/units/punc_removal.py b/matchzoo/preprocessors/units/punc_removal.py index 302a4702..af55d582 100644 --- a/matchzoo/preprocessors/units/punc_removal.py +++ b/matchzoo/preprocessors/units/punc_removal.py @@ -1,4 +1,4 @@ -import re +import string from .unit import Unit @@ -6,8 +6,6 @@ class PuncRemoval(Unit): """Process unit for remove punctuations.""" - _MATCH_PUNC = re.compile(r'[^\w\s]') - def transform(self, input_: list) -> list: """ Remove punctuations from list of tokens. @@ -16,5 +14,5 @@ def transform(self, input_: list) -> list: :return rv: tokens without punctuation. """ - return [token for token in input_ if - not self._MATCH_PUNC.search(token)] + table = str.maketrans({key: None for key in string.punctuation}) + return [item.translate(table) for item in input_] diff --git a/matchzoo/preprocessors/units/stateful_unit.py b/matchzoo/preprocessors/units/stateful_unit.py index 9f8b3fca..423075dc 100644 --- a/matchzoo/preprocessors/units/stateful_unit.py +++ b/matchzoo/preprocessors/units/stateful_unit.py @@ -5,17 +5,32 @@ class StatefulUnit(Unit, metaclass=abc.ABCMeta): - """Process unit do persive state (i.e. need fit).""" + """ + Unit with inner state. + + Usually need to be fit before transforming. All information gathered in the + fit phrase will be stored into its `context`. + """ def __init__(self): """Initialization.""" - self._state = {} + self._context = {} @property def state(self): - """Get current state.""" - return self._state + """ + Get current context. Same as `unit.context`. + + Deprecated since v2.2.0, and will be removed in the future. + Used `unit.context` instead. + """ + return self._context + + @property + def context(self): + """Get current context. Same as `unit.state`.""" + return self._context @abc.abstractmethod - def fit(self, input: typing.Any): + def fit(self, input_: typing.Any): """Abstract base method, need to be implemented in subclass.""" diff --git a/matchzoo/preprocessors/units/tokenize.py b/matchzoo/preprocessors/units/tokenize.py index 1aeb2e62..befdcc56 100644 --- a/matchzoo/preprocessors/units/tokenize.py +++ b/matchzoo/preprocessors/units/tokenize.py @@ -1,4 +1,6 @@ import nltk +from matchzoo.utils.bert_utils import is_chinese_char, \ + whitespace_tokenize, run_split_on_punc from .unit import Unit @@ -15,3 +17,110 @@ def transform(self, input_: str) -> list: :return tokens: tokenized tokens as a list. """ return nltk.word_tokenize(input_) + + +class ChineseTokenize(Unit): + """Process unit for text containing Chinese tokens.""" + + def transform(self, input_: str) -> str: + """ + Process input data from raw terms to processed text. + + :param input_: raw textual input. + + :return output: text with at least one blank between adjacent + Chinese tokens. + """ + output = [] + for char in input_: + cp = ord(char) + if is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class BasicTokenize(Unit): + """Process unit for text tokenization.""" + + def transform(self, input_: str) -> list: + """ + Process input data from raw terms to list of tokens. + + :param input_: raw textual input. + + :return tokens: tokenized tokens as a list. + """ + orig_tokens = whitespace_tokenize(input_) + split_tokens = [] + for token in orig_tokens: + split_tokens.extend(run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + +class WordPieceTokenize(Unit): + """Process unit for text tokenization.""" + + def __init__(self, vocab: dict, max_input_chars_per_word: int = 200): + """Initialization.""" + self.vocab = vocab + self.unk_token = '[UNK]' + self.max_input_chars_per_word = max_input_chars_per_word + + def transform(self, input_: list) -> list: + """ + Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example: + >>> input_list = ["unaffable"] + >>> vocab = {"un": 0, "##aff": 1, "##able":2} + >>> wordpiece_unit = WordPieceTokenize(vocab) + >>> output = wordpiece_unit.transform(input_list) + >>> golden_output = ["un", "##aff", "##able"] + >>> assert output == golden_output + + :param input_: token list. + + :return tokens: A list of wordpiece tokens. + """ + output_tokens = [] + for token in input_: + chars = list(token) + token_length = len(chars) + if token_length > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + unknown_suffix = False + start = 0 + sub_tokens = [] + while start < token_length: + end = token_length + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + unknown_suffix = True + break + sub_tokens.append(cur_substr) + start = end + + if unknown_suffix: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens diff --git a/matchzoo/preprocessors/units/vocabulary.py b/matchzoo/preprocessors/units/vocabulary.py index c725c469..711e4d50 100644 --- a/matchzoo/preprocessors/units/vocabulary.py +++ b/matchzoo/preprocessors/units/vocabulary.py @@ -5,20 +5,23 @@ class Vocabulary(StatefulUnit): """ Vocabulary class. + :param pad_value: The string value for the padding position. + :param oov_value: The string value for the out-of-vocabulary terms. + Examples: - >>> vocab = Vocabulary() + >>> vocab = Vocabulary(pad_value='[PAD]', oov_value='[OOV]') >>> vocab.fit(['A', 'B', 'C', 'D', 'E']) >>> term_index = vocab.state['term_index'] >>> term_index # doctest: +SKIP - {'E': 1, 'C': 2, 'D': 3, 'A': 4, 'B': 5} + {'[PAD]': 0, '[OOV]': 1, 'D': 2, 'A': 3, 'B': 4, 'C': 5, 'E': 6} >>> index_term = vocab.state['index_term'] >>> index_term # doctest: +SKIP - {1: 'C', 2: 'A', 3: 'E', 4: 'B', 5: 'D'} + {0: '[PAD]', 1: '[OOV]', 2: 'D', 3: 'A', 4: 'B', 5: 'C', 6: 'E'} >>> term_index['out-of-vocabulary-term'] - 0 + 1 >>> index_term[0] - '' + '[PAD]' >>> index_term[42] Traceback (most recent call last): ... @@ -27,40 +30,81 @@ class Vocabulary(StatefulUnit): >>> c_index = term_index['C'] >>> vocab.transform(['C', 'A', 'C']) == [c_index, a_index, c_index] True - >>> vocab.transform(['C', 'A', 'OOV']) == [c_index, a_index, 0] + >>> vocab.transform(['C', 'A', '[OOV]']) == [c_index, a_index, 1] True >>> indices = vocab.transform(list('ABCDDZZZ')) - >>> ''.join(vocab.state['index_term'][i] for i in indices) - 'ABCDD' + >>> ' '.join(vocab.state['index_term'][i] for i in indices) + 'A B C D D [OOV] [OOV] [OOV]' """ - class IndexTerm(dict): - """Map index to term.""" - - def __missing__(self, key): - """Map out-of-vocabulary indices to empty string.""" - if key == 0: - return '' - else: - raise KeyError(key) + def __init__(self, pad_value: str = '', oov_value: str = ''): + """Vocabulary unit initializer.""" + super().__init__() + self._pad = pad_value + self._oov = oov_value + self._context['term_index'] = self.TermIndex() + self._context['index_term'] = dict() class TermIndex(dict): """Map term to index.""" def __missing__(self, key): - """Map out-of-vocabulary terms to index 0.""" - return 0 + """Map out-of-vocabulary terms to index 1.""" + return 1 def fit(self, tokens: list): """Build a :class:`TermIndex` and a :class:`IndexTerm`.""" - self._state['term_index'] = self.TermIndex() - self._state['index_term'] = self.IndexTerm() + self._context['term_index'][self._pad] = 0 + self._context['term_index'][self._oov] = 1 + self._context['index_term'][0] = self._pad + self._context['index_term'][1] = self._oov terms = set(tokens) for index, term in enumerate(terms): - self._state['term_index'][term] = index + 1 - self._state['index_term'][index + 1] = term + self._context['term_index'][term] = index + 2 + self._context['index_term'][index + 2] = term + + def transform(self, input_: list) -> list: + """Transform a list of tokens to corresponding indices.""" + return [self._context['term_index'][token] for token in input_] + + +class BertVocabulary(StatefulUnit): + """ + Vocabulary class. + + :param pad_value: The string value for the padding position. + :param oov_value: The string value for the out-of-vocabulary terms. + + Examples: + >>> vocab = BertVocabulary(pad_value='[PAD]', oov_value='[UNK]') + >>> indices = vocab.transform(list('ABCDDZZZ')) + + """ + + def __init__(self, pad_value: str = '[PAD]', oov_value: str = '[UNK]'): + """Vocabulary unit initializer.""" + super().__init__() + self._pad = pad_value + self._oov = oov_value + self._context['term_index'] = self.TermIndex() + self._context['index_term'] = {} + + class TermIndex(dict): + """Map term to index.""" + + def __missing__(self, key): + """Map out-of-vocabulary terms to index 100 .""" + return 100 + + def fit(self, vocab_path: str): + """Build a :class:`TermIndex` and a :class:`IndexTerm`.""" + with open(vocab_path, 'r', encoding='utf-8') as vocab_file: + for idx, line in enumerate(vocab_file): + term = line.strip() + self._context['term_index'][term] = idx + self._context['index_term'][idx] = term def transform(self, input_: list) -> list: """Transform a list of tokens to corresponding indices.""" - return [self._state['term_index'][token] for token in input_] + return [self._context['term_index'][token] for token in input_] diff --git a/matchzoo/preprocessors/units/word_exact_match.py b/matchzoo/preprocessors/units/word_exact_match.py new file mode 100644 index 00000000..717b196d --- /dev/null +++ b/matchzoo/preprocessors/units/word_exact_match.py @@ -0,0 +1,72 @@ +import numpy as np +import pandas + +from .unit import Unit + + +class WordExactMatch(Unit): + """ + WordExactUnit Class. + + Process unit to get a binary match list of two word index lists. The + word index list is the word representation of a text. + + Examples: + >>> input_ = pandas.DataFrame({ + ... 'text_left':[[1, 2, 3],[4, 5, 7, 9]], + ... 'text_right':[[5, 3, 2, 7],[2, 3, 5]]} + ... ) + >>> left_word_exact_match = WordExactMatch( + ... fixed_length_text=5, + ... match='text_left', to_match='text_right' + ... ) + >>> left_out = input_.apply(left_word_exact_match.transform, axis=1) + >>> left_out[0] + [0.0, 1.0, 1.0, 0.0, 0.0] + >>> left_out[1] + [0.0, 1.0, 0.0, 0.0, 0.0] + >>> right_word_exact_match = WordExactMatch( + ... fixed_length_text=5, + ... match='text_right', to_match='text_left' + ... ) + >>> right_out = input_.apply(right_word_exact_match.transform, axis=1) + >>> right_out[0] + [0.0, 1.0, 1.0, 0.0, 0.0] + >>> right_out[1] + [0.0, 0.0, 1.0, 0.0, 0.0] + + """ + + def __init__( + self, + fixed_length_text: int, + match: str, + to_match: str + ): + """ + Class initialization. + + :param fixed_length_text: fixed length of the text. + :param match: the 'match' column name. + :param to_match: the 'to_match' column name. + """ + self._fixed_length_text = fixed_length_text + self._match = match + self._to_match = to_match + + def transform(self, input_) -> list: + """ + Transform two word index lists into a binary match list. + + :param input_: a dataframe include 'match' column and + 'to_match' column. + + :return: a binary match result list of two word index lists. + """ + match_length = len(input_[self._match]) + match_binary = np.zeros((self._fixed_length_text)) + for i in range(min(self._fixed_length_text, match_length)): + if input_[self._match][i] in set(input_[self._to_match]): + match_binary[i] = 1 + + return match_binary.tolist() diff --git a/matchzoo/preprocessors/units/word_hashing.py b/matchzoo/preprocessors/units/word_hashing.py index a17a1677..805c1ba3 100644 --- a/matchzoo/preprocessors/units/word_hashing.py +++ b/matchzoo/preprocessors/units/word_hashing.py @@ -19,12 +19,14 @@ class WordHashing(Unit): Examples: >>> letters = [['#te', 'tes','est', 'st#'], ['oov']] >>> word_hashing = WordHashing( - ... term_index={'': 0,'st#': 1, '#te': 2, 'est': 3, 'tes': 4}) + ... term_index={ + ... '_PAD': 0, 'OOV': 1, 'st#': 2, '#te': 3, 'est': 4, 'tes': 5 + ... }) >>> hashing = word_hashing.transform(letters) >>> hashing[0] - [0.0, 1.0, 1.0, 1.0, 1.0, 0.0] + [0.0, 0.0, 1.0, 1.0, 1.0, 1.0] >>> hashing[1] - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0] + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0] """ @@ -52,18 +54,18 @@ def transform(self, input_: list) -> list: if any([isinstance(elem, list) for elem in input_]): # The input shape for CDSSM is # [[word1 ngram, ngram], [word2, ngram, ngram], ...]. - hashing = np.zeros((len(input_), len(self._term_index) + 1)) + hashing = np.zeros((len(input_), len(self._term_index))) for idx, word in enumerate(input_): counted_letters = collections.Counter(word) for key, value in counted_letters.items(): - letter_id = self._term_index.get(key, 0) + letter_id = self._term_index.get(key, 1) hashing[idx, letter_id] = value else: # The input shape for DSSM model [ngram, ngram, ...]. - hashing = np.zeros((len(self._term_index) + 1)) + hashing = np.zeros(len(self._term_index)) counted_letters = collections.Counter(input_) for key, value in counted_letters.items(): - letter_id = self._term_index.get(key, 0) + letter_id = self._term_index.get(key, 1) hashing[letter_id] = value return hashing.tolist() diff --git a/matchzoo/utils/__init__.py b/matchzoo/utils/__init__.py index 8c76f5fa..63a840db 100644 --- a/matchzoo/utils/__init__.py +++ b/matchzoo/utils/__init__.py @@ -1,3 +1,4 @@ from .one_hot import one_hot from .tensor_type import TensorType from .list_recursive_subclasses import list_recursive_concrete_subclasses +from .make_keras_optimizer_picklable import make_keras_optimizer_picklable diff --git a/matchzoo/utils/bert_utils.py b/matchzoo/utils/bert_utils.py new file mode 100644 index 00000000..8490ddbb --- /dev/null +++ b/matchzoo/utils/bert_utils.py @@ -0,0 +1,94 @@ +import unicodedata + + +def is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + return (char == " ") or \ + (char == "\t") or \ + (char == "\n") or \ + (char == "\r") or \ + (unicodedata.category(char) == "Zs") + + +def is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ["Cc", "Cf"]: + return True + return False + + +def is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + condition = (33 <= cp <= 47) or (58 <= cp <= 64) or \ + (91 <= cp <= 96) or (123 <= cp <= 126) + cat = unicodedata.category(char) + if condition or cat.startswith("P"): + return True + return False + + +def is_chinese_char(cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean + # characters, despite its name. The modern Korean Hangul alphabet is a + # different block, as is Japanese Hiragana and Katakana. Those alphabets + # are used to write space-separated words, so they are not treated + # specially and handled like the all of the other languages. + return (0x4E00 <= cp <= 0x9FFF) or \ + (0x3400 <= cp <= 0x4DBF) or \ + (0x20000 <= cp <= 0x2A6DF) or \ + (0x2A700 <= cp <= 0x2B73F) or \ + (0x2B740 <= cp <= 0x2B81F) or \ + (0x2B820 <= cp <= 0x2CEAF) or \ + (0xF900 <= cp <= 0xFAFF) or \ + (0x2F800 <= cp <= 0x2FA1F) + + +def run_strip_accents(text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [char for char in text if not unicodedata.category(char) == 'Mn'] + return "".join(output) + + +def run_split_on_punc(text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + tokens = text.split() + return tokens diff --git a/matchzoo/utils/make_keras_optimizer_picklable.py b/matchzoo/utils/make_keras_optimizer_picklable.py new file mode 100644 index 00000000..c45edba4 --- /dev/null +++ b/matchzoo/utils/make_keras_optimizer_picklable.py @@ -0,0 +1,19 @@ +import keras + + +def make_keras_optimizer_picklable(): + """ + Fix https://github.com/NTMC-Community/MatchZoo/issues/726. + + This function changes how keras behaves, use with caution. + """ + def __getstate__(self): + return keras.optimizers.serialize(self) + + def __setstate__(self, state): + optimizer = keras.optimizers.deserialize(state) + self.__dict__ = optimizer.__dict__ + + cls = keras.optimizers.Optimizer + cls.__getstate__ = __getstate__ + cls.__setstate__ = __setstate__ diff --git a/requirements.txt b/requirements.txt index 978099aa..995b0417 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ -keras == 2.2.4 +keras == 2.3.0 tabulate >= 0.8.2 -tensorflow >= 1.1.0 +tensorflow >= 2.0.0 nltk >= 3.2.3 numpy >= 1.14 -tqdm >= 4.19.4 +tqdm >= 4.23.4 dill >= 0.2.7.1 hyperopt >= 0.1.1 -pandas >= 0.23.1 +pandas == 0.24.2 networkx >= 2.1 h5py >= 2.8.0 coverage >= 4.3.4 diff --git a/setup.py b/setup.py index 8aadcbf9..f99f2f42 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ long_description = f.read() install_requires = [ - 'keras == 2.2.4', + 'keras >= 2.3.0', 'nltk >= 3.2.3', 'numpy >= 1.14', 'tqdm >= 4.19.4', diff --git a/tests/unit_test/models/test_models.py b/tests/unit_test/models/test_models.py index ecf3da20..743fe9c0 100644 --- a/tests/unit_test/models/test_models.py +++ b/tests/unit_test/models/test_models.py @@ -9,7 +9,7 @@ import shutil import matchzoo as mz - +from keras.backend import clear_session @pytest.fixture(scope='module', params=[ mz.tasks.Ranking(loss=mz.losses.RankCrossEntropyLoss(num_neg=2)), @@ -21,7 +21,7 @@ def task(request): @pytest.fixture(scope='module') def train_raw(task): - return mz.datasets.toy.load_data('train', task) + return mz.datasets.toy.load_data('train', task)[:5] @pytest.fixture(scope='module', params=mz.models.list_available()) @@ -31,11 +31,12 @@ def model_class(request): @pytest.fixture(scope='module') def embedding(): - return mz.datasets.embeddings.load_glove_embedding(dimension=50) + return mz.datasets.toy.load_embedding() @pytest.fixture(scope='module') def setup(task, model_class, train_raw, embedding): + clear_session() # prevent OOM during CI tests return mz.auto.prepare( task=task, model_class=model_class, @@ -65,19 +66,20 @@ def embedding_matrix(setup): @pytest.fixture(scope='module') -def gen(train_raw, preprocessor, gen_builder): - return gen_builder.build(preprocessor.transform(train_raw)) +def data(train_raw, preprocessor, gen_builder): + return gen_builder.build(preprocessor.transform(train_raw))[0] @pytest.mark.slow -def test_model_fit_eval_predict(model, gen): - x, y = gen[0] - assert model.fit(x, y, verbose=0) - assert model.evaluate(x, y) - assert model.predict(x) is not None +def test_model_fit_eval_predict(model, data): + x, y = data + batch_size = len(x['id_left']) + assert model.fit(x, y, batch_size=batch_size, verbose=0) + assert model.evaluate(x, y, batch_size=batch_size) + assert model.predict(x, batch_size=batch_size) is not None -@pytest.mark.slow +@pytest.mark.cron def test_save_load_model(model): tmpdir = '.matchzoo_test_save_load_tmpdir' @@ -94,9 +96,9 @@ def test_save_load_model(model): shutil.rmtree(tmpdir) -@pytest.mark.slow +@pytest.mark.cron def test_hyper_space(model): - for _ in range(8): + for _ in range(2): new_params = copy.deepcopy(model.params) sample = mz.hyper_spaces.sample(new_params.hyper_space) for key, value in sample.items(): diff --git a/tests/unit_test/processor_units/test_processor_units.py b/tests/unit_test/processor_units/test_processor_units.py index 45afb0ad..fccd87e0 100644 --- a/tests/unit_test/processor_units/test_processor_units.py +++ b/tests/unit_test/processor_units/test_processor_units.py @@ -136,3 +136,33 @@ def test_this(): type(train_data_processed) test_data_transformed = dssm_preprocessor.transform(test_data) type(test_data_transformed) + + +import tempfile +import os + + +def test_bert_tokenizer_unit(): + vocab_tokens = [ + "[PAD]", "further", "##more", ",", "under", "the", "micro", "##scope", "neither", + "entity", "contains", "glands", ".", "此", "外", "在", "显", "微", "镜", "下" + ] + raw_text = "furthermore, \r under the microscope \t neither entity \n contains sebaceous glands. 此外, 在显微镜下" + + golden_tokens = ['further', '##more', ',', 'under', 'the', 'micro', '##scope', 'neither', 'entity', 'contains', + '[UNK]', 'glands', '.', '此', '外', ',', '在', '显', '微', '镜', '下'] + + vocab_dict = {} + for idx, token in enumerate(vocab_tokens): + vocab_dict[token] = idx + + clean_unit = units.BertClean() + cleaned_text = clean_unit.transform(raw_text) + chinese_tokenize_unit = units.ChineseTokenize() + chinese_tokenized_text = chinese_tokenize_unit.transform(cleaned_text) + basic_tokenize_unit = units.BasicTokenize() + basic_tokens = basic_tokenize_unit.transform(chinese_tokenized_text) + wordpiece_unit = units.WordPieceTokenize(vocab_dict) + wordpiece_tokens = wordpiece_unit.transform(basic_tokens) + + assert wordpiece_tokens == golden_tokens diff --git a/tests/unit_test/test_datasets.py b/tests/unit_test/test_datasets.py index 8781e444..f94fb5cd 100644 --- a/tests/unit_test/test_datasets.py +++ b/tests/unit_test/test_datasets.py @@ -3,7 +3,7 @@ import matchzoo as mz -@pytest.mark.slow +@pytest.mark.cron def test_load_data(): train_data = mz.datasets.wiki_qa.load_data('train', task='ranking') assert len(train_data) == 20360 @@ -32,7 +32,7 @@ def test_load_data(): assert tag == [False, True] -@pytest.mark.slow +@pytest.mark.cron def test_load_snli(): train_data, classes = mz.datasets.snli.load_data('train', 'classification', @@ -60,7 +60,7 @@ def test_load_snli(): assert y.shape == (num_samples, 1) -@pytest.mark.slow +@pytest.mark.cron def test_load_quora_qp(): train_data = mz.datasets.quora_qp.load_data(task='classification') assert len(train_data) == 363177 @@ -83,3 +83,38 @@ def test_load_quora_qp(): x, y = dev_data.unpack() assert y.shape == (40371, 1) + +@pytest.mark.cron +def test_load_cqa_ql_16(): + # test load question pairs + train_data = mz.datasets.cqa_ql_16.load_data(task='classification') + assert len(train_data) == 3998 + dev_data, tag = mz.datasets.cqa_ql_16.load_data( + 'dev', + task='classification', + return_classes=True) + assert tag == ['PerfectMatch', 'Relevant', 'Irrelevant'] + assert len(dev_data) == 500 + x, y = dev_data.unpack() + assert y.shape == (500, 3) + test_data = mz.datasets.cqa_ql_16.load_data('test') + assert len(test_data) == 700 + + # test load answer pairs + train_data = mz.datasets.cqa_ql_16.load_data(match_type='answer') + assert len(train_data) == 39980 + test_data = mz.datasets.cqa_ql_16.load_data(stage='test', match_type='answer') + assert len(test_data) == 7000 + + # test load external answer pairs + train_data = mz.datasets.cqa_ql_16.load_data(match_type='external_answer') + assert len(train_data) == 39980 + + # test load rank data + train_data = mz.datasets.cqa_ql_16.load_data(task='ranking') + x, y = train_data.unpack() + assert y.shape == (3998, 1) + + dev_data = mz.datasets.cqa_ql_16.load_data('dev', task='ranking', match_type='answer', target_label='Good') + x, y = dev_data.unpack() + assert y.shape == (5000, 1) diff --git a/tests/unit_test/test_embedding.py b/tests/unit_test/test_embedding.py index 444e9daa..81a9c9e6 100644 --- a/tests/unit_test/test_embedding.py +++ b/tests/unit_test/test_embedding.py @@ -5,15 +5,15 @@ @pytest.fixture def term_index(): - return {'G': 1, 'C': 2, 'D': 3, 'A': 4, '[PAD]': 0} + return {'G': 1, 'C': 2, 'D': 3, 'A': 4, '_PAD': 0} def test_embedding(term_index): embed = mz.embedding.load_from_file(mz.datasets.embeddings.EMBED_RANK) matrix = embed.build_matrix(term_index) - assert matrix.shape == (len(term_index) + 1, 50) + assert matrix.shape == (len(term_index), 50) embed = mz.embedding.load_from_file(mz.datasets.embeddings.EMBED_10_GLOVE, mode='glove') matrix = embed.build_matrix(term_index) - assert matrix.shape == (len(term_index) + 1, 10) + assert matrix.shape == (len(term_index), 10) assert embed.input_dim == 5 diff --git a/tests/unit_test/test_layers.py b/tests/unit_test/test_layers.py index 88c9c1ab..cda16191 100644 --- a/tests/unit_test/test_layers.py +++ b/tests/unit_test/test_layers.py @@ -3,6 +3,8 @@ from keras import backend as K from matchzoo import layers +from matchzoo.contrib.layers import SpatialGRU +from matchzoo.contrib.layers import MatchingTensorLayer def test_matching_layers(): @@ -26,3 +28,33 @@ def test_matching_layers(): layers.MatchingLayer(matching_type='error') with pytest.raises(ValueError): layers.MatchingLayer()([s1_tensor, s3_tensor]) + + +def test_spatial_gru(): + s_value = K.variable(np.array([[[[1, 2], [2, 3], [3, 4]], + [[4, 5], [5, 6], [6, 7]]], + [[[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]], + [[0.4, 0.5], [0.5, 0.6], [0.6, 0.7]]]])) + for direction in ['lt', 'rb']: + model = SpatialGRU(direction=direction) + _ = K.eval(model(s_value)) + with pytest.raises(ValueError): + SpatialGRU(direction='lr')(s_value) + + +def test_matching_tensor_layer(): + s1_value = np.array([[[1, 2], [2, 3], [3, 4]], + [[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]]) + s2_value = np.array([[[1, 2], [2, 3]], + [[0.1, 0.2], [0.2, 0.3]]]) + s3_value = np.array([[[1, 2], [2, 3]], + [[0.1, 0.2], [0.2, 0.3]], + [[0.1, 0.2], [0.2, 0.3]]]) + s1_tensor = K.variable(s1_value) + s2_tensor = K.variable(s2_value) + s3_tensor = K.variable(s3_value) + for init_diag in [True, False]: + model = MatchingTensorLayer(init_diag=init_diag) + _ = K.eval(model([s1_tensor, s2_tensor])) + with pytest.raises(ValueError): + MatchingTensorLayer()([s1_tensor, s3_tensor]) diff --git a/tests/unit_test/test_utils.py b/tests/unit_test/test_utils.py new file mode 100644 index 00000000..e69de29b diff --git a/tutorials/quora/esim.ipynb b/tutorials/quora/esim.ipynb new file mode 100644 index 00000000..4bdf927a --- /dev/null +++ b/tutorials/quora/esim.ipynb @@ -0,0 +1,675 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using TensorFlow backend.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "matchzoo version 2.1.0\n", + "\n", + "data loading ...\n", + "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n", + "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n", + "loading embedding ...\n", + "embedding loaded as `glove_embedding`\n" + ] + } + ], + "source": [ + "%run ./tutorials/wikiqa/init.ipynb" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from keras.optimizers import Adam\n", + "from keras.utils import to_categorical\n", + "\n", + "import matchzoo as mz\n", + "from matchzoo.contrib.models.esim import ESIM" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def load_filtered_data(preprocessor, data_type):\n", + " assert ( data_type in ['train', 'dev', 'test'])\n", + " data_pack = mz.datasets.wiki_qa.load_data(data_type, task='ranking')\n", + "\n", + " if data_type == 'train':\n", + " X, Y = preprocessor.fit_transform(data_pack).unpack()\n", + " else:\n", + " X, Y = preprocessor.transform(data_pack).unpack()\n", + "\n", + " new_idx = []\n", + " for i in range(Y.shape[0]):\n", + " if X[\"length_left\"][i] == 0 or X[\"length_right\"][i] == 0:\n", + " continue\n", + " new_idx.append(i)\n", + " new_idx = np.array(new_idx)\n", + " print(\"Removed empty data. Found \", (Y.shape[0] - new_idx.shape[0]))\n", + "\n", + " for k in X.keys():\n", + " X[k] = X[k][new_idx]\n", + " Y = Y[new_idx]\n", + "\n", + " pos_idx = (Y == 1)[:, 0]\n", + " pos_qid = X[\"id_left\"][pos_idx]\n", + " keep_idx_bool = np.array([ qid in pos_qid for qid in X[\"id_left\"]])\n", + " keep_idx = np.arange(keep_idx_bool.shape[0])\n", + " keep_idx = keep_idx[keep_idx_bool]\n", + " print(\"Removed questions with no pos label. Found \", (keep_idx_bool == 0).sum())\n", + "\n", + " print(\"shuffling...\")\n", + " np.random.shuffle(keep_idx)\n", + " for k in X.keys():\n", + " X[k] = X[k][keep_idx]\n", + " Y = Y[keep_idx]\n", + "\n", + " return X, Y, preprocessor" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "fixed_length_left = 10\n", + "fixed_length_right = 40\n", + "batch_size = 32\n", + "epochs = 5" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 10798.93it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:02<00:00, 8019.65it/s]\n", + "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 1415354.12it/s]\n", + "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 226166.63it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 233892.08it/s]\n", + "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 782897.32it/s]\n", + "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 1175423.27it/s]\n", + "Building Vocabulary from a datapack.: 100%|██████████| 358408/358408 [00:00<00:00, 4845654.07it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 15108.05it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:02<00:00, 8129.15it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 222548.25it/s]\n", + "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 324738.11it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 122413.67it/s]\n", + "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 821484.73it/s]\n", + "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 1319786.92it/s]\n", + "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 200871.36it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 180842.83it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Removed empty data. Found 91\n", + "Removed questions with no pos label. Found 11642\n", + "shuffling...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 296/296 [00:00<00:00, 15853.43it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2708/2708 [00:00<00:00, 8318.22it/s]\n", + "Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 232964.32it/s]\n", + "Processing text_left with transform: 100%|██████████| 296/296 [00:00<00:00, 200892.23it/s]\n", + "Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 231808.96it/s]\n", + "Processing length_left with len: 100%|██████████| 296/296 [00:00<00:00, 562279.88it/s]\n", + "Processing length_right with len: 100%|██████████| 2708/2708 [00:00<00:00, 1159470.73it/s]\n", + "Processing text_left with transform: 100%|██████████| 296/296 [00:00<00:00, 183357.55it/s]\n", + "Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 178815.40it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Removed empty data. Found 8\n", + "Removed questions with no pos label. Found 1595\n", + "shuffling...\n" + ] + } + ], + "source": [ + "# prepare data\n", + "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=fixed_length_left,\n", + " fixed_length_right=fixed_length_right,\n", + " remove_stop_words=False,\n", + " filter_low_freq=10)\n", + "\n", + "train_X, train_Y, preprocessor = load_filtered_data(preprocessor, 'train')\n", + "val_X, val_Y, _ = load_filtered_data(preprocessor, 'dev')\n", + "pred_X, pred_Y = val_X, val_Y\n", + "# pred_X, pred_Y, _ = load_filtered_data(preprocessor, 'test') # no prediction label for quora dataset\n", + "\n", + "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'], initializer=lambda: 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "__________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + "text_left (InputLayer) (None, 10) 0 \n", + "__________________________________________________________________________________________________\n", + "text_right (InputLayer) (None, 40) 0 \n", + "__________________________________________________________________________________________________\n", + "embedding (Embedding) multiple 1930500 text_left[0][0] \n", + " text_right[0][0] \n", + "__________________________________________________________________________________________________\n", + "dropout_1 (Dropout) multiple 0 embedding[0][0] \n", + " embedding[1][0] \n", + " dense_1[0][0] \n", + " dense_1[1][0] \n", + " dense_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_1 (Lambda) multiple 0 text_left[0][0] \n", + " text_right[0][0] \n", + "__________________________________________________________________________________________________\n", + "bidirectional_1 (Bidirectional) multiple 1442400 dropout_1[0][0] \n", + " dropout_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_2 (Lambda) (None, 10, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_3 (Lambda) (None, 40, 1) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_1 (Multiply) (None, 10, 600) 0 bidirectional_1[0][0] \n", + " lambda_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_2 (Multiply) (None, 40, 600) 0 bidirectional_1[1][0] \n", + " lambda_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_4 (Lambda) (None, 10, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_5 (Lambda) (None, 1, 40) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "dot_1 (Dot) (None, 10, 40) 0 multiply_1[0][0] \n", + " multiply_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_3 (Multiply) (None, 10, 40) 0 lambda_4[0][0] \n", + " lambda_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "permute_1 (Permute) (None, 40, 10) 0 dot_1[0][0] \n", + " multiply_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "atten_mask (Lambda) multiple 0 dot_1[0][0] \n", + " multiply_3[0][0] \n", + " permute_1[0][0] \n", + " permute_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "softmax_1 (Softmax) multiple 0 atten_mask[0][0] \n", + " atten_mask[1][0] \n", + "__________________________________________________________________________________________________\n", + "dot_2 (Dot) (None, 10, 600) 0 softmax_1[0][0] \n", + " multiply_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "dot_3 (Dot) (None, 40, 600) 0 softmax_1[1][0] \n", + " multiply_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "subtract_1 (Subtract) (None, 10, 600) 0 multiply_1[0][0] \n", + " dot_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_4 (Multiply) (None, 10, 600) 0 multiply_1[0][0] \n", + " dot_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "subtract_2 (Subtract) (None, 40, 600) 0 multiply_2[0][0] \n", + " dot_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_5 (Multiply) (None, 40, 600) 0 multiply_2[0][0] \n", + " dot_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_1 (Concatenate) (None, 10, 2400) 0 multiply_1[0][0] \n", + " dot_2[0][0] \n", + " subtract_1[0][0] \n", + " multiply_4[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_2 (Concatenate) (None, 40, 2400) 0 multiply_2[0][0] \n", + " dot_3[0][0] \n", + " subtract_2[0][0] \n", + " multiply_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_1 (Dense) multiple 720300 concatenate_1[0][0] \n", + " concatenate_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "bidirectional_2 (Bidirectional) multiple 1442400 dropout_1[2][0] \n", + " dropout_1[3][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_6 (Lambda) (None, 10, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_8 (Lambda) (None, 10, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_10 (Lambda) (None, 40, 1) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_12 (Lambda) (None, 40, 1) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_6 (Multiply) (None, 10, 600) 0 bidirectional_2[0][0] \n", + " lambda_6[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_7 (Multiply) (None, 10, 600) 0 bidirectional_2[0][0] \n", + " lambda_8[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_8 (Multiply) (None, 40, 600) 0 bidirectional_2[1][0] \n", + " lambda_10[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_9 (Multiply) (None, 40, 600) 0 bidirectional_2[1][0] \n", + " lambda_12[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_7 (Lambda) (None, 600) 0 multiply_6[0][0] \n", + " lambda_6[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_9 (Lambda) (None, 600) 0 multiply_7[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_11 (Lambda) (None, 600) 0 multiply_8[0][0] \n", + " lambda_10[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_13 (Lambda) (None, 600) 0 multiply_9[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_3 (Concatenate) (None, 1200) 0 lambda_7[0][0] \n", + " lambda_9[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_4 (Concatenate) (None, 1200) 0 lambda_11[0][0] \n", + " lambda_13[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_5 (Concatenate) (None, 2400) 0 concatenate_3[0][0] \n", + " concatenate_4[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_2 (Dense) (None, 300) 720300 concatenate_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_3 (Dense) (None, 1) 301 dropout_1[4][0] \n", + "==================================================================================================\n", + "Total params: 6,256,201\n", + "Trainable params: 4,325,701\n", + "Non-trainable params: 1,930,500\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "model = ESIM()\n", + "model.params['task'] = mz.tasks.Ranking()\n", + "model.params['mask_value'] = 0\n", + "model.params['input_shapes'] = [[fixed_length_left, ],\n", + " [fixed_length_right, ]]\n", + "model.params['lstm_dim'] = 300\n", + "model.params['embedding_input_dim'] = preprocessor.context['vocab_size']\n", + "model.params['embedding_output_dim'] = 300\n", + "model.params['embedding_trainable'] = False\n", + "model.params['dropout_rate'] = 0.5\n", + "\n", + "model.params['mlp_num_units'] = 300\n", + "model.params['mlp_num_layers'] = 0\n", + "model.params['mlp_num_fan_out'] = 300\n", + "model.params['mlp_activation_func'] = 'tanh'\n", + "model.params['optimizer'] = Adam(lr=4e-4)\n", + "\n", + "model.guess_and_fill_missing_params()\n", + "model.build()\n", + "\n", + "model.compile()\n", + "model.backend.summary() # not visualize" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 8627 samples, validate on 1130 samples\n", + "Epoch 1/5\n", + "8627/8627 [==============================] - 48s 6ms/step - loss: 0.1073 - val_loss: 0.0984\n", + "Validation: mean_average_precision(0.0): 0.6222655981584554\n", + "Epoch 2/5\n", + "8627/8627 [==============================] - 44s 5ms/step - loss: 0.0994 - val_loss: 0.0974\n", + "Validation: mean_average_precision(0.0): 0.640342571890191\n", + "Epoch 3/5\n", + "8627/8627 [==============================] - 44s 5ms/step - loss: 0.0944 - val_loss: 0.0981\n", + "Validation: mean_average_precision(0.0): 0.633281742507933\n", + "Epoch 4/5\n", + "8627/8627 [==============================] - 44s 5ms/step - loss: 0.0915 - val_loss: 0.0898\n", + "Validation: mean_average_precision(0.0): 0.6479046351993808\n", + "Epoch 5/5\n", + "8627/8627 [==============================] - 44s 5ms/step - loss: 0.0893 - val_loss: 0.0931\n", + "Validation: mean_average_precision(0.0): 0.6506805763854636\n" + ] + } + ], + "source": [ + "# run as classification task\n", + "model.load_embedding_matrix(embedding_matrix)\n", + "evaluate = mz.callbacks.EvaluateAllMetrics(model,\n", + " x=pred_X,\n", + " y=pred_Y,\n", + " once_every=1,\n", + " batch_size=len(pred_Y))\n", + "\n", + "history = model.fit(x = [train_X['text_left'],\n", + " train_X['text_right']],\n", + " y = train_Y,\n", + " validation_data = (val_X, val_Y),\n", + " batch_size = batch_size,\n", + " epochs = epochs,\n", + " callbacks=[evaluate]\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "__________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + "text_left (InputLayer) (None, 10) 0 \n", + "__________________________________________________________________________________________________\n", + "text_right (InputLayer) (None, 40) 0 \n", + "__________________________________________________________________________________________________\n", + "embedding (Embedding) multiple 1930500 text_left[0][0] \n", + " text_right[0][0] \n", + "__________________________________________________________________________________________________\n", + "dropout_1 (Dropout) multiple 0 embedding[0][0] \n", + " embedding[1][0] \n", + " dense_1[0][0] \n", + " dense_1[1][0] \n", + " dense_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_1 (Lambda) multiple 0 text_left[0][0] \n", + " text_right[0][0] \n", + "__________________________________________________________________________________________________\n", + "bidirectional_1 (Bidirectional) multiple 1442400 dropout_1[0][0] \n", + " dropout_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_2 (Lambda) (None, 10, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_3 (Lambda) (None, 40, 1) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_1 (Multiply) (None, 10, 600) 0 bidirectional_1[0][0] \n", + " lambda_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_2 (Multiply) (None, 40, 600) 0 bidirectional_1[1][0] \n", + " lambda_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_4 (Lambda) (None, 10, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_5 (Lambda) (None, 1, 40) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "dot_1 (Dot) (None, 10, 40) 0 multiply_1[0][0] \n", + " multiply_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_3 (Multiply) (None, 10, 40) 0 lambda_4[0][0] \n", + " lambda_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "permute_1 (Permute) (None, 40, 10) 0 dot_1[0][0] \n", + " multiply_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "atten_mask (Lambda) multiple 0 dot_1[0][0] \n", + " multiply_3[0][0] \n", + " permute_1[0][0] \n", + " permute_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "softmax_1 (Softmax) multiple 0 atten_mask[0][0] \n", + " atten_mask[1][0] \n", + "__________________________________________________________________________________________________\n", + "dot_2 (Dot) (None, 10, 600) 0 softmax_1[0][0] \n", + " multiply_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "dot_3 (Dot) (None, 40, 600) 0 softmax_1[1][0] \n", + " multiply_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "subtract_1 (Subtract) (None, 10, 600) 0 multiply_1[0][0] \n", + " dot_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_4 (Multiply) (None, 10, 600) 0 multiply_1[0][0] \n", + " dot_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "subtract_2 (Subtract) (None, 40, 600) 0 multiply_2[0][0] \n", + " dot_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_5 (Multiply) (None, 40, 600) 0 multiply_2[0][0] \n", + " dot_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_1 (Concatenate) (None, 10, 2400) 0 multiply_1[0][0] \n", + " dot_2[0][0] \n", + " subtract_1[0][0] \n", + " multiply_4[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_2 (Concatenate) (None, 40, 2400) 0 multiply_2[0][0] \n", + " dot_3[0][0] \n", + " subtract_2[0][0] \n", + " multiply_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_1 (Dense) multiple 720300 concatenate_1[0][0] \n", + " concatenate_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "bidirectional_2 (Bidirectional) multiple 1442400 dropout_1[2][0] \n", + " dropout_1[3][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_6 (Lambda) (None, 10, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_8 (Lambda) (None, 10, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_10 (Lambda) (None, 40, 1) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_12 (Lambda) (None, 40, 1) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_6 (Multiply) (None, 10, 600) 0 bidirectional_2[0][0] \n", + " lambda_6[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_7 (Multiply) (None, 10, 600) 0 bidirectional_2[0][0] \n", + " lambda_8[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_8 (Multiply) (None, 40, 600) 0 bidirectional_2[1][0] \n", + " lambda_10[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_9 (Multiply) (None, 40, 600) 0 bidirectional_2[1][0] \n", + " lambda_12[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_7 (Lambda) (None, 600) 0 multiply_6[0][0] \n", + " lambda_6[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_9 (Lambda) (None, 600) 0 multiply_7[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_11 (Lambda) (None, 600) 0 multiply_8[0][0] \n", + " lambda_10[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_13 (Lambda) (None, 600) 0 multiply_9[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_3 (Concatenate) (None, 1200) 0 lambda_7[0][0] \n", + " lambda_9[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_4 (Concatenate) (None, 1200) 0 lambda_11[0][0] \n", + " lambda_13[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_5 (Concatenate) (None, 2400) 0 concatenate_3[0][0] \n", + " concatenate_4[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_2 (Dense) (None, 300) 720300 concatenate_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_3 (Dense) (None, 2) 602 dropout_1[4][0] \n", + "==================================================================================================\n", + "Total params: 6,256,502\n", + "Trainable params: 4,326,002\n", + "Non-trainable params: 1,930,500\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "# run as classification task\n", + "classification_task = mz.tasks.Classification(num_classes=2)\n", + "classification_task.metrics = 'acc'\n", + "\n", + "model = ESIM()\n", + "model.params['task'] = classification_task\n", + "model.params['mask_value'] = 0\n", + "model.params['input_shapes'] = [[fixed_length_left, ],\n", + " [fixed_length_right, ]]\n", + "model.params['lstm_dim'] = 300\n", + "model.params['embedding_input_dim'] = preprocessor.context['vocab_size']\n", + "model.params['embedding_output_dim'] = 300\n", + "model.params['embedding_trainable'] = False\n", + "model.params['dropout_rate'] = 0.5\n", + "\n", + "model.params['mlp_num_units'] = 300\n", + "model.params['mlp_num_layers'] = 0\n", + "model.params['mlp_num_fan_out'] = 300\n", + "model.params['mlp_activation_func'] = 'tanh'\n", + "model.params['optimizer'] = Adam(lr=4e-4)\n", + "\n", + "model.guess_and_fill_missing_params()\n", + "model.build()\n", + "\n", + "model.compile()\n", + "model.backend.summary() # not visualize" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 8627 samples, validate on 1130 samples\n", + "Epoch 1/5\n", + "8627/8627 [==============================] - 48s 6ms/step - loss: 0.3607 - val_loss: 0.3330\n", + "Validation: categorical_accuracy: 1.0\n", + "Epoch 2/5\n", + "8627/8627 [==============================] - 43s 5ms/step - loss: 0.3273 - val_loss: 0.3490\n", + "Validation: categorical_accuracy: 0.9451327323913574\n", + "Epoch 3/5\n", + "8627/8627 [==============================] - 44s 5ms/step - loss: 0.3096 - val_loss: 0.3498\n", + "Validation: categorical_accuracy: 0.9938052892684937\n", + "Epoch 4/5\n", + "8627/8627 [==============================] - 44s 5ms/step - loss: 0.2970 - val_loss: 0.3170\n", + "Validation: categorical_accuracy: 0.969911515712738\n", + "Epoch 5/5\n", + "8627/8627 [==============================] - 44s 5ms/step - loss: 0.2787 - val_loss: 0.3543\n", + "Validation: categorical_accuracy: 0.8778761029243469\n" + ] + } + ], + "source": [ + "evaluate = mz.callbacks.EvaluateAllMetrics(model,\n", + " x=pred_X,\n", + " y=pred_Y,\n", + " once_every=1,\n", + " batch_size=len(pred_Y))\n", + "\n", + "train_Y = to_categorical(train_Y)\n", + "val_Y = to_categorical(val_Y)\n", + "\n", + "model.load_embedding_matrix(embedding_matrix)\n", + "history = model.fit(x = [train_X['text_left'],\n", + " train_X['text_right']],\n", + " y = train_Y,\n", + " validation_data = (val_X, val_Y),\n", + " batch_size = batch_size,\n", + " epochs = epochs,\n", + " callbacks=[evaluate]\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'categorical_accuracy': 0.8920354}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.evaluate(val_X, val_Y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mz_play", + "language": "python", + "name": "mz_play" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/wikiqa/README.rst b/tutorials/wikiqa/README.rst index 58da266f..a5521c9a 100644 --- a/tutorials/wikiqa/README.rst +++ b/tutorials/wikiqa/README.rst @@ -92,3 +92,20 @@ MatchLSTM 10 dropout_rate 0.5 ==== ==================== ====================================================== +DSSM +#### + +==== =========================== =================================== + .. Name Value +==== =========================== =================================== + 0 model_class + 1 input_shapes [(9645,), (9645,)] + 2 task Ranking Task + 3 optimizer adam + 4 with_multi_layer_perceptron True + 5 mlp_num_units 300 + 6 mlp_num_layers 3 + 7 mlp_num_fan_out 128 + 8 mlp_activation_func relu +==== =========================== =================================== + diff --git a/tutorials/wikiqa/dssm.ipynb b/tutorials/wikiqa/dssm.ipynb index 8eb6e11d..a7545b85 100644 --- a/tutorials/wikiqa/dssm.ipynb +++ b/tutorials/wikiqa/dssm.ipynb @@ -32,24 +32,24 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 2118/2118 [00:00<00:00, 3802.39it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 18841/18841 [00:04<00:00, 3959.06it/s]\n", - "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 822625.79it/s]\n", - "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 597166.86it/s]\n", - "Building Vocabulary from a datapack.: 100%|██████████| 1614998/1614998 [00:00<00:00, 4642343.92it/s]\n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 2118/2118 [00:00<00:00, 2853.90it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 18841/18841 [00:12<00:00, 1456.96it/s]\n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 122/122 [00:00<00:00, 2308.40it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 1115/1115 [00:00<00:00, 2025.86it/s]\n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 237/237 [00:00<00:00, 2678.58it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 2300/2300 [00:01<00:00, 1345.18it/s]\n" + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 2118/2118 [00:00<00:00, 3587.72it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 18841/18841 [00:04<00:00, 4528.13it/s]\n", + "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 592156.77it/s]\n", + "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 432217.30it/s]\n", + "Building Vocabulary from a datapack.: 100%|██████████| 1614998/1614998 [00:00<00:00, 4239505.32it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 2118/2118 [00:00<00:00, 2709.71it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 18841/18841 [00:11<00:00, 1656.57it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 122/122 [00:00<00:00, 1120.91it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 1115/1115 [00:00<00:00, 1895.34it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 237/237 [00:00<00:00, 1910.44it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 2300/2300 [00:01<00:00, 1630.79it/s]\n" ] } ], @@ -62,19 +62,19 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'vocab_unit': ,\n", + "{'vocab_unit': ,\n", " 'vocab_size': 9645,\n", " 'embedding_input_dim': 9645,\n", " 'input_shapes': [(9645,), (9645,)]}" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -85,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -99,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -152,7 +152,9 @@ "model.guess_and_fill_missing_params()\n", "model.build()\n", "model.compile()\n", - "model.backend.summary()" + "model.backend.summary()\n", + "\n", + "append_params_to_readme(model)" ] }, { diff --git a/tutorials/wikiqa/esim.ipynb b/tutorials/wikiqa/esim.ipynb new file mode 100644 index 00000000..042910bc --- /dev/null +++ b/tutorials/wikiqa/esim.ipynb @@ -0,0 +1,524 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using TensorFlow backend.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "matchzoo version 2.1.0\n", + "\n", + "data loading ...\n", + "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n", + "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n", + "loading embedding ...\n", + "embedding loaded as `glove_embedding`\n" + ] + } + ], + "source": [ + "%run ./tutorials/wikiqa/init.ipynb" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "from keras.backend.tensorflow_backend import set_session\n", + "config = tf.ConfigProto()\n", + "config.gpu_options.visible_device_list=\"1\"\n", + "config.gpu_options.allow_growth = True # dynamically grow the memory used on the GPU\n", + "sess = tf.Session(config=config)\n", + "set_session(sess) # set this TensorFlow session as the default session for Keras" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def load_filtered_data(preprocessor, data_type):\n", + " assert ( data_type in ['train', 'dev', 'test'])\n", + " data_pack = mz.datasets.wiki_qa.load_data(data_type, task='ranking')\n", + "\n", + " if data_type == 'train':\n", + " X, Y = preprocessor.fit_transform(data_pack).unpack()\n", + " else:\n", + " X, Y = preprocessor.transform(data_pack).unpack()\n", + "\n", + " new_idx = []\n", + " for i in range(Y.shape[0]):\n", + " if X[\"length_left\"][i] == 0 or X[\"length_right\"][i] == 0:\n", + " continue\n", + " new_idx.append(i)\n", + " new_idx = np.array(new_idx)\n", + " print(\"Removed empty data. Found \", (Y.shape[0] - new_idx.shape[0]))\n", + "\n", + " for k in X.keys():\n", + " X[k] = X[k][new_idx]\n", + " Y = Y[new_idx]\n", + "\n", + " pos_idx = (Y == 1)[:, 0]\n", + " pos_qid = X[\"id_left\"][pos_idx]\n", + " keep_idx_bool = np.array([ qid in pos_qid for qid in X[\"id_left\"]])\n", + " keep_idx = np.arange(keep_idx_bool.shape[0])\n", + " keep_idx = keep_idx[keep_idx_bool]\n", + " print(\"Removed questions with no pos label. Found \", (keep_idx_bool == 0).sum())\n", + "\n", + " print(\"shuffling...\")\n", + " np.random.shuffle(keep_idx)\n", + " for k in X.keys():\n", + " X[k] = X[k][keep_idx]\n", + " Y = Y[keep_idx]\n", + "\n", + " return X, Y, preprocessor" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 12754.26it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:02<00:00, 6500.31it/s]\n", + "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 1215206.55it/s]\n", + "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 185258.28it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 184455.70it/s]\n", + "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 922581.36it/s]\n", + "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 1082236.12it/s]\n", + "Building Vocabulary from a datapack.: 100%|██████████| 404432/404432 [00:00<00:00, 3795031.47it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 13650.60it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:02<00:00, 6764.51it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 171037.31it/s]\n", + "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 288623.28it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 90725.37it/s]\n", + "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 583636.81it/s]\n", + "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 1203693.44it/s]\n", + "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 193145.54it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 134549.60it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Removed empty data. Found 38\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 296/296 [00:00<00:00, 14135.26it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 0%| | 0/2708 [00:00 Lowercase => PuncRemoval: 100%|██████████| 2708/2708 [00:00<00:00, 6731.87it/s]\n", + "Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 168473.93it/s]\n", + "Processing text_left with transform: 100%|██████████| 296/296 [00:00<00:00, 204701.40it/s]\n", + "Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 159066.95it/s]\n", + "Processing length_left with len: 100%|██████████| 296/296 [00:00<00:00, 442607.48it/s]\n", + "Processing length_right with len: 100%|██████████| 2708/2708 [00:00<00:00, 1038699.15it/s]\n", + "Processing text_left with transform: 100%|██████████| 296/296 [00:00<00:00, 149130.81it/s]\n", + "Processing text_right with transform: 100%|██████████| 2708/2708 [00:00<00:00, 140864.36it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 633/633 [00:00<00:00, 12189.39it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Removed empty data. Found 2\n", + "Removed questions with no pos label. Found 1601\n", + "shuffling...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 5961/5961 [00:00<00:00, 7064.16it/s]\n", + "Processing text_right with transform: 100%|██████████| 5961/5961 [00:00<00:00, 187399.25it/s]\n", + "Processing text_left with transform: 100%|██████████| 633/633 [00:00<00:00, 259733.36it/s]\n", + "Processing text_right with transform: 100%|██████████| 5961/5961 [00:00<00:00, 160878.23it/s]\n", + "Processing length_left with len: 100%|██████████| 633/633 [00:00<00:00, 688714.51it/s]\n", + "Processing length_right with len: 100%|██████████| 5961/5961 [00:00<00:00, 1166965.98it/s]\n", + "Processing text_left with transform: 100%|██████████| 633/633 [00:00<00:00, 158526.06it/s]\n", + "Processing text_right with transform: 100%|██████████| 5961/5961 [00:00<00:00, 137558.64it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Removed empty data. Found 18\n", + "Removed questions with no pos label. Found 3805\n", + "shuffling...\n" + ] + } + ], + "source": [ + "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=20,\n", + " fixed_length_right=40,\n", + " remove_stop_words=False)\n", + "train_X, train_Y, preprocessor = load_filtered_data(preprocessor, 'train')\n", + "val_X, val_Y, _ = load_filtered_data(preprocessor, 'dev')\n", + "pred_X, pred_Y, _ = load_filtered_data(preprocessor, 'test')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "__________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + "text_left (InputLayer) (None, 20) 0 \n", + "__________________________________________________________________________________________________\n", + "text_right (InputLayer) (None, 40) 0 \n", + "__________________________________________________________________________________________________\n", + "embedding (Embedding) multiple 5002500 text_left[0][0] \n", + " text_right[0][0] \n", + "__________________________________________________________________________________________________\n", + "dropout_1 (Dropout) multiple 0 embedding[0][0] \n", + " embedding[1][0] \n", + " dense_1[0][0] \n", + " dense_1[1][0] \n", + " dense_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_1 (Lambda) multiple 0 text_left[0][0] \n", + " text_right[0][0] \n", + "__________________________________________________________________________________________________\n", + "bidirectional_1 (Bidirectional) multiple 1442400 dropout_1[0][0] \n", + " dropout_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_2 (Lambda) (None, 20, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_3 (Lambda) (None, 40, 1) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_1 (Multiply) (None, 20, 600) 0 bidirectional_1[0][0] \n", + " lambda_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_2 (Multiply) (None, 40, 600) 0 bidirectional_1[1][0] \n", + " lambda_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_4 (Lambda) (None, 20, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_5 (Lambda) (None, 1, 40) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "dot_1 (Dot) (None, 20, 40) 0 multiply_1[0][0] \n", + " multiply_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_3 (Multiply) (None, 20, 40) 0 lambda_4[0][0] \n", + " lambda_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "permute_1 (Permute) (None, 40, 20) 0 dot_1[0][0] \n", + " multiply_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "atten_mask (Lambda) multiple 0 dot_1[0][0] \n", + " multiply_3[0][0] \n", + " permute_1[0][0] \n", + " permute_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "softmax_1 (Softmax) multiple 0 atten_mask[0][0] \n", + " atten_mask[1][0] \n", + "__________________________________________________________________________________________________\n", + "dot_2 (Dot) (None, 20, 600) 0 softmax_1[0][0] \n", + " multiply_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "dot_3 (Dot) (None, 40, 600) 0 softmax_1[1][0] \n", + " multiply_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "subtract_1 (Subtract) (None, 20, 600) 0 multiply_1[0][0] \n", + " dot_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_4 (Multiply) (None, 20, 600) 0 multiply_1[0][0] \n", + " dot_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "subtract_2 (Subtract) (None, 40, 600) 0 multiply_2[0][0] \n", + " dot_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_5 (Multiply) (None, 40, 600) 0 multiply_2[0][0] \n", + " dot_3[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_1 (Concatenate) (None, 20, 2400) 0 multiply_1[0][0] \n", + " dot_2[0][0] \n", + " subtract_1[0][0] \n", + " multiply_4[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_2 (Concatenate) (None, 40, 2400) 0 multiply_2[0][0] \n", + " dot_3[0][0] \n", + " subtract_2[0][0] \n", + " multiply_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_1 (Dense) multiple 720300 concatenate_1[0][0] \n", + " concatenate_2[0][0] \n", + "__________________________________________________________________________________________________\n", + "bidirectional_2 (Bidirectional) multiple 1442400 dropout_1[2][0] \n", + " dropout_1[3][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_6 (Lambda) (None, 20, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_8 (Lambda) (None, 20, 1) 0 lambda_1[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_10 (Lambda) (None, 40, 1) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_12 (Lambda) (None, 40, 1) 0 lambda_1[1][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_6 (Multiply) (None, 20, 600) 0 bidirectional_2[0][0] \n", + " lambda_6[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_7 (Multiply) (None, 20, 600) 0 bidirectional_2[0][0] \n", + " lambda_8[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_8 (Multiply) (None, 40, 600) 0 bidirectional_2[1][0] \n", + " lambda_10[0][0] \n", + "__________________________________________________________________________________________________\n", + "multiply_9 (Multiply) (None, 40, 600) 0 bidirectional_2[1][0] \n", + " lambda_12[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_7 (Lambda) (None, 600) 0 multiply_6[0][0] \n", + " lambda_6[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_9 (Lambda) (None, 600) 0 multiply_7[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_11 (Lambda) (None, 600) 0 multiply_8[0][0] \n", + " lambda_10[0][0] \n", + "__________________________________________________________________________________________________\n", + "lambda_13 (Lambda) (None, 600) 0 multiply_9[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_3 (Concatenate) (None, 1200) 0 lambda_7[0][0] \n", + " lambda_9[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_4 (Concatenate) (None, 1200) 0 lambda_11[0][0] \n", + " lambda_13[0][0] \n", + "__________________________________________________________________________________________________\n", + "concatenate_5 (Concatenate) (None, 2400) 0 concatenate_3[0][0] \n", + " concatenate_4[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_2 (Dense) (None, 300) 720300 concatenate_5[0][0] \n", + "__________________________________________________________________________________________________\n", + "dense_3 (Dense) (None, 2) 602 dropout_1[4][0] \n", + "==================================================================================================\n", + "Total params: 9,328,502\n", + "Trainable params: 4,326,002\n", + "Non-trainable params: 5,002,500\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "from keras.optimizers import Adam\n", + "import matchzoo\n", + "\n", + "model = matchzoo.contrib.models.ESIM()\n", + "\n", + "# update `input_shapes` and `embedding_input_dim`\n", + "# model.params['task'] = mz.tasks.Ranking() \n", + "# or \n", + "model.params['task'] = mz.tasks.Classification(num_classes=2)\n", + "model.params.update(preprocessor.context)\n", + "\n", + "model.params['mask_value'] = 0\n", + "model.params['lstm_dim'] = 300\n", + "model.params['embedding_output_dim'] = 300\n", + "model.params['embedding_trainable'] = False\n", + "model.params['dropout_rate'] = 0.5\n", + "\n", + "model.params['mlp_num_units'] = 300\n", + "model.params['mlp_num_layers'] = 0\n", + "model.params['mlp_num_fan_out'] = 300\n", + "model.params['mlp_activation_func'] = 'tanh'\n", + "model.params['optimizer'] = Adam(lr=1e-4)\n", + "model.guess_and_fill_missing_params()\n", + "model.build()\n", + "model.compile()\n", + "model.backend.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'], initializer=lambda: 0)\n", + "model.load_embedding_matrix(embedding_matrix)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 8650 samples, validate on 1130 samples\n", + "Epoch 1/10\n", + "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0985 - val_loss: 0.0977\n", + "Validation: mean_average_precision(0.0): 0.6377925262180991\n", + "Epoch 2/10\n", + "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0947 - val_loss: 0.0939\n", + "Validation: mean_average_precision(0.0): 0.6323746460063332\n", + "Epoch 3/10\n", + "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0923 - val_loss: 0.0896\n", + "Validation: mean_average_precision(0.0): 0.6447892278707743\n", + "Epoch 4/10\n", + "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0895 - val_loss: 0.0904\n", + "Validation: mean_average_precision(0.0): 0.6645210508066117\n", + "Epoch 5/10\n", + "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0883 - val_loss: 0.0900\n", + "Validation: mean_average_precision(0.0): 0.6622282952529867\n", + "Epoch 6/10\n", + "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0839 - val_loss: 0.0900\n", + "Validation: mean_average_precision(0.0): 0.6654279587941297\n", + "Epoch 7/10\n", + "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0821 - val_loss: 0.0896\n", + "Validation: mean_average_precision(0.0): 0.6668269018575894\n", + "Epoch 8/10\n", + "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0792 - val_loss: 0.0885\n", + "Validation: mean_average_precision(0.0): 0.6723704781393599\n", + "Epoch 9/10\n", + "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0754 - val_loss: 0.0895\n", + "Validation: mean_average_precision(0.0): 0.6552521148587158\n", + "Epoch 10/10\n", + "8650/8650 [==============================] - 52s 6ms/step - loss: 0.0731 - val_loss: 0.0910\n", + "Validation: mean_average_precision(0.0): 0.6695447388956829\n" + ] + } + ], + "source": [ + "# train as ranking task\n", + "model.params['task'] = mz.tasks.Ranking()\n", + "evaluate = mz.callbacks.EvaluateAllMetrics(model,\n", + " x=pred_X,\n", + " y=pred_Y,\n", + " once_every=1,\n", + " batch_size=len(pred_Y))\n", + "history = model.fit(x = [train_X['text_left'],\n", + " train_X['text_right']], # (20360, 1000)\n", + " y = train_Y, # (20360, 2)\n", + " validation_data = (val_X, val_Y),\n", + " callbacks=[evaluate],\n", + " batch_size = 32,\n", + " epochs = 10)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 8650 samples, validate on 1130 samples\n", + "Epoch 1/10\n", + "8650/8650 [==============================] - 68s 8ms/step - loss: 0.3628 - val_loss: 0.3552\n", + "Epoch 2/10\n", + "8650/8650 [==============================] - 63s 7ms/step - loss: 0.3285 - val_loss: 0.3591\n", + "Epoch 3/10\n", + "8650/8650 [==============================] - 63s 7ms/step - loss: 0.3105 - val_loss: 0.3681\n", + "Epoch 4/10\n", + "8650/8650 [==============================] - 64s 7ms/step - loss: 0.3012 - val_loss: 0.3166\n", + "Epoch 5/10\n", + "8650/8650 [==============================] - 64s 7ms/step - loss: 0.2888 - val_loss: 0.2961\n", + "Epoch 6/10\n", + "8650/8650 [==============================] - 64s 7ms/step - loss: 0.2801 - val_loss: 0.3362\n", + "Epoch 7/10\n", + "8650/8650 [==============================] - 64s 7ms/step - loss: 0.2692 - val_loss: 0.3324\n", + "Epoch 8/10\n", + "8650/8650 [==============================] - 64s 7ms/step - loss: 0.2609 - val_loss: 0.3172\n", + "Epoch 9/10\n", + "8650/8650 [==============================] - 58s 7ms/step - loss: 0.2542 - val_loss: 0.3296\n", + "Epoch 10/10\n", + "8650/8650 [==============================] - 53s 6ms/step - loss: 0.2365 - val_loss: 0.3058\n" + ] + } + ], + "source": [ + "# train as classification task \n", + "\n", + "from keras.utils import to_categorical\n", + "train_Y = to_categorical(train_Y)\n", + "val_Y = to_categorical(val_Y)\n", + "\n", + "model.params['task'] = mz.tasks.Classification(num_classes=2)\n", + "\n", + "history = model.fit(x = [train_X['text_left'],\n", + " train_X['text_right']], # (20360, 1000)\n", + " y = train_Y, # (20360, 2)\n", + " validation_data = (val_X, val_Y),\n", + " batch_size = 32,\n", + " epochs = 10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mz_play", + "language": "python", + "name": "mz_play" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}