From 12f1702d937e06d6095cedd92a759e8fe6d36bd8 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 19:34:54 +0000 Subject: [PATCH 01/16] add package structure --- .github/workflows/ci.yml | 86 ++ .gitignore | 165 ++- .pre-commit-config.yaml | 6 - .readthedocs.yml | 23 + CONTRIBUTING.md | 96 ++ LICENSE | 28 + README.md | 30 +- docs/authors.rst | 8 + docs/baseline_models.rst | 12 + docs/conf.py | 61 + docs/index.rst | 36 + docs/seqsignet_models.rst | 17 + docs/sw_attention.rst | 12 + docs/swnu.rst | 12 + .../AnnoMI/anno_mi-client-baseline-bert.py | 45 +- .../AnnoMI/anno_mi-client-baseline-bilstm.py | 18 +- .../anno_mi-client-baseline-ffn-history.py | 27 +- .../AnnoMI/anno_mi-client-baseline-ffn.py | 27 +- ...lient-seqsignet-attention-bilstm-script.py | 22 +- ...ient-seqsignet-attention-encoder-script.py | 18 +- .../AnnoMI/anno_mi-client-seqsignet-script.py | 14 +- .../AnnoMI/anno_mi-client-swmhau-script.py | 18 +- .../AnnoMI/anno_mi-client-swnu-script.py | 18 +- notebooks/AnnoMI/anno_mi-client.ipynb | 2 +- ...seqsignet-attention-encoder-script copy.py | 123 ++ .../anno_mi-client-seqsignet-script copy | 101 ++ .../anno_mi-client-swmhau-script copy | 101 ++ .../anno_mi-client-swnu-script copy | 102 ++ notebooks/Rumours/rumours-baseline-bert.py | 35 +- notebooks/Rumours/rumours-baseline-bilstm.py | 12 +- .../Rumours/rumours-baseline-ffn-history.py | 21 +- notebooks/Rumours/rumours-baseline-ffn.py | 20 +- ...mours-seqsignet-attention-bilstm-script.py | 16 +- ...ours-seqsignet-attention-encoder-script.py | 12 +- notebooks/Rumours/rumours-seqsignet-script.py | 10 +- notebooks/Rumours/rumours-swmhau-script.py | 12 +- notebooks/Rumours/rumours-swnu-script.py | 12 +- notebooks/feed-forward-mnist.ipynb | 8 +- noxfile.py | 91 ++ poetry.lock | 1077 ----------------- pyproject.toml | 140 ++- src/sig_networks/__init__.py | 3 + .../sig_networks}/feature_concatenation.py | 2 +- .../sig_networks}/ffn_baseline.py | 2 +- .../sig_networks}/focal_loss.py | 11 +- .../sig_networks}/huggingface_loader.py | 5 +- .../sig_networks}/lstm_baseline.py | 2 +- .../sig_networks}/pytorch_utils.py | 64 +- .../sig_networks/scripts}/__init__.py | 0 .../scripts/ffn_baseline_functions.py | 6 +- .../scripts/fine_tune_bert_classification.py | 14 +- .../sig_networks}/scripts/implement_model.py | 4 +- .../scripts/lstm_baseline_functions.py | 6 +- .../seqsignet_attention_bilstm_functions.py | 8 +- .../seqsignet_attention_encoder_functions.py | 8 +- .../scripts/seqsignet_functions.py | 6 +- .../scripts/swmhau_network_functions.py | 8 +- .../scripts/swnu_network_functions.py | 8 +- .../seqsignet_attention_bilstm.py | 8 +- .../seqsignet_attention_encoder.py | 10 +- .../sig_networks}/seqsignet_bilstm.py | 8 +- .../specific_classification_utils.py | 6 +- .../sig_networks}/swmhau.py | 10 +- .../sig_networks}/swmhau_network.py | 8 +- {nlpsig_networks => src/sig_networks}/swnu.py | 13 +- .../sig_networks}/swnu_network.py | 8 +- .../sig_networks}/utils.py | 0 67 files changed, 1512 insertions(+), 1410 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 .readthedocs.yml create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 docs/authors.rst create mode 100644 docs/baseline_models.rst create mode 100644 docs/conf.py create mode 100644 docs/index.rst create mode 100644 docs/seqsignet_models.rst create mode 100644 docs/sw_attention.rst create mode 100644 docs/swnu.rst create mode 100644 notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-attention-encoder-script copy.py create mode 100644 notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-script copy create mode 100644 notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swmhau-script copy create mode 100644 notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swnu-script copy create mode 100644 noxfile.py delete mode 100644 poetry.lock create mode 100644 src/sig_networks/__init__.py rename {nlpsig_networks => src/sig_networks}/feature_concatenation.py (99%) rename {nlpsig_networks => src/sig_networks}/ffn_baseline.py (97%) rename {nlpsig_networks => src/sig_networks}/focal_loss.py (95%) rename {nlpsig_networks => src/sig_networks}/huggingface_loader.py (97%) rename {nlpsig_networks => src/sig_networks}/lstm_baseline.py (98%) rename {nlpsig_networks => src/sig_networks}/pytorch_utils.py (96%) rename {nlpsig_networks => src/sig_networks/scripts}/__init__.py (100%) rename {nlpsig_networks => src/sig_networks}/scripts/ffn_baseline_functions.py (99%) rename {nlpsig_networks => src/sig_networks}/scripts/fine_tune_bert_classification.py (99%) rename {nlpsig_networks => src/sig_networks}/scripts/implement_model.py (99%) rename {nlpsig_networks => src/sig_networks}/scripts/lstm_baseline_functions.py (99%) rename {nlpsig_networks => src/sig_networks}/scripts/seqsignet_attention_bilstm_functions.py (99%) rename {nlpsig_networks => src/sig_networks}/scripts/seqsignet_attention_encoder_functions.py (99%) rename {nlpsig_networks => src/sig_networks}/scripts/seqsignet_functions.py (99%) rename {nlpsig_networks => src/sig_networks}/scripts/swmhau_network_functions.py (99%) rename {nlpsig_networks => src/sig_networks}/scripts/swnu_network_functions.py (99%) rename {nlpsig_networks => src/sig_networks}/seqsignet_attention_bilstm.py (96%) rename {nlpsig_networks => src/sig_networks}/seqsignet_attention_encoder.py (96%) rename {nlpsig_networks => src/sig_networks}/seqsignet_bilstm.py (97%) rename {nlpsig_networks => src/sig_networks}/specific_classification_utils.py (98%) rename {nlpsig_networks => src/sig_networks}/swmhau.py (98%) rename {nlpsig_networks => src/sig_networks}/swmhau_network.py (96%) rename {nlpsig_networks => src/sig_networks}/swnu.py (97%) rename {nlpsig_networks => src/sig_networks}/swnu_network.py (96%) rename {nlpsig_networks => src/sig_networks}/utils.py (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..b2c0984 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,86 @@ +name: CI + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + - develop + release: + types: + - published + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + FORCE_COLOR: 3 + +jobs: + pre-commit: + name: Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: "3.x" + - uses: pre-commit/action@v3.0.0 + with: + extra_args: --hook-stage manual --all-files + + checks: + name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} + runs-on: ${{ matrix.runs-on }} + needs: [pre-commit] + strategy: + fail-fast: false + matrix: + python-version: ["3.8"] + runs-on: [ubuntu-latest, macos-latest] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package + run: python -m pip install .[test] + + - name: Test package + run: python -m pytest -ra --cov=nlpsig + + dist: + name: Distribution build + runs-on: ubuntu-latest + needs: [pre-commit] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Build sdist and wheel + run: pipx run build + + - uses: actions/upload-artifact@v3 + with: + path: dist + + - name: Check products + run: pipx run twine check dist/* + + - uses: pypa/gh-action-pypi-publish@v1.8.10 + if: github.event_name == 'release' && github.event.action == 'published' + with: + # Remember to generate this and set it in "GitHub Secrets" + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore index 616bb5c..e640039 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,167 @@ -*.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# setuptools_scm +src/*/_version.py + + +# ruff +.ruff_cache/ + +# OS specific stuff +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Common editor files +*~ +*.swp + +.vscode/ + +# Miscellaneous *.ipynb_checkpoints/ **__pycache__ *.npy *.pkl -paths.py \ No newline at end of file +paths.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cbc9b24..ae4190d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,12 +44,6 @@ repos: - id: blacken-docs additional_dependencies: [black==23.1.0] - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.0.260" - hooks: - - id: ruff - args: ["--fix", "--show-fixes"] - - repo: https://github.com/shellcheck-py/shellcheck-py rev: "v0.9.0.2" hooks: diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000..268f4d8 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,23 @@ +# .readthedocs.yml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.8" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +python: + install: + - method: pip + path: . + extra_requirements: + - docs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..52a93a9 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,96 @@ +# Quick development + +The fastest way to start with development is to use nox. If you don't have nox, +you can use `pipx run nox` to run it without installing, or `pipx install nox`. +If you don't have pipx (pip for applications), then you can install with with +`pip install pipx` (the only case were installing an application with regular +pip is reasonable). If you use macOS, then pipx and nox are both in brew, use +`brew install pipx nox`. + +To use, run `nox`. This will lint and test using every installed version of +Python on your system, skipping ones that are not installed. You can also run +specific jobs: + +```console +$ nox -s lint # Lint only +$ nox -s tests # Python tests +$ nox -s docs -- serve # Build and serve the docs +$ nox -s build # Make an SDist and wheel +``` + +Nox handles everything for you, including setting up an temporary virtual +environment for each run. + +# Setting up a development environment manually + +You can set up a development environment by running: + +```bash +python3 -m venv .venv +source ./.venv/bin/activate +pip install -v -e .[dev] +``` + +If you have the +[Python Launcher for Unix](https://github.com/brettcannon/python-launcher), you +can instead do: + +```bash +py -m venv .venv +py -m install -v -e .[dev] +``` + +# Post setup + +You should prepare pre-commit, which will help you by checking that commits pass +required checks: + +```bash +pip install pre-commit # or brew install pre-commit on macOS +pre-commit install # Will install a pre-commit hook into the git repo +``` + +You can also/alternatively run `pre-commit run` (changes only) or +`pre-commit run --all-files` to check even without installing the hook. + +# Testing + +Use pytest to run the unit checks: + +```bash +pytest +``` + +# Coverage + +Use pytest-cov to generate coverage reports: + +```bash +pytest --cov=nlpsig +``` + +# Building docs + +You can build the docs using: + +```bash +nox -s docs +``` + +You can see a preview with: + +```bash +nox -s docs -- serve +``` + +# Pre-commit + +This project uses pre-commit for all style checking. While you can run it with +nox, this is such an important tool that it deserves to be installed on its own. +Install pre-commit and run: + +```bash +pre-commit run -a +``` + +to check all files. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e7995db --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) Talia Tseriotou, Ryan Chan. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the vector package developers nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index 52cc5d1..462d9cf 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,18 @@ -# nlpsig-networks +# sig-networks -

-Model architecture -

+_sig-networks_ is a package for training and evaluating neural networks for +longitudinal NLP classification tasks. -Work in progress repo. +## Installation -First create an environment and install -[`nlpsig`](https://github.com/datasig-ac-uk/nlpsig), and then install this -package afterwards: +... -``` -git clone git@github.com:datasig-ac-uk/nlpsig.git -git clone git@github.com:ttseriotou/nlpsig-networks.git -cd nlpsig -conda env create --name nlpsig-networks python=3.8 -conda activate nlpsig-networks -pip install -e . -cd ../nlpsig-networks -pip install -e . -``` +## Usage + +The library is still under development but it is possible to train and evaluate +several models in a few lines of code. -If you want to install a development version, use `pip install -e .` instead in -the above. +... ## Pre-commit and linters diff --git a/docs/authors.rst b/docs/authors.rst new file mode 100644 index 0000000..1a678b1 --- /dev/null +++ b/docs/authors.rst @@ -0,0 +1,8 @@ +Authors / collaborators +======================= + +``sig-networks`` is a library that applies models first developed in `Sequential Path Signature Networks for Personalised Longitudinal Language Modeling `_ by Tseriotou et al. (2023) which presented a novel extension of neural sequential models using the notion of path signatures from rough path theory. The library is developed in collaboration with the `Research Engineering Team `_ at `The Alan Turing Institute `_, and was originally written by `Talia Tseriotou `_ and `Ryan Chan `_. + +We also thank `Kasra Hosseini `_ for his early contributions to the library, and to `Adam Tsakalidis `_, `Maria Likata `_, and `Terry Lyons `_ for their valuable guidance and advice. + +For an up-to-date list of collaborators to the library, please see the `GitHub repo `_ for this package. diff --git a/docs/baseline_models.rst b/docs/baseline_models.rst new file mode 100644 index 0000000..2050d58 --- /dev/null +++ b/docs/baseline_models.rst @@ -0,0 +1,12 @@ +Baseline Models +=============== + +.. automodule:: sig_networks.ffn_baseline + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: sig_networks.lstm_baseline + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..1972f53 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,61 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +from __future__ import annotations + +# Warning: do not change the path here. To use autodoc, you need to install the +# package first. + +# -- Project information ----------------------------------------------------- + +project = "sig-networks" +copyright = "2023, Talia Tseriotou" +author = "Talia Tseriotou, Ryan Chan" + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "myst_parser", + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx_copybutton", +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = [] + +# Include both markdown and rst files +source_suffix = [".rst", ".md"] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ["_build", "**.ipynb_checkpoints", "Thumbs.db", ".DS_Store", ".env"] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "furo" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path: list[str] = [] + + +# -- Extension configuration ------------------------------------------------- +myst_enable_extensions = [ + "colon_fence", + "deflist", +] diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..b990ede --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,36 @@ + +The ``sig-networks`` Python package +====================================== + +``sig-networks`` +------------------- + +The ``sig-networks`` library provides functionality for training and evaluating neural networks for longitudinal NLP classification tasks. + +.. toctree:: + :maxdepth: 2 + :titlesonly: + :caption: API + :glob: + + baseline_models.rst + swnu.rst + sw_attention.rst + seqsignet_models.rst + +.. toctree:: + :maxdepth: 2 + :titlesonly: + :caption: About + :glob: + + authors.rst + + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/seqsignet_models.rst b/docs/seqsignet_models.rst new file mode 100644 index 0000000..4751359 --- /dev/null +++ b/docs/seqsignet_models.rst @@ -0,0 +1,17 @@ +SeqSigNet Models +================ + +.. automodule:: sig_networks.seqsignet_bilstm + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: sig_networks.seqsignet_attention_encoder + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: sig_networks.seqsignet_attention_bilstm + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/sw_attention.rst b/docs/sw_attention.rst new file mode 100644 index 0000000..69d44e6 --- /dev/null +++ b/docs/sw_attention.rst @@ -0,0 +1,12 @@ +Signature Window Attention Units +================================ + +.. automodule:: sig_networks.swmhau + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: sig_networks.swmhau_network + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/swnu.rst b/docs/swnu.rst new file mode 100644 index 0000000..26823c6 --- /dev/null +++ b/docs/swnu.rst @@ -0,0 +1,12 @@ +Signature Window Network Units +============================== + +.. automodule:: sig_networks.swnu + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: sig_networks.swnu_network + :members: + :undoc-members: + :show-inheritance: diff --git a/notebooks/AnnoMI/anno_mi-client-baseline-bert.py b/notebooks/AnnoMI/anno_mi-client-baseline-bert.py index 83da6b2..092d301 100644 --- a/notebooks/AnnoMI/anno_mi-client-baseline-bert.py +++ b/notebooks/AnnoMI/anno_mi-client-baseline-bert.py @@ -1,20 +1,22 @@ -import numpy as np -import pickle +from __future__ import annotations + import os +import pickle + +import numpy as np import torch import transformers - -from nlpsig_networks.scripts.fine_tune_bert_classification import ( - fine_tune_transformer_average_seed, -) - from load_anno_mi import ( anno_mi, - output_dim_client, client_index, client_transcript_id, - label_to_id_client, id_to_label_client, + label_to_id_client, + output_dim_client, +) + +from sig_networks.scripts.fine_tune_bert_classification import ( + fine_tune_transformer_average_seed, ) # set to only report critical errors to avoid excessing logging @@ -65,31 +67,26 @@ gamma = 2 for loss in ["focal", "cross_entropy"]: if loss == "focal": - results_output=f"{output_dir}/bert_classifier_focal.csv", + results_output = (f"{output_dir}/bert_classifier_focal.csv",) else: - results_output=f"{output_dir}/bert_classifier_ce.csv" - + results_output = f"{output_dir}/bert_classifier_ce.csv" + bert_classifier, best_bert_classifier, _, __ = fine_tune_transformer_average_seed( loss=loss, gamma=gamma, results_output=results_output, **kwargs, ) - + print(f"F1: {best_bert_classifier['f1'].mean()}") - print( - f"Precision: {best_bert_classifier['precision'].mean()}" - ) + print(f"Precision: {best_bert_classifier['precision'].mean()}") print(f"Recall: {best_bert_classifier['recall'].mean()}") + print(f"F1 scores: {np.stack(best_bert_classifier['f1_scores']).mean(axis=0)}") print( - "F1 scores: " - f"{np.stack(best_bert_classifier['f1_scores']).mean(axis=0)}" - ) - print( - "Precision scores: " - f"{np.stack(best_bert_classifier['precision_scores']).mean(axis=0)}" + "Precision scores: " + f"{np.stack(best_bert_classifier['precision_scores']).mean(axis=0)}" ) print( - "Recall scores: " - f"{np.stack(best_bert_classifier['recall_scores']).mean(axis=0)}" + "Recall scores: " + f"{np.stack(best_bert_classifier['recall_scores']).mean(axis=0)}" ) diff --git a/notebooks/AnnoMI/anno_mi-client-baseline-bilstm.py b/notebooks/AnnoMI/anno_mi-client-baseline-bilstm.py index 5364652..16c193f 100644 --- a/notebooks/AnnoMI/anno_mi-client-baseline-bilstm.py +++ b/notebooks/AnnoMI/anno_mi-client-baseline-bilstm.py @@ -1,16 +1,20 @@ -import numpy as np -import pickle +from __future__ import annotations + import os +import pickle + +import numpy as np import torch -from nlpsig_networks.scripts.lstm_baseline_functions import ( - lstm_hyperparameter_search, -) from load_anno_mi import ( anno_mi, - y_data_client, - output_dim_client, client_index, client_transcript_id, + output_dim_client, + y_data_client, +) + +from sig_networks.scripts.lstm_baseline_functions import ( + lstm_hyperparameter_search, ) seed = 2023 diff --git a/notebooks/AnnoMI/anno_mi-client-baseline-ffn-history.py b/notebooks/AnnoMI/anno_mi-client-baseline-ffn-history.py index d737ca3..e0632bf 100644 --- a/notebooks/AnnoMI/anno_mi-client-baseline-ffn-history.py +++ b/notebooks/AnnoMI/anno_mi-client-baseline-ffn-history.py @@ -1,16 +1,20 @@ -import numpy as np -import pickle +from __future__ import annotations + import os +import pickle + +import numpy as np import torch -from nlpsig_networks.scripts.ffn_baseline_functions import ( - histories_baseline_hyperparameter_search, -) from load_anno_mi import ( anno_mi, - y_data_client, - output_dim_client, client_index, client_transcript_id, + output_dim_client, + y_data_client, +) + +from sig_networks.scripts.ffn_baseline_functions import ( + histories_baseline_hyperparameter_search, ) seed = 2023 @@ -75,14 +79,9 @@ ) print(f"F1: {best_ffn_mean_history_kfold['f1'].mean()}") -print( - f"Precision: {best_ffn_mean_history_kfold['precision'].mean()}" -) +print(f"Precision: {best_ffn_mean_history_kfold['precision'].mean()}") print(f"Recall: {best_ffn_mean_history_kfold['recall'].mean()}") -print( - "F1 scores: " - f"{np.stack(best_ffn_mean_history_kfold['f1_scores']).mean(axis=0)}" -) +print(f"F1 scores: {np.stack(best_ffn_mean_history_kfold['f1_scores']).mean(axis=0)}") print( "Precision scores: " f"{np.stack(best_ffn_mean_history_kfold['precision_scores']).mean(axis=0)}" diff --git a/notebooks/AnnoMI/anno_mi-client-baseline-ffn.py b/notebooks/AnnoMI/anno_mi-client-baseline-ffn.py index 4ee8cb6..62b56c9 100644 --- a/notebooks/AnnoMI/anno_mi-client-baseline-ffn.py +++ b/notebooks/AnnoMI/anno_mi-client-baseline-ffn.py @@ -1,15 +1,19 @@ -import numpy as np -import pickle +from __future__ import annotations + import os +import pickle + +import numpy as np import torch -from nlpsig_networks.scripts.ffn_baseline_functions import ( - ffn_hyperparameter_search, -) from load_anno_mi import ( - y_data_client, - output_dim_client, client_index, client_transcript_id, + output_dim_client, + y_data_client, +) + +from sig_networks.scripts.ffn_baseline_functions import ( + ffn_hyperparameter_search, ) seed = 2023 @@ -64,14 +68,9 @@ ) print(f"F1: {best_ffn_current_kfold['f1'].mean()}") -print( - f"Precision: {best_ffn_current_kfold['precision'].mean()}" -) +print(f"Precision: {best_ffn_current_kfold['precision'].mean()}") print(f"Recall: {best_ffn_current_kfold['recall'].mean()}") -print( - "F1 scores: " - f"{np.stack(best_ffn_current_kfold['f1_scores']).mean(axis=0)}" -) +print("F1 scores: " f"{np.stack(best_ffn_current_kfold['f1_scores']).mean(axis=0)}") print( "Precision scores: " f"{np.stack(best_ffn_current_kfold['precision_scores']).mean(axis=0)}" diff --git a/notebooks/AnnoMI/anno_mi-client-seqsignet-attention-bilstm-script.py b/notebooks/AnnoMI/anno_mi-client-seqsignet-attention-bilstm-script.py index 4f5ab93..1e55d2d 100644 --- a/notebooks/AnnoMI/anno_mi-client-seqsignet-attention-bilstm-script.py +++ b/notebooks/AnnoMI/anno_mi-client-seqsignet-attention-bilstm-script.py @@ -1,16 +1,20 @@ -import numpy as np -import pickle +from __future__ import annotations + import os +import pickle + +import numpy as np import torch -from nlpsig_networks.scripts.seqsignet_attention_bilstm_functions import ( - seqsignet_attention_bilstm_hyperparameter_search, -) from load_anno_mi import ( anno_mi, - y_data_client, - output_dim_client, client_index, client_transcript_id, + output_dim_client, + y_data_client, +) + +from sig_networks.scripts.seqsignet_attention_bilstm_functions import ( + seqsignet_attention_bilstm_hyperparameter_search, ) seed = 2023 @@ -104,9 +108,7 @@ ) print(f"F1: {best_seqsignet_network_umap_kfold['f1'].mean()}") - print( - f"Precision: {best_seqsignet_network_umap_kfold['precision'].mean()}" - ) + print(f"Precision: {best_seqsignet_network_umap_kfold['precision'].mean()}") print(f"Recall: {best_seqsignet_network_umap_kfold['recall'].mean()}") print( "F1 scores: " diff --git a/notebooks/AnnoMI/anno_mi-client-seqsignet-attention-encoder-script.py b/notebooks/AnnoMI/anno_mi-client-seqsignet-attention-encoder-script.py index 30c1a05..0a03267 100644 --- a/notebooks/AnnoMI/anno_mi-client-seqsignet-attention-encoder-script.py +++ b/notebooks/AnnoMI/anno_mi-client-seqsignet-attention-encoder-script.py @@ -1,16 +1,20 @@ -import numpy as np -import pickle +from __future__ import annotations + import os +import pickle + +import numpy as np import torch -from nlpsig_networks.scripts.seqsignet_attention_encoder_functions import ( - seqsignet_attention_encoder_hyperparameter_search, -) from load_anno_mi import ( anno_mi, - y_data_client, - output_dim_client, client_index, client_transcript_id, + output_dim_client, + y_data_client, +) + +from sig_networks.scripts.seqsignet_attention_encoder_functions import ( + seqsignet_attention_encoder_hyperparameter_search, ) seed = 2023 diff --git a/notebooks/AnnoMI/anno_mi-client-seqsignet-script.py b/notebooks/AnnoMI/anno_mi-client-seqsignet-script.py index 38f2e53..4288f16 100644 --- a/notebooks/AnnoMI/anno_mi-client-seqsignet-script.py +++ b/notebooks/AnnoMI/anno_mi-client-seqsignet-script.py @@ -1,16 +1,20 @@ -import numpy as np -import pickle +from __future__ import annotations + import os +import pickle + +import numpy as np import torch -from nlpsig_networks.scripts.seqsignet_functions import seqsignet_hyperparameter_search from load_anno_mi import ( anno_mi, - y_data_client, - output_dim_client, client_index, client_transcript_id, + output_dim_client, + y_data_client, ) +from sig_networks.scripts.seqsignet_functions import seqsignet_hyperparameter_search + seed = 2023 # set device diff --git a/notebooks/AnnoMI/anno_mi-client-swmhau-script.py b/notebooks/AnnoMI/anno_mi-client-swmhau-script.py index 6b88c05..96e6921 100644 --- a/notebooks/AnnoMI/anno_mi-client-swmhau-script.py +++ b/notebooks/AnnoMI/anno_mi-client-swmhau-script.py @@ -1,16 +1,20 @@ -import numpy as np -import pickle +from __future__ import annotations + import os +import pickle + +import numpy as np import torch -from nlpsig_networks.scripts.swmhau_network_functions import ( - swmhau_network_hyperparameter_search, -) from load_anno_mi import ( anno_mi, - y_data_client, - output_dim_client, client_index, client_transcript_id, + output_dim_client, + y_data_client, +) + +from sig_networks.scripts.swmhau_network_functions import ( + swmhau_network_hyperparameter_search, ) seed = 2023 diff --git a/notebooks/AnnoMI/anno_mi-client-swnu-script.py b/notebooks/AnnoMI/anno_mi-client-swnu-script.py index 1cb4425..e65d8d3 100644 --- a/notebooks/AnnoMI/anno_mi-client-swnu-script.py +++ b/notebooks/AnnoMI/anno_mi-client-swnu-script.py @@ -1,16 +1,20 @@ -import numpy as np -import pickle +from __future__ import annotations + import os +import pickle + +import numpy as np import torch -from nlpsig_networks.scripts.swnu_network_functions import ( - swnu_network_hyperparameter_search, -) from load_anno_mi import ( anno_mi, - y_data_client, - output_dim_client, client_index, client_transcript_id, + output_dim_client, + y_data_client, +) + +from sig_networks.scripts.swnu_network_functions import ( + swnu_network_hyperparameter_search, ) seed = 2023 diff --git a/notebooks/AnnoMI/anno_mi-client.ipynb b/notebooks/AnnoMI/anno_mi-client.ipynb index ee2f1f1..dc29e04 100644 --- a/notebooks/AnnoMI/anno_mi-client.ipynb +++ b/notebooks/AnnoMI/anno_mi-client.ipynb @@ -671,7 +671,7 @@ ], "metadata": { "kernelspec": { - "display_name": "nlpsig-networks", + "display_name": "sig-networks", "language": "python", "name": "python3" }, diff --git a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-attention-encoder-script copy.py b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-attention-encoder-script copy.py new file mode 100644 index 0000000..2025475 --- /dev/null +++ b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-attention-encoder-script copy.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import os +import pickle + +import numpy as np +import torch + +from sig_networks.scripts.seqsignet_attention_encoder_functions import ( + seqsignet_attention_encoder_hyperparameter_search, +) + +from ..load_anno_mi import ( + anno_mi, + client_index, + client_transcript_id, + output_dim_client, + y_data_client, +) + +seed = 2023 + +# set device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print("Device: ", device) + +# set output directory +output_dir = "client_talk_type_output" +if not os.path.isdir(output_dir): + os.makedirs(output_dir) + +# load sbert embeddings +with open("../anno_mi_sbert.pkl", "rb") as f: + sbert_embeddings = pickle.load(f) + +# set features +features = ["time_encoding", "timeline_index"] +standardise_method = ["z_score", None] +include_features_in_path = True +include_features_in_input = True + +# set hyperparameters +num_epochs = 100 +dimensions = [15] +# define swmhau parameters: (output_channels, sig_depth, num_heads) +swmhau_parameters = [(12, 3, 10), (10, 3, 5)] +num_layers = [1] +ffn_hidden_dim_sizes = [[32, 32], [128, 128], [512, 512]] +dropout_rates = [0.1] +learning_rates = [5e-4, 3e-4, 1e-4, 1e-5] +seeds = [1, 12, 123] +loss = "focal" +gamma = 2 +validation_metric = "f1" +patience = 3 + +# set kwargs +kwargs = { + "num_epochs": num_epochs, + "df": anno_mi, + "id_column": "transcript_id", + "label_column": "client_talk_type", + "embeddings": sbert_embeddings, + "y_data": y_data_client, + "output_dim": output_dim_client, + "dimensions": dimensions, + "log_signature": True, + "pooling": "signature", + "transformer_encoder_layers": 2, + "swmhau_parameters": swmhau_parameters, + "num_layers": num_layers, + "ffn_hidden_dim_sizes": ffn_hidden_dim_sizes, + "dropout_rates": dropout_rates, + "learning_rates": learning_rates, + "seeds": seeds, + "loss": loss, + "gamma": gamma, + "device": device, + "features": features, + "standardise_method": standardise_method, + "include_features_in_path": include_features_in_path, + "include_features_in_input": include_features_in_input, + "path_indices": client_index, + "split_ids": client_transcript_id, + "k_fold": True, + "patience": patience, + "validation_metric": validation_metric, + "verbose": False, +} + +# run hyperparameter search +lengths = [(3, 5, 3), (3, 5, 6), (3, 5, 11), (3, 5, 26), (3, 5, 36)] + +for shift, window_size, n in lengths: + print(f"shift: {shift}, window_size: {window_size}, n: {n}") + ( + seqsignet_attention_encoder_umap_kfold, + best_seqsignet_attention_encoder_umap_kfold, + _, + __, + ) = seqsignet_attention_encoder_hyperparameter_search( + shift=shift, + window_size=window_size, + n=n, + dim_reduce_methods=["umap"], + results_output=f"{output_dir}/seqsignet_attention_encoder_umap_focal_{gamma}_{shift}_{window_size}_{n}_kfold.csv", + **kwargs, + ) + + print(f"F1: {best_seqsignet_attention_encoder_umap_kfold['f1'].mean()}") + print( + f"Precision: {best_seqsignet_attention_encoder_umap_kfold['precision'].mean()}" + ) + print(f"Recall: {best_seqsignet_attention_encoder_umap_kfold['recall'].mean()}") + print( + f"F1 scores: {np.stack(best_seqsignet_attention_encoder_umap_kfold['f1_scores']).mean(axis=0)}" + ) + print( + f"Precision scores: {np.stack(best_seqsignet_attention_encoder_umap_kfold['precision_scores']).mean(axis=0)}" + ) + print( + f"Recall scores: {np.stack(best_seqsignet_attention_encoder_umap_kfold['recall_scores']).mean(axis=0)}" + ) diff --git a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-script copy b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-script copy new file mode 100644 index 0000000..8c370d9 --- /dev/null +++ b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-script copy @@ -0,0 +1,101 @@ +import numpy as np +import pickle +import os +import torch +from sig_networks.scripts.seqsignet_functions import seqsignet_hyperparameter_search +from ..load_anno_mi import anno_mi, y_data_client, output_dim_client, client_index, client_transcript_id + +seed = 2023 + +# set device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print("Device: ", device) + +# set output directory +output_dir = "client_talk_type_output" +if not os.path.isdir(output_dir): + os.makedirs(output_dir) + +# load sbert embeddings +with open("../anno_mi_sbert.pkl", "rb") as f: + sbert_embeddings = pickle.load(f) + +# set features +features = ["time_encoding", "timeline_index"] +standardise_method = ["z_score", None] +include_features_in_path = True +include_features_in_input = True + +# set hyperparameters +num_epochs = 100 +dimensions = [15] +swnu_hidden_dim_sizes_and_sig_depths = [([12], 3), ([10], 3)] +lstm_hidden_dim_sizes = [300, 400] +ffn_hidden_dim_sizes = [[32,32], [128,128], [512,512]] +dropout_rates = [0.1] +learning_rates = [5e-4, 3e-4, 1e-4] +seeds = [1, 12, 123] +loss = "focal" +gamma = 2 +validation_metric = "f1" +patience = 3 + +# set kwargs +kwargs = { + "num_epochs": num_epochs, + "df": anno_mi, + "id_column": "transcript_id", + "label_column": "client_talk_type", + "embeddings": sbert_embeddings, + "y_data": y_data_client, + "output_dim": output_dim_client, + "dimensions": dimensions, + "log_signature": True, + "pooling": "signature", + "swnu_hidden_dim_sizes_and_sig_depths": swnu_hidden_dim_sizes_and_sig_depths, + "lstm_hidden_dim_sizes": lstm_hidden_dim_sizes, + "ffn_hidden_dim_sizes": ffn_hidden_dim_sizes, + "dropout_rates": dropout_rates, + "learning_rates": learning_rates, + "BiLSTM": True, + "seeds": seeds, + "loss": loss, + "gamma": gamma, + "device": device, + "features": features, + "standardise_method": standardise_method, + "include_features_in_path": include_features_in_path, + "include_features_in_input": include_features_in_input, + "path_indices": client_index, + "split_ids": client_transcript_id, + "k_fold": True, + "patience": patience, + "validation_metric": validation_metric, + "verbose": False, +} + +# run hyperparameter search +lengths = [(3, 5, 3), (3, 5, 6), (3, 5, 11), (3, 5, 26), (3, 5, 36)] + +for shift, window_size, n in lengths: + print(f"shift: {shift}, window_size: {window_size}, n: {n}") + ( + seqsignet_network_umap_kfold, + best_seqsignet_network_umap_kfold, + _, + __, + ) = seqsignet_hyperparameter_search( + shift=shift, + window_size=window_size, + n=n, + dim_reduce_methods=["umap"], + results_output=f"{output_dir}/seqsignet_umap_focal_{gamma}_{shift}_{window_size}_{n}_kfold.csv", + **kwargs, + ) + + print(f"F1: {best_seqsignet_network_umap_kfold['f1'].mean()}") + print(f"Precision: {best_seqsignet_network_umap_kfold['precision'].mean()}") + print(f"Recall: {best_seqsignet_network_umap_kfold['recall'].mean()}") + print(f"F1 scores: {np.stack(best_seqsignet_network_umap_kfold['f1_scores']).mean(axis=0)}") + print(f"Precision scores: {np.stack(best_seqsignet_network_umap_kfold['precision_scores']).mean(axis=0)}") + print(f"Recall scores: {np.stack(best_seqsignet_network_umap_kfold['recall_scores']).mean(axis=0)}") diff --git a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swmhau-script copy b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swmhau-script copy new file mode 100644 index 0000000..8618fc2 --- /dev/null +++ b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swmhau-script copy @@ -0,0 +1,101 @@ +import numpy as np +import pickle +import os +import torch +from sig_networks.scripts.swmhau_network_functions import ( + swmhau_network_hyperparameter_search, +) +from ..load_anno_mi import anno_mi, y_data_client, output_dim_client, client_index, client_transcript_id + +seed = 2023 + +# set device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print("Device: ", device) + +# set output directory +output_dir = "client_talk_type_output" +if not os.path.isdir(output_dir): + os.makedirs(output_dir) + +# load sbert embeddings +with open("../anno_mi_sbert.pkl", "rb") as f: + sbert_embeddings = pickle.load(f) + +# set features +features = ["time_encoding", "timeline_index"] +standardise_method = ["z_score", None] +include_features_in_path = True +include_features_in_input = True + +# set hyperparameters +num_epochs = 100 +dimensions = [15] +# define swmhau parameters: (output_channels, sig_depth, num_heads) +swmhau_parameters = [(12, 3, 10), (10, 3, 5)] +num_layers = [1] +ffn_hidden_dim_sizes = [[32,32], [128,128], [512,512]] +dropout_rates = [0.1] +learning_rates = [5e-4, 3e-4, 1e-4, 1e-5] +seeds = [1, 12, 123] +loss = "focal" +gamma = 2 +validation_metric = "f1" +patience = 3 + +# set kwargs +kwargs = { + "num_epochs": num_epochs, + "df": anno_mi, + "id_column": "transcript_id", + "label_column": "client_talk_type", + "embeddings": sbert_embeddings, + "y_data": y_data_client, + "output_dim": output_dim_client, + "dimensions": dimensions, + "log_signature": True, + "pooling": "signature", + "swmhau_parameters": swmhau_parameters, + "num_layers": num_layers, + "ffn_hidden_dim_sizes": ffn_hidden_dim_sizes, + "dropout_rates": dropout_rates, + "learning_rates": learning_rates, + "seeds": seeds, + "loss": loss, + "gamma": gamma, + "device": device, + "features": features, + "standardise_method": standardise_method, + "include_features_in_path": include_features_in_path, + "include_features_in_input": include_features_in_input, + "path_indices": client_index, + "split_ids": client_transcript_id, + "k_fold": True, + "patience": patience, + "validation_metric": validation_metric, + "verbose": False, +} + +# run hyperparameter search +lengths = [5, 11, 20, 35, 80, 110] + +for size in lengths: + print(f"history_length: {size}") + ( + swmhau_network_umap_kfold, + best_swmhau_network_umap_kfold, + _, + __, + ) = swmhau_network_hyperparameter_search( + history_lengths=[size], + dim_reduce_methods=["umap"], + results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}_kfold.csv", + **kwargs, + ) + + print(f"F1: {best_swmhau_network_umap_kfold['f1'].mean()}") + print(f"Precision: {best_swmhau_network_umap_kfold['precision'].mean()}") + print(f"Recall: {best_swmhau_network_umap_kfold['recall'].mean()}") + print(f"F1 scores: {np.stack(best_swmhau_network_umap_kfold['f1_scores']).mean(axis=0)}") + print(f"Precision scores: {np.stack(best_swmhau_network_umap_kfold['precision_scores']).mean(axis=0)}") + print(f"Recall scores: {np.stack(best_swmhau_network_umap_kfold['recall_scores']).mean(axis=0)}") diff --git a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swnu-script copy b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swnu-script copy new file mode 100644 index 0000000..c38bf3d --- /dev/null +++ b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swnu-script copy @@ -0,0 +1,102 @@ +import numpy as np +import pickle +import os +import torch +from sig_networks.scripts.swnu_network_functions import ( + swnu_network_hyperparameter_search, +) +from ..load_anno_mi import anno_mi, y_data_client, output_dim_client, client_index, client_transcript_id + +seed = 2023 + +# set device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print("Device: ", device) + +# set output directory +output_dir = "client_talk_type_output" +if not os.path.isdir(output_dir): + os.makedirs(output_dir) + +# load sbert embeddings +with open("../anno_mi_sbert.pkl", "rb") as f: + sbert_embeddings = pickle.load(f) + +# set features +features = ["time_encoding", "timeline_index"] +standardise_method = ["z_score", None] +include_features_in_path = True +include_features_in_input = True + +# set hyperparameters +num_epochs = 100 +dimensions = [15] +swnu_hidden_dim_sizes_and_sig_depths = [([12], 3), ([10], 3)] +ffn_hidden_dim_sizes = [[32,32], [128,128], [512,512]] +dropout_rates = [0.1] +learning_rates = [5e-4, 3e-4, 1e-4] +seeds = [1, 12, 123] +loss = "focal" +gamma = 2 +validation_metric = "f1" +patience = 3 + +# set kwargs +kwargs = { + "num_epochs": num_epochs, + "df": anno_mi, + "id_column": "transcript_id", + "label_column": "client_talk_type", + "embeddings": sbert_embeddings, + "y_data": y_data_client, + "output_dim": output_dim_client, + "dimensions": dimensions, + "log_signature": True, + "pooling": "signature", + "swnu_hidden_dim_sizes_and_sig_depths": swnu_hidden_dim_sizes_and_sig_depths, + "ffn_hidden_dim_sizes": ffn_hidden_dim_sizes, + "dropout_rates": dropout_rates, + "learning_rates": learning_rates, + "BiLSTM": True, + "seeds": seeds, + "loss": loss, + "gamma": gamma, + "device": device, + "features": features, + "standardise_method": standardise_method, + "include_features_in_path": include_features_in_path, + "include_features_in_input": include_features_in_input, + "path_indices": client_index, + "split_ids": client_transcript_id, + "k_fold": True, + "patience": patience, + "validation_metric": validation_metric, + "verbose": False, +} + +# run hyperparameter search +lengths = [(3, 5, 3), (3, 5, 6), (3, 5, 11), (3, 5, 26), (3, 5, 36)] + +# run hyperparameter search +lengths = [5, 11, 20, 35, 80, 110] + +for size in lengths: + print(f"history_length: {size}") + ( + swnu_network_umap_kfold, + best_swnu_network_umap_kfold, + _, + __, + ) = swnu_network_hyperparameter_search( + history_lengths=[size], + dim_reduce_methods=["umap"], + results_output=f"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv", + **kwargs, + ) + + print(f"F1: {best_swnu_network_umap_kfold['f1'].mean()}") + print(f"Precision: {best_swnu_network_umap_kfold['precision'].mean()}") + print(f"Recall: {best_swnu_network_umap_kfold['recall'].mean()}") + print(f"F1 scores: {np.stack(best_swnu_network_umap_kfold['f1_scores']).mean(axis=0)}") + print(f"Precision scores: {np.stack(best_swnu_network_umap_kfold['precision_scores']).mean(axis=0)}") + print(f"Recall scores: {np.stack(best_swnu_network_umap_kfold['recall_scores']).mean(axis=0)}") diff --git a/notebooks/Rumours/rumours-baseline-bert.py b/notebooks/Rumours/rumours-baseline-bert.py index 40a706d..67ccafa 100644 --- a/notebooks/Rumours/rumours-baseline-bert.py +++ b/notebooks/Rumours/rumours-baseline-bert.py @@ -1,14 +1,16 @@ -import numpy as np +from __future__ import annotations + import os + +import numpy as np import torch import transformers +from load_rumours import df_rumours, id_to_label, label_to_id, output_dim, split_ids -from nlpsig_networks.scripts.fine_tune_bert_classification import ( +from sig_networks.scripts.fine_tune_bert_classification import ( fine_tune_transformer_average_seed, ) -from load_rumours import df_rumours, output_dim, split_ids, label_to_id, id_to_label - # set to only report critical errors to avoid excessing logging transformers.utils.logging.set_verbosity(50) @@ -52,31 +54,26 @@ gamma = 2 for loss in ["focal", "cross_entropy"]: if loss == "focal": - results_output=f"{output_dir}/bert_classifier_focal.csv", + results_output = (f"{output_dir}/bert_classifier_focal.csv",) else: - results_output=f"{output_dir}/bert_classifier_ce.csv" - + results_output = f"{output_dir}/bert_classifier_ce.csv" + bert_classifier, best_bert_classifier, _, __ = fine_tune_transformer_average_seed( loss=loss, gamma=gamma, results_output=results_output, **kwargs, ) - + print(f"F1: {best_bert_classifier['f1'].mean()}") - print( - f"Precision: {best_bert_classifier['precision'].mean()}" - ) + print(f"Precision: {best_bert_classifier['precision'].mean()}") print(f"Recall: {best_bert_classifier['recall'].mean()}") + print("F1 scores: " f"{np.stack(best_bert_classifier['f1_scores']).mean(axis=0)}") print( - "F1 scores: " - f"{np.stack(best_bert_classifier['f1_scores']).mean(axis=0)}" - ) - print( - "Precision scores: " - f"{np.stack(best_bert_classifier['precision_scores']).mean(axis=0)}" + "Precision scores: " + f"{np.stack(best_bert_classifier['precision_scores']).mean(axis=0)}" ) print( - "Recall scores: " - f"{np.stack(best_bert_classifier['recall_scores']).mean(axis=0)}" + "Recall scores: " + f"{np.stack(best_bert_classifier['recall_scores']).mean(axis=0)}" ) diff --git a/notebooks/Rumours/rumours-baseline-bilstm.py b/notebooks/Rumours/rumours-baseline-bilstm.py index ec42b45..5bcb9c2 100644 --- a/notebooks/Rumours/rumours-baseline-bilstm.py +++ b/notebooks/Rumours/rumours-baseline-bilstm.py @@ -1,11 +1,15 @@ -import numpy as np +from __future__ import annotations + import os + +import numpy as np import torch -from nlpsig_networks.scripts.lstm_baseline_functions import ( +from load_rumours import df_rumours, output_dim, split_ids, y_data +from load_sbert_embeddings import sbert_embeddings + +from sig_networks.scripts.lstm_baseline_functions import ( lstm_hyperparameter_search, ) -from load_sbert_embeddings import sbert_embeddings -from load_rumours import df_rumours, y_data, output_dim, split_ids seed = 2023 diff --git a/notebooks/Rumours/rumours-baseline-ffn-history.py b/notebooks/Rumours/rumours-baseline-ffn-history.py index 63fbb68..b9de756 100644 --- a/notebooks/Rumours/rumours-baseline-ffn-history.py +++ b/notebooks/Rumours/rumours-baseline-ffn-history.py @@ -1,11 +1,15 @@ -import numpy as np +from __future__ import annotations + import os + +import numpy as np import torch -from nlpsig_networks.scripts.ffn_baseline_functions import ( +from load_rumours import df_rumours, output_dim, split_ids, y_data +from load_sbert_embeddings import sbert_embeddings + +from sig_networks.scripts.ffn_baseline_functions import ( histories_baseline_hyperparameter_search, ) -from load_sbert_embeddings import sbert_embeddings -from load_rumours import df_rumours, y_data, output_dim, split_ids seed = 2023 @@ -64,14 +68,9 @@ ) print(f"F1: {best_ffn_mean_history_kfold['f1'].mean()}") -print( - f"Precision: {best_ffn_mean_history_kfold['precision'].mean()}" -) +print(f"Precision: {best_ffn_mean_history_kfold['precision'].mean()}") print(f"Recall: {best_ffn_mean_history_kfold['recall'].mean()}") -print( - "F1 scores: " - f"{np.stack(best_ffn_mean_history_kfold['f1_scores']).mean(axis=0)}" -) +print(f"F1 scores: {np.stack(best_ffn_mean_history_kfold['f1_scores']).mean(axis=0)}") print( "Precision scores: " f"{np.stack(best_ffn_mean_history_kfold['precision_scores']).mean(axis=0)}" diff --git a/notebooks/Rumours/rumours-baseline-ffn.py b/notebooks/Rumours/rumours-baseline-ffn.py index 56bd700..c7a8656 100644 --- a/notebooks/Rumours/rumours-baseline-ffn.py +++ b/notebooks/Rumours/rumours-baseline-ffn.py @@ -1,12 +1,15 @@ -import numpy as np +from __future__ import annotations import os + +import numpy as np import torch -from nlpsig_networks.scripts.ffn_baseline_functions import ( +from load_rumours import output_dim, split_ids, y_data +from load_sbert_embeddings import sbert_embeddings + +from sig_networks.scripts.ffn_baseline_functions import ( ffn_hyperparameter_search, ) -from load_sbert_embeddings import sbert_embeddings -from load_rumours import df_rumours, y_data, output_dim, split_ids seed = 2023 @@ -56,14 +59,9 @@ ) print(f"F1: {best_ffn_current_kfold['f1'].mean()}") -print( - f"Precision: {best_ffn_current_kfold['precision'].mean()}" -) +print(f"Precision: {best_ffn_current_kfold['precision'].mean()}") print(f"Recall: {best_ffn_current_kfold['recall'].mean()}") -print( - "F1 scores: " - f"{np.stack(best_ffn_current_kfold['f1_scores']).mean(axis=0)}" -) +print("F1 scores: " f"{np.stack(best_ffn_current_kfold['f1_scores']).mean(axis=0)}") print( "Precision scores: " f"{np.stack(best_ffn_current_kfold['precision_scores']).mean(axis=0)}" diff --git a/notebooks/Rumours/rumours-seqsignet-attention-bilstm-script.py b/notebooks/Rumours/rumours-seqsignet-attention-bilstm-script.py index 9130185..0bea94a 100644 --- a/notebooks/Rumours/rumours-seqsignet-attention-bilstm-script.py +++ b/notebooks/Rumours/rumours-seqsignet-attention-bilstm-script.py @@ -1,11 +1,15 @@ -import numpy as np +from __future__ import annotations + import os + +import numpy as np import torch -from nlpsig_networks.scripts.seqsignet_attention_bilstm_functions import ( +from load_rumours import df_rumours, output_dim, split_ids, y_data +from load_sbert_embeddings import sbert_embeddings + +from sig_networks.scripts.seqsignet_attention_bilstm_functions import ( seqsignet_attention_bilstm_hyperparameter_search, ) -from load_sbert_embeddings import sbert_embeddings -from load_rumours import df_rumours, y_data, output_dim, split_ids seed = 2023 @@ -93,9 +97,7 @@ ) print(f"F1: {best_seqsignet_network_umap_kfold['f1'].mean()}") - print( - f"Precision: {best_seqsignet_network_umap_kfold['precision'].mean()}" - ) + print(f"Precision: {best_seqsignet_network_umap_kfold['precision'].mean()}") print(f"Recall: {best_seqsignet_network_umap_kfold['recall'].mean()}") print( "F1 scores: " diff --git a/notebooks/Rumours/rumours-seqsignet-attention-encoder-script.py b/notebooks/Rumours/rumours-seqsignet-attention-encoder-script.py index 5cc4b58..0d0f23c 100644 --- a/notebooks/Rumours/rumours-seqsignet-attention-encoder-script.py +++ b/notebooks/Rumours/rumours-seqsignet-attention-encoder-script.py @@ -1,11 +1,15 @@ -import numpy as np +from __future__ import annotations + import os + +import numpy as np import torch -from nlpsig_networks.scripts.seqsignet_attention_encoder_functions import ( +from load_rumours import df_rumours, output_dim, split_ids, y_data +from load_sbert_embeddings import sbert_embeddings + +from sig_networks.scripts.seqsignet_attention_encoder_functions import ( seqsignet_attention_encoder_hyperparameter_search, ) -from load_sbert_embeddings import sbert_embeddings -from load_rumours import df_rumours, y_data, output_dim, split_ids seed = 2023 diff --git a/notebooks/Rumours/rumours-seqsignet-script.py b/notebooks/Rumours/rumours-seqsignet-script.py index 80fefb3..1361faa 100644 --- a/notebooks/Rumours/rumours-seqsignet-script.py +++ b/notebooks/Rumours/rumours-seqsignet-script.py @@ -1,9 +1,13 @@ -import numpy as np +from __future__ import annotations + import os + +import numpy as np import torch -from nlpsig_networks.scripts.seqsignet_functions import seqsignet_hyperparameter_search +from load_rumours import df_rumours, output_dim, split_ids, y_data from load_sbert_embeddings import sbert_embeddings -from load_rumours import df_rumours, y_data, output_dim, split_ids + +from sig_networks.scripts.seqsignet_functions import seqsignet_hyperparameter_search seed = 2023 diff --git a/notebooks/Rumours/rumours-swmhau-script.py b/notebooks/Rumours/rumours-swmhau-script.py index 31e3465..8ca41c4 100644 --- a/notebooks/Rumours/rumours-swmhau-script.py +++ b/notebooks/Rumours/rumours-swmhau-script.py @@ -1,11 +1,15 @@ -import numpy as np +from __future__ import annotations + import os + +import numpy as np import torch -from nlpsig_networks.scripts.swmhau_network_functions import ( +from load_rumours import df_rumours, output_dim, split_ids, y_data +from load_sbert_embeddings import sbert_embeddings + +from sig_networks.scripts.swmhau_network_functions import ( swmhau_network_hyperparameter_search, ) -from load_sbert_embeddings import sbert_embeddings -from load_rumours import df_rumours, y_data, output_dim, split_ids seed = 2023 diff --git a/notebooks/Rumours/rumours-swnu-script.py b/notebooks/Rumours/rumours-swnu-script.py index 4715dce..00cfe33 100644 --- a/notebooks/Rumours/rumours-swnu-script.py +++ b/notebooks/Rumours/rumours-swnu-script.py @@ -1,11 +1,15 @@ -import numpy as np +from __future__ import annotations + import os + +import numpy as np import torch -from nlpsig_networks.scripts.swnu_network_functions import ( +from load_rumours import df_rumours, output_dim, split_ids, y_data +from load_sbert_embeddings import sbert_embeddings + +from sig_networks.scripts.swnu_network_functions import ( swnu_network_hyperparameter_search, ) -from load_sbert_embeddings import sbert_embeddings -from load_rumours import df_rumours, y_data, output_dim, split_ids seed = 2023 diff --git a/notebooks/feed-forward-mnist.ipynb b/notebooks/feed-forward-mnist.ipynb index 83ea293..4fc8547 100644 --- a/notebooks/feed-forward-mnist.ipynb +++ b/notebooks/feed-forward-mnist.ipynb @@ -14,10 +14,10 @@ "import matplotlib.pyplot as plt\n", "from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR\n", "\n", - "import nlpsig_networks\n", - "from nlpsig_networks.ffn import FeedforwardNeuralNetModel\n", - "from nlpsig_networks.focal_loss import FocalLoss\n", - "from nlpsig_networks.pytorch_utils import training_pytorch, testing_pytorch\n", + "import sig_networks\n", + "from sig_networks.ffn import FeedforwardNeuralNetModel\n", + "from sig_networks.focal_loss import FocalLoss\n", + "from sig_networkstorch_utils import training_pytorch, testing_pytorch\n", "\n", "seed = 2023" ] diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 0000000..55a3580 --- /dev/null +++ b/noxfile.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import argparse +import shutil +from pathlib import Path + +import nox + +DIR = Path(__file__).parent.resolve() + +nox.options.sessions = ["lint", "tests"] + + +@nox.session(python=["3.8"]) +def lint(session: nox.Session) -> None: + """ + Run the linter. + """ + session.install("pre-commit") + session.run("pre-commit", "run", "--all-files", *session.posargs) + + +@nox.session(python=["3.8"]) +def tests(session: nox.Session) -> None: + """ + Run the unit and regular tests. + """ + session.install(".[test]") + session.run("pytest", *session.posargs) + + +@nox.session(python=["3.8"]) +def coverage(session: nox.Session) -> None: + """ + Run tests and compute coverage. + """ + + session.posargs.append("--cov=nlpsig") + tests(session) + + +@nox.session(python=["3.8"]) +def docs(session: nox.Session) -> None: + """ + Build the docs. Pass "--serve" to serve. + """ + + parser = argparse.ArgumentParser() + parser.add_argument("--serve", action="store_true", help="Serve after building") + args = parser.parse_args(session.posargs) + + session.install(".[docs]") + session.chdir("docs") + session.run("sphinx-build", "-M", "html", ".", "_build") + + if args.serve: + print("Launching docs at http://localhost:8000/ - use Ctrl-C to quit") + session.run("python", "-m", "http.server", "8000", "-d", "_build/html") + + +@nox.session(python=["3.8"]) +def build_api_docs(session: nox.Session) -> None: + """ + Build (regenerate) API docs. + """ + + session.install("sphinx") + session.chdir("docs") + session.run( + "sphinx-apidoc", + "-o", + "api/", + "--no-toc", + "--force", + "--module-first", + "../src/nlpsig", + ) + + +@nox.session(python=["3.8"]) +def build(session: nox.Session) -> None: + """ + Build an SDist and wheel. + """ + + build_p = DIR.joinpath("build") + if build_p.exists(): + shutil.rmtree(build_p) + + session.install("build") + session.run("python", "-m", "build") diff --git a/poetry.lock b/poetry.lock deleted file mode 100644 index 22c679f..0000000 --- a/poetry.lock +++ /dev/null @@ -1,1077 +0,0 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. - -[[package]] -name = "accelerate" -version = "0.20.1" -description = "Accelerate" -category = "main" -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "accelerate-0.20.1-py3-none-any.whl", hash = "sha256:bbc1f06879e724079a4e564691dae4021591531fd5558864fa4a644d70bede88"}, - {file = "accelerate-0.20.1.tar.gz", hash = "sha256:5d26bfdad7e31f4e479839c6afc2bdac4e153cfc3a06d52d9af17461e3c72121"}, -] - -[package.dependencies] -numpy = ">=1.17" -packaging = ">=20.0" -psutil = "*" -pyyaml = "*" -torch = ">=1.6.0" - -[package.extras] -dev = ["black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.0.241)", "scikit-learn", "scipy", "tqdm", "transformers", "urllib3 (<2.0.0)"] -quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.0.241)", "urllib3 (<2.0.0)"] -rich = ["rich"] -sagemaker = ["sagemaker"] -test-dev = ["datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "tqdm", "transformers"] -test-prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] -test-trackers = ["comet-ml", "tensorboard", "wandb"] -testing = ["datasets", "deepspeed", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "tqdm", "transformers"] - -[[package]] -name = "certifi" -version = "2023.7.22" -description = "Python package for providing Mozilla's CA Bundle." -category = "main" -optional = false -python-versions = ">=3.6" -files = [ - {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, - {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, -] - -[[package]] -name = "charset-normalizer" -version = "3.2.0" -description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -category = "main" -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"}, - {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, -] - -[[package]] -name = "colorama" -version = "0.4.6" -description = "Cross-platform colored terminal text." -category = "main" -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -files = [ - {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, - {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, -] - -[[package]] -name = "filelock" -version = "3.12.2" -description = "A platform independent file lock." -category = "main" -optional = false -python-versions = ">=3.7" -files = [ - {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, - {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"}, -] - -[package.extras] -docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] - -[[package]] -name = "fsspec" -version = "2023.6.0" -description = "File-system specification" -category = "main" -optional = false -python-versions = ">=3.8" -files = [ - {file = "fsspec-2023.6.0-py3-none-any.whl", hash = "sha256:1cbad1faef3e391fba6dc005ae9b5bdcbf43005c9167ce78c915549c352c869a"}, - {file = "fsspec-2023.6.0.tar.gz", hash = "sha256:d0b2f935446169753e7a5c5c55681c54ea91996cc67be93c39a154fb3a2742af"}, -] - -[package.extras] -abfs = ["adlfs"] -adl = ["adlfs"] -arrow = ["pyarrow (>=1)"] -dask = ["dask", "distributed"] -devel = ["pytest", "pytest-cov"] -dropbox = ["dropbox", "dropboxdrivefs", "requests"] -full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] -fuse = ["fusepy"] -gcs = ["gcsfs"] -git = ["pygit2"] -github = ["requests"] -gs = ["gcsfs"] -gui = ["panel"] -hdfs = ["pyarrow (>=1)"] -http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] -libarchive = ["libarchive-c"] -oci = ["ocifs"] -s3 = ["s3fs"] -sftp = ["paramiko"] -smb = ["smbprotocol"] -ssh = ["paramiko"] -tqdm = ["tqdm"] - -[[package]] -name = "huggingface-hub" -version = "0.16.4" -description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -category = "main" -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "huggingface_hub-0.16.4-py3-none-any.whl", hash = "sha256:0d3df29932f334fead024afc7cb4cc5149d955238b8b5e42dcf9740d6995a349"}, - {file = "huggingface_hub-0.16.4.tar.gz", hash = "sha256:608c7d4f3d368b326d1747f91523dbd1f692871e8e2e7a4750314a2dd8b63e14"}, -] - -[package.dependencies] -filelock = "*" -fsspec = "*" -packaging = ">=20.9" -pyyaml = ">=5.1" -requests = "*" -tqdm = ">=4.42.1" -typing-extensions = ">=3.7.4.3" - -[package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] -cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (>=23.1,<24.0)", "gradio", "jedi", "mypy (==0.982)", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] -fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] -inference = ["aiohttp", "pydantic"] -quality = ["black (>=23.1,<24.0)", "mypy (==0.982)", "ruff (>=0.0.241)"] -tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] -torch = ["torch"] -typing = ["pydantic", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] - -[[package]] -name = "idna" -version = "3.4" -description = "Internationalized Domain Names in Applications (IDNA)" -category = "main" -optional = false -python-versions = ">=3.5" -files = [ - {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, - {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, -] - -[[package]] -name = "joblib" -version = "1.2.0" -description = "Lightweight pipelining with Python functions" -category = "main" -optional = false -python-versions = ">=3.7" -files = [ - {file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"}, - {file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"}, -] - -[[package]] -name = "numpy" -version = "1.24.2" -description = "Fundamental package for array computing in Python" -category = "main" -optional = false -python-versions = ">=3.8" -files = [ - {file = "numpy-1.24.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eef70b4fc1e872ebddc38cddacc87c19a3709c0e3e5d20bf3954c147b1dd941d"}, - {file = "numpy-1.24.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8d2859428712785e8a8b7d2b3ef0a1d1565892367b32f915c4a4df44d0e64f5"}, - {file = "numpy-1.24.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6524630f71631be2dabe0c541e7675db82651eb998496bbe16bc4f77f0772253"}, - {file = "numpy-1.24.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a51725a815a6188c662fb66fb32077709a9ca38053f0274640293a14fdd22978"}, - {file = "numpy-1.24.2-cp310-cp310-win32.whl", hash = "sha256:2620e8592136e073bd12ee4536149380695fbe9ebeae845b81237f986479ffc9"}, - {file = "numpy-1.24.2-cp310-cp310-win_amd64.whl", hash = "sha256:97cf27e51fa078078c649a51d7ade3c92d9e709ba2bfb97493007103c741f1d0"}, - {file = "numpy-1.24.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7de8fdde0003f4294655aa5d5f0a89c26b9f22c0a58790c38fae1ed392d44a5a"}, - {file = "numpy-1.24.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4173bde9fa2a005c2c6e2ea8ac1618e2ed2c1c6ec8a7657237854d42094123a0"}, - {file = "numpy-1.24.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cecaed30dc14123020f77b03601559fff3e6cd0c048f8b5289f4eeabb0eb281"}, - {file = "numpy-1.24.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a23f8440561a633204a67fb44617ce2a299beecf3295f0d13c495518908e910"}, - {file = "numpy-1.24.2-cp311-cp311-win32.whl", hash = "sha256:e428c4fbfa085f947b536706a2fc349245d7baa8334f0c5723c56a10595f9b95"}, - {file = "numpy-1.24.2-cp311-cp311-win_amd64.whl", hash = "sha256:557d42778a6869c2162deb40ad82612645e21d79e11c1dc62c6e82a2220ffb04"}, - {file = "numpy-1.24.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d0a2db9d20117bf523dde15858398e7c0858aadca7c0f088ac0d6edd360e9ad2"}, - {file = "numpy-1.24.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c72a6b2f4af1adfe193f7beb91ddf708ff867a3f977ef2ec53c0ffb8283ab9f5"}, - {file = "numpy-1.24.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c29e6bd0ec49a44d7690ecb623a8eac5ab8a923bce0bea6293953992edf3a76a"}, - {file = "numpy-1.24.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2eabd64ddb96a1239791da78fa5f4e1693ae2dadc82a76bc76a14cbb2b966e96"}, - {file = "numpy-1.24.2-cp38-cp38-win32.whl", hash = "sha256:e3ab5d32784e843fc0dd3ab6dcafc67ef806e6b6828dc6af2f689be0eb4d781d"}, - {file = "numpy-1.24.2-cp38-cp38-win_amd64.whl", hash = "sha256:76807b4063f0002c8532cfeac47a3068a69561e9c8715efdad3c642eb27c0756"}, - {file = "numpy-1.24.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4199e7cfc307a778f72d293372736223e39ec9ac096ff0a2e64853b866a8e18a"}, - {file = "numpy-1.24.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:adbdce121896fd3a17a77ab0b0b5eedf05a9834a18699db6829a64e1dfccca7f"}, - {file = "numpy-1.24.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:889b2cc88b837d86eda1b17008ebeb679d82875022200c6e8e4ce6cf549b7acb"}, - {file = "numpy-1.24.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64bb98ac59b3ea3bf74b02f13836eb2e24e48e0ab0145bbda646295769bd780"}, - {file = "numpy-1.24.2-cp39-cp39-win32.whl", hash = "sha256:63e45511ee4d9d976637d11e6c9864eae50e12dc9598f531c035265991910468"}, - {file = "numpy-1.24.2-cp39-cp39-win_amd64.whl", hash = "sha256:a77d3e1163a7770164404607b7ba3967fb49b24782a6ef85d9b5f54126cc39e5"}, - {file = "numpy-1.24.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92011118955724465fb6853def593cf397b4a1367495e0b59a7e69d40c4eb71d"}, - {file = "numpy-1.24.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9006288bcf4895917d02583cf3411f98631275bc67cce355a7f39f8c14338fa"}, - {file = "numpy-1.24.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:150947adbdfeceec4e5926d956a06865c1c690f2fd902efede4ca6fe2e657c3f"}, - {file = "numpy-1.24.2.tar.gz", hash = "sha256:003a9f530e880cb2cd177cba1af7220b9aa42def9c4afc2a2fc3ee6be7eb2b22"}, -] - -[[package]] -name = "packaging" -version = "23.1" -description = "Core utilities for Python packages" -category = "main" -optional = false -python-versions = ">=3.7" -files = [ - {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, - {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, -] - -[[package]] -name = "pandas" -version = "1.5.3" -description = "Powerful data structures for data analysis, time series, and statistics" -category = "main" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pandas-1.5.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3749077d86e3a2f0ed51367f30bf5b82e131cc0f14260c4d3e499186fccc4406"}, - {file = "pandas-1.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:972d8a45395f2a2d26733eb8d0f629b2f90bebe8e8eddbb8829b180c09639572"}, - {file = "pandas-1.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:50869a35cbb0f2e0cd5ec04b191e7b12ed688874bd05dd777c19b28cbea90996"}, - {file = "pandas-1.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3ac844a0fe00bfaeb2c9b51ab1424e5c8744f89860b138434a363b1f620f354"}, - {file = "pandas-1.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a0a56cef15fd1586726dace5616db75ebcfec9179a3a55e78f72c5639fa2a23"}, - {file = "pandas-1.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:478ff646ca42b20376e4ed3fa2e8d7341e8a63105586efe54fa2508ee087f328"}, - {file = "pandas-1.5.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6973549c01ca91ec96199e940495219c887ea815b2083722821f1d7abfa2b4dc"}, - {file = "pandas-1.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c39a8da13cede5adcd3be1182883aea1c925476f4e84b2807a46e2775306305d"}, - {file = "pandas-1.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f76d097d12c82a535fda9dfe5e8dd4127952b45fea9b0276cb30cca5ea313fbc"}, - {file = "pandas-1.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e474390e60ed609cec869b0da796ad94f420bb057d86784191eefc62b65819ae"}, - {file = "pandas-1.5.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f2b952406a1588ad4cad5b3f55f520e82e902388a6d5a4a91baa8d38d23c7f6"}, - {file = "pandas-1.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:bc4c368f42b551bf72fac35c5128963a171b40dce866fb066540eeaf46faa003"}, - {file = "pandas-1.5.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:14e45300521902689a81f3f41386dc86f19b8ba8dd5ac5a3c7010ef8d2932813"}, - {file = "pandas-1.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9842b6f4b8479e41968eced654487258ed81df7d1c9b7b870ceea24ed9459b31"}, - {file = "pandas-1.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:26d9c71772c7afb9d5046e6e9cf42d83dd147b5cf5bcb9d97252077118543792"}, - {file = "pandas-1.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fbcb19d6fceb9e946b3e23258757c7b225ba450990d9ed63ccceeb8cae609f7"}, - {file = "pandas-1.5.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:565fa34a5434d38e9d250af3c12ff931abaf88050551d9fbcdfafca50d62babf"}, - {file = "pandas-1.5.3-cp38-cp38-win32.whl", hash = "sha256:87bd9c03da1ac870a6d2c8902a0e1fd4267ca00f13bc494c9e5a9020920e1d51"}, - {file = "pandas-1.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:41179ce559943d83a9b4bbacb736b04c928b095b5f25dd2b7389eda08f46f373"}, - {file = "pandas-1.5.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c74a62747864ed568f5a82a49a23a8d7fe171d0c69038b38cedf0976831296fa"}, - {file = "pandas-1.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4c00e0b0597c8e4f59e8d461f797e5d70b4d025880516a8261b2817c47759ee"}, - {file = "pandas-1.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a50d9a4336a9621cab7b8eb3fb11adb82de58f9b91d84c2cd526576b881a0c5a"}, - {file = "pandas-1.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd05f7783b3274aa206a1af06f0ceed3f9b412cf665b7247eacd83be41cf7bf0"}, - {file = "pandas-1.5.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f69c4029613de47816b1bb30ff5ac778686688751a5e9c99ad8c7031f6508e5"}, - {file = "pandas-1.5.3-cp39-cp39-win32.whl", hash = "sha256:7cec0bee9f294e5de5bbfc14d0573f65526071029d036b753ee6507d2a21480a"}, - {file = "pandas-1.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:dfd681c5dc216037e0b0a2c821f5ed99ba9f03ebcf119c7dac0e9a7b960b9ec9"}, - {file = "pandas-1.5.3.tar.gz", hash = "sha256:74a3fd7e5a7ec052f183273dc7b0acd3a863edf7520f5d3a1765c04ffdb3b0b1"}, -] - -[package.dependencies] -numpy = {version = ">=1.20.3", markers = "python_version < \"3.10\""} -python-dateutil = ">=2.8.1" -pytz = ">=2020.1" - -[package.extras] -test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] - -[[package]] -name = "pillow" -version = "9.5.0" -description = "Python Imaging Library (Fork)" -category = "main" -optional = false -python-versions = ">=3.7" -files = [ - {file = "Pillow-9.5.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:ace6ca218308447b9077c14ea4ef381ba0b67ee78d64046b3f19cf4e1139ad16"}, - {file = "Pillow-9.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d3d403753c9d5adc04d4694d35cf0391f0f3d57c8e0030aac09d7678fa8030aa"}, - {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ba1b81ee69573fe7124881762bb4cd2e4b6ed9dd28c9c60a632902fe8db8b38"}, - {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe7e1c262d3392afcf5071df9afa574544f28eac825284596ac6db56e6d11062"}, - {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f36397bf3f7d7c6a3abdea815ecf6fd14e7fcd4418ab24bae01008d8d8ca15e"}, - {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:252a03f1bdddce077eff2354c3861bf437c892fb1832f75ce813ee94347aa9b5"}, - {file = "Pillow-9.5.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85ec677246533e27770b0de5cf0f9d6e4ec0c212a1f89dfc941b64b21226009d"}, - {file = "Pillow-9.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b416f03d37d27290cb93597335a2f85ed446731200705b22bb927405320de903"}, - {file = "Pillow-9.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1781a624c229cb35a2ac31cc4a77e28cafc8900733a864870c49bfeedacd106a"}, - {file = "Pillow-9.5.0-cp310-cp310-win32.whl", hash = "sha256:8507eda3cd0608a1f94f58c64817e83ec12fa93a9436938b191b80d9e4c0fc44"}, - {file = "Pillow-9.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:d3c6b54e304c60c4181da1c9dadf83e4a54fd266a99c70ba646a9baa626819eb"}, - {file = "Pillow-9.5.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:7ec6f6ce99dab90b52da21cf0dc519e21095e332ff3b399a357c187b1a5eee32"}, - {file = "Pillow-9.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:560737e70cb9c6255d6dcba3de6578a9e2ec4b573659943a5e7e4af13f298f5c"}, - {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96e88745a55b88a7c64fa49bceff363a1a27d9a64e04019c2281049444a571e3"}, - {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d9c206c29b46cfd343ea7cdfe1232443072bbb270d6a46f59c259460db76779a"}, - {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cfcc2c53c06f2ccb8976fb5c71d448bdd0a07d26d8e07e321c103416444c7ad1"}, - {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a0f9bb6c80e6efcde93ffc51256d5cfb2155ff8f78292f074f60f9e70b942d99"}, - {file = "Pillow-9.5.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:8d935f924bbab8f0a9a28404422da8af4904e36d5c33fc6f677e4c4485515625"}, - {file = "Pillow-9.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fed1e1cf6a42577953abbe8e6cf2fe2f566daebde7c34724ec8803c4c0cda579"}, - {file = "Pillow-9.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c1170d6b195555644f0616fd6ed929dfcf6333b8675fcca044ae5ab110ded296"}, - {file = "Pillow-9.5.0-cp311-cp311-win32.whl", hash = "sha256:54f7102ad31a3de5666827526e248c3530b3a33539dbda27c6843d19d72644ec"}, - {file = "Pillow-9.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:cfa4561277f677ecf651e2b22dc43e8f5368b74a25a8f7d1d4a3a243e573f2d4"}, - {file = "Pillow-9.5.0-cp311-cp311-win_arm64.whl", hash = "sha256:965e4a05ef364e7b973dd17fc765f42233415974d773e82144c9bbaaaea5d089"}, - {file = "Pillow-9.5.0-cp312-cp312-win32.whl", hash = "sha256:22baf0c3cf0c7f26e82d6e1adf118027afb325e703922c8dfc1d5d0156bb2eeb"}, - {file = "Pillow-9.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:432b975c009cf649420615388561c0ce7cc31ce9b2e374db659ee4f7d57a1f8b"}, - {file = "Pillow-9.5.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:5d4ebf8e1db4441a55c509c4baa7a0587a0210f7cd25fcfe74dbbce7a4bd1906"}, - {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:375f6e5ee9620a271acb6820b3d1e94ffa8e741c0601db4c0c4d3cb0a9c224bf"}, - {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99eb6cafb6ba90e436684e08dad8be1637efb71c4f2180ee6b8f940739406e78"}, - {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dfaaf10b6172697b9bceb9a3bd7b951819d1ca339a5ef294d1f1ac6d7f63270"}, - {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:763782b2e03e45e2c77d7779875f4432e25121ef002a41829d8868700d119392"}, - {file = "Pillow-9.5.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:35f6e77122a0c0762268216315bf239cf52b88865bba522999dc38f1c52b9b47"}, - {file = "Pillow-9.5.0-cp37-cp37m-win32.whl", hash = "sha256:aca1c196f407ec7cf04dcbb15d19a43c507a81f7ffc45b690899d6a76ac9fda7"}, - {file = "Pillow-9.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:322724c0032af6692456cd6ed554bb85f8149214d97398bb80613b04e33769f6"}, - {file = "Pillow-9.5.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:a0aa9417994d91301056f3d0038af1199eb7adc86e646a36b9e050b06f526597"}, - {file = "Pillow-9.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8286396b351785801a976b1e85ea88e937712ee2c3ac653710a4a57a8da5d9c"}, - {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c830a02caeb789633863b466b9de10c015bded434deb3ec87c768e53752ad22a"}, - {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fbd359831c1657d69bb81f0db962905ee05e5e9451913b18b831febfe0519082"}, - {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8fc330c3370a81bbf3f88557097d1ea26cd8b019d6433aa59f71195f5ddebbf"}, - {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:7002d0797a3e4193c7cdee3198d7c14f92c0836d6b4a3f3046a64bd1ce8df2bf"}, - {file = "Pillow-9.5.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:229e2c79c00e85989a34b5981a2b67aa079fd08c903f0aaead522a1d68d79e51"}, - {file = "Pillow-9.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9adf58f5d64e474bed00d69bcd86ec4bcaa4123bfa70a65ce72e424bfb88ed96"}, - {file = "Pillow-9.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:662da1f3f89a302cc22faa9f14a262c2e3951f9dbc9617609a47521c69dd9f8f"}, - {file = "Pillow-9.5.0-cp38-cp38-win32.whl", hash = "sha256:6608ff3bf781eee0cd14d0901a2b9cc3d3834516532e3bd673a0a204dc8615fc"}, - {file = "Pillow-9.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:e49eb4e95ff6fd7c0c402508894b1ef0e01b99a44320ba7d8ecbabefddcc5569"}, - {file = "Pillow-9.5.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:482877592e927fd263028c105b36272398e3e1be3269efda09f6ba21fd83ec66"}, - {file = "Pillow-9.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3ded42b9ad70e5f1754fb7c2e2d6465a9c842e41d178f262e08b8c85ed8a1d8e"}, - {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c446d2245ba29820d405315083d55299a796695d747efceb5717a8b450324115"}, - {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8aca1152d93dcc27dc55395604dcfc55bed5f25ef4c98716a928bacba90d33a3"}, - {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:608488bdcbdb4ba7837461442b90ea6f3079397ddc968c31265c1e056964f1ef"}, - {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:60037a8db8750e474af7ffc9faa9b5859e6c6d0a50e55c45576bf28be7419705"}, - {file = "Pillow-9.5.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:07999f5834bdc404c442146942a2ecadd1cb6292f5229f4ed3b31e0a108746b1"}, - {file = "Pillow-9.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a127ae76092974abfbfa38ca2d12cbeddcdeac0fb71f9627cc1135bedaf9d51a"}, - {file = "Pillow-9.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:489f8389261e5ed43ac8ff7b453162af39c3e8abd730af8363587ba64bb2e865"}, - {file = "Pillow-9.5.0-cp39-cp39-win32.whl", hash = "sha256:9b1af95c3a967bf1da94f253e56b6286b50af23392a886720f563c547e48e964"}, - {file = "Pillow-9.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:77165c4a5e7d5a284f10a6efaa39a0ae8ba839da344f20b111d62cc932fa4e5d"}, - {file = "Pillow-9.5.0-pp38-pypy38_pp73-macosx_10_10_x86_64.whl", hash = "sha256:833b86a98e0ede388fa29363159c9b1a294b0905b5128baf01db683672f230f5"}, - {file = "Pillow-9.5.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaf305d6d40bd9632198c766fb64f0c1a83ca5b667f16c1e79e1661ab5060140"}, - {file = "Pillow-9.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0852ddb76d85f127c135b6dd1f0bb88dbb9ee990d2cd9aa9e28526c93e794fba"}, - {file = "Pillow-9.5.0-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:91ec6fe47b5eb5a9968c79ad9ed78c342b1f97a091677ba0e012701add857829"}, - {file = "Pillow-9.5.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:cb841572862f629b99725ebaec3287fc6d275be9b14443ea746c1dd325053cbd"}, - {file = "Pillow-9.5.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:c380b27d041209b849ed246b111b7c166ba36d7933ec6e41175fd15ab9eb1572"}, - {file = "Pillow-9.5.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c9af5a3b406a50e313467e3565fc99929717f780164fe6fbb7704edba0cebbe"}, - {file = "Pillow-9.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5671583eab84af046a397d6d0ba25343c00cd50bce03787948e0fff01d4fd9b1"}, - {file = "Pillow-9.5.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:84a6f19ce086c1bf894644b43cd129702f781ba5751ca8572f08aa40ef0ab7b7"}, - {file = "Pillow-9.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:1e7723bd90ef94eda669a3c2c19d549874dd5badaeefabefd26053304abe5799"}, - {file = "Pillow-9.5.0.tar.gz", hash = "sha256:bf548479d336726d7a0eceb6e767e179fbde37833ae42794602631a070d630f1"}, -] - -[package.extras] -docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] -tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] - -[[package]] -name = "psutil" -version = "5.9.5" -description = "Cross-platform lib for process and system monitoring in Python." -category = "main" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -files = [ - {file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"}, - {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"}, - {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"}, - {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"}, - {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4"}, - {file = "psutil-5.9.5-cp27-none-win32.whl", hash = "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f"}, - {file = "psutil-5.9.5-cp27-none-win_amd64.whl", hash = "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42"}, - {file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"}, - {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"}, - {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"}, - {file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"}, - {file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"}, - {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"}, - {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"}, -] - -[package.extras] -test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] - -[[package]] -name = "python-dateutil" -version = "2.8.2" -description = "Extensions to the standard Python datetime module" -category = "main" -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" -files = [ - {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, - {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, -] - -[package.dependencies] -six = ">=1.5" - -[[package]] -name = "pytz" -version = "2023.3" -description = "World timezone definitions, modern and historical" -category = "main" -optional = false -python-versions = "*" -files = [ - {file = "pytz-2023.3-py2.py3-none-any.whl", hash = "sha256:a151b3abb88eda1d4e34a9814df37de2a80e301e68ba0fd856fb9b46bfbbbffb"}, - {file = "pytz-2023.3.tar.gz", hash = "sha256:1d8ce29db189191fb55338ee6d0387d82ab59f3d00eac103412d64e0ebd0c588"}, -] - -[[package]] -name = "pyyaml" -version = "6.0.1" -description = "YAML parser and emitter for Python" -category = "main" -optional = false -python-versions = ">=3.6" -files = [ - {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, - {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, - {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, - {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, - {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, - {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, - {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, - {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, - {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, -] - -[[package]] -name = "regex" -version = "2023.6.3" -description = "Alternative regular expression module, to replace re." -category = "main" -optional = false -python-versions = ">=3.6" -files = [ - {file = "regex-2023.6.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:824bf3ac11001849aec3fa1d69abcb67aac3e150a933963fb12bda5151fe1bfd"}, - {file = "regex-2023.6.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:05ed27acdf4465c95826962528f9e8d41dbf9b1aa8531a387dee6ed215a3e9ef"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b49c764f88a79160fa64f9a7b425620e87c9f46095ef9c9920542ab2495c8bc"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8e3f1316c2293e5469f8f09dc2d76efb6c3982d3da91ba95061a7e69489a14ef"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:43e1dd9d12df9004246bacb79a0e5886b3b6071b32e41f83b0acbf293f820ee8"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4959e8bcbfda5146477d21c3a8ad81b185cd252f3d0d6e4724a5ef11c012fb06"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:af4dd387354dc83a3bff67127a124c21116feb0d2ef536805c454721c5d7993d"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2239d95d8e243658b8dbb36b12bd10c33ad6e6933a54d36ff053713f129aa536"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:890e5a11c97cf0d0c550eb661b937a1e45431ffa79803b942a057c4fb12a2da2"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a8105e9af3b029f243ab11ad47c19b566482c150c754e4c717900a798806b222"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:25be746a8ec7bc7b082783216de8e9473803706723b3f6bef34b3d0ed03d57e2"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:3676f1dd082be28b1266c93f618ee07741b704ab7b68501a173ce7d8d0d0ca18"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:10cb847aeb1728412c666ab2e2000ba6f174f25b2bdc7292e7dd71b16db07568"}, - {file = "regex-2023.6.3-cp310-cp310-win32.whl", hash = "sha256:dbbbfce33cd98f97f6bffb17801b0576e653f4fdb1d399b2ea89638bc8d08ae1"}, - {file = "regex-2023.6.3-cp310-cp310-win_amd64.whl", hash = "sha256:c5f8037000eb21e4823aa485149f2299eb589f8d1fe4b448036d230c3f4e68e0"}, - {file = "regex-2023.6.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c123f662be8ec5ab4ea72ea300359023a5d1df095b7ead76fedcd8babbedf969"}, - {file = "regex-2023.6.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9edcbad1f8a407e450fbac88d89e04e0b99a08473f666a3f3de0fd292badb6aa"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcba6dae7de533c876255317c11f3abe4907ba7d9aa15d13e3d9710d4315ec0e"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29cdd471ebf9e0f2fb3cac165efedc3c58db841d83a518b082077e612d3ee5df"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:12b74fbbf6cbbf9dbce20eb9b5879469e97aeeaa874145517563cca4029db65c"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c29ca1bd61b16b67be247be87390ef1d1ef702800f91fbd1991f5c4421ebae8"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d77f09bc4b55d4bf7cc5eba785d87001d6757b7c9eec237fe2af57aba1a071d9"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ea353ecb6ab5f7e7d2f4372b1e779796ebd7b37352d290096978fea83c4dba0c"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:10590510780b7541969287512d1b43f19f965c2ece6c9b1c00fc367b29d8dce7"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e2fbd6236aae3b7f9d514312cdb58e6494ee1c76a9948adde6eba33eb1c4264f"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:6b2675068c8b56f6bfd5a2bda55b8accbb96c02fd563704732fd1c95e2083461"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:74419d2b50ecb98360cfaa2974da8689cb3b45b9deff0dcf489c0d333bcc1477"}, - {file = "regex-2023.6.3-cp311-cp311-win32.whl", hash = "sha256:fb5ec16523dc573a4b277663a2b5a364e2099902d3944c9419a40ebd56a118f9"}, - {file = "regex-2023.6.3-cp311-cp311-win_amd64.whl", hash = "sha256:09e4a1a6acc39294a36b7338819b10baceb227f7f7dbbea0506d419b5a1dd8af"}, - {file = "regex-2023.6.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:0654bca0cdf28a5956c83839162692725159f4cda8d63e0911a2c0dc76166525"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:463b6a3ceb5ca952e66550a4532cef94c9a0c80dc156c4cc343041951aec1697"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87b2a5bb5e78ee0ad1de71c664d6eb536dc3947a46a69182a90f4410f5e3f7dd"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6343c6928282c1f6a9db41f5fd551662310e8774c0e5ebccb767002fcf663ca9"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6192d5af2ccd2a38877bfef086d35e6659566a335b1492786ff254c168b1693"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74390d18c75054947e4194019077e243c06fbb62e541d8817a0fa822ea310c14"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:742e19a90d9bb2f4a6cf2862b8b06dea5e09b96c9f2df1779e53432d7275331f"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:8abbc5d54ea0ee80e37fef009e3cec5dafd722ed3c829126253d3e22f3846f1e"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:c2b867c17a7a7ae44c43ebbeb1b5ff406b3e8d5b3e14662683e5e66e6cc868d3"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:d831c2f8ff278179705ca59f7e8524069c1a989e716a1874d6d1aab6119d91d1"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:ee2d1a9a253b1729bb2de27d41f696ae893507c7db224436abe83ee25356f5c1"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:61474f0b41fe1a80e8dfa70f70ea1e047387b7cd01c85ec88fa44f5d7561d787"}, - {file = "regex-2023.6.3-cp36-cp36m-win32.whl", hash = "sha256:0b71e63226e393b534105fcbdd8740410dc6b0854c2bfa39bbda6b0d40e59a54"}, - {file = "regex-2023.6.3-cp36-cp36m-win_amd64.whl", hash = "sha256:bbb02fd4462f37060122e5acacec78e49c0fbb303c30dd49c7f493cf21fc5b27"}, - {file = "regex-2023.6.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b862c2b9d5ae38a68b92e215b93f98d4c5e9454fa36aae4450f61dd33ff48487"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:976d7a304b59ede34ca2921305b57356694f9e6879db323fd90a80f865d355a3"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:83320a09188e0e6c39088355d423aa9d056ad57a0b6c6381b300ec1a04ec3d16"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9427a399501818a7564f8c90eced1e9e20709ece36be701f394ada99890ea4b3"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178bbc1b2ec40eaca599d13c092079bf529679bf0371c602edaa555e10b41c3"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:837328d14cde912af625d5f303ec29f7e28cdab588674897baafaf505341f2fc"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d44dc13229905ae96dd2ae2dd7cebf824ee92bc52e8cf03dcead37d926da019"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d54af539295392611e7efbe94e827311eb8b29668e2b3f4cadcfe6f46df9c777"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:7117d10690c38a622e54c432dfbbd3cbd92f09401d622902c32f6d377e2300ee"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bb60b503ec8a6e4e3e03a681072fa3a5adcbfa5479fa2d898ae2b4a8e24c4591"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:65ba8603753cec91c71de423a943ba506363b0e5c3fdb913ef8f9caa14b2c7e0"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:271f0bdba3c70b58e6f500b205d10a36fb4b58bd06ac61381b68de66442efddb"}, - {file = "regex-2023.6.3-cp37-cp37m-win32.whl", hash = "sha256:9beb322958aaca059f34975b0df135181f2e5d7a13b84d3e0e45434749cb20f7"}, - {file = "regex-2023.6.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fea75c3710d4f31389eed3c02f62d0b66a9da282521075061ce875eb5300cf23"}, - {file = "regex-2023.6.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8f56fcb7ff7bf7404becdfc60b1e81a6d0561807051fd2f1860b0d0348156a07"}, - {file = "regex-2023.6.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d2da3abc88711bce7557412310dfa50327d5769a31d1c894b58eb256459dc289"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a99b50300df5add73d307cf66abea093304a07eb017bce94f01e795090dea87c"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5708089ed5b40a7b2dc561e0c8baa9535b77771b64a8330b684823cfd5116036"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:687ea9d78a4b1cf82f8479cab23678aff723108df3edeac098e5b2498879f4a7"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3850beab9f527f06ccc94b446c864059c57651b3f911fddb8d9d3ec1d1b25d"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8915cc96abeb8983cea1df3c939e3c6e1ac778340c17732eb63bb96247b91d2"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:841d6e0e5663d4c7b4c8099c9997be748677d46cbf43f9f471150e560791f7ff"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9edce5281f965cf135e19840f4d93d55b3835122aa76ccacfd389e880ba4cf82"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b956231ebdc45f5b7a2e1f90f66a12be9610ce775fe1b1d50414aac1e9206c06"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:36efeba71c6539d23c4643be88295ce8c82c88bbd7c65e8a24081d2ca123da3f"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:cf67ca618b4fd34aee78740bea954d7c69fdda419eb208c2c0c7060bb822d747"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b4598b1897837067a57b08147a68ac026c1e73b31ef6e36deeeb1fa60b2933c9"}, - {file = "regex-2023.6.3-cp38-cp38-win32.whl", hash = "sha256:f415f802fbcafed5dcc694c13b1292f07fe0befdb94aa8a52905bd115ff41e88"}, - {file = "regex-2023.6.3-cp38-cp38-win_amd64.whl", hash = "sha256:d4f03bb71d482f979bda92e1427f3ec9b220e62a7dd337af0aa6b47bf4498f72"}, - {file = "regex-2023.6.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ccf91346b7bd20c790310c4147eee6ed495a54ddb6737162a36ce9dbef3e4751"}, - {file = "regex-2023.6.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b28f5024a3a041009eb4c333863d7894d191215b39576535c6734cd88b0fcb68"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0bb18053dfcfed432cc3ac632b5e5e5c5b7e55fb3f8090e867bfd9b054dbcbf"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a5bfb3004f2144a084a16ce19ca56b8ac46e6fd0651f54269fc9e230edb5e4a"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c6b48d0fa50d8f4df3daf451be7f9689c2bde1a52b1225c5926e3f54b6a9ed1"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:051da80e6eeb6e239e394ae60704d2b566aa6a7aed6f2890a7967307267a5dc6"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a4c3b7fa4cdaa69268748665a1a6ff70c014d39bb69c50fda64b396c9116cf77"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:457b6cce21bee41ac292d6753d5e94dcbc5c9e3e3a834da285b0bde7aa4a11e9"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:aad51907d74fc183033ad796dd4c2e080d1adcc4fd3c0fd4fd499f30c03011cd"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:0385e73da22363778ef2324950e08b689abdf0b108a7d8decb403ad7f5191938"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c6a57b742133830eec44d9b2290daf5cbe0a2f1d6acee1b3c7b1c7b2f3606df7"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3e5219bf9e75993d73ab3d25985c857c77e614525fac9ae02b1bebd92f7cecac"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e5087a3c59eef624a4591ef9eaa6e9a8d8a94c779dade95d27c0bc24650261cd"}, - {file = "regex-2023.6.3-cp39-cp39-win32.whl", hash = "sha256:20326216cc2afe69b6e98528160b225d72f85ab080cbdf0b11528cbbaba2248f"}, - {file = "regex-2023.6.3-cp39-cp39-win_amd64.whl", hash = "sha256:bdff5eab10e59cf26bc479f565e25ed71a7d041d1ded04ccf9aee1d9f208487a"}, - {file = "regex-2023.6.3.tar.gz", hash = "sha256:72d1a25bf36d2050ceb35b517afe13864865268dfb45910e2e17a84be6cbfeb0"}, -] - -[[package]] -name = "requests" -version = "2.31.0" -description = "Python HTTP for Humans." -category = "main" -optional = false -python-versions = ">=3.7" -files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, -] - -[package.dependencies] -certifi = ">=2017.4.17" -charset-normalizer = ">=2,<4" -idna = ">=2.5,<4" -urllib3 = ">=1.21.1,<3" - -[package.extras] -socks = ["PySocks (>=1.5.6,!=1.5.7)"] -use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] - -[[package]] -name = "safetensors" -version = "0.3.1" -description = "Fast and Safe Tensor serialization" -category = "main" -optional = false -python-versions = "*" -files = [ - {file = "safetensors-0.3.1-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:2ae9b7dd268b4bae6624729dac86deb82104820e9786429b0583e5168db2f770"}, - {file = "safetensors-0.3.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:08c85c1934682f1e2cd904d38433b53cd2a98245a7cc31f5689f9322a2320bbf"}, - {file = "safetensors-0.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba625c7af9e1c5d0d91cb83d2fba97d29ea69d4db2015d9714d24c7f6d488e15"}, - {file = "safetensors-0.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b57d5890c619ec10d9f1b6426b8690d0c9c2868a90dc52f13fae6f6407ac141f"}, - {file = "safetensors-0.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c9f562ea696d50b95cadbeb1716dc476714a87792ffe374280c0835312cbfe2"}, - {file = "safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c115951b3a865ece8d98ee43882f2fd0a999c0200d6e6fec24134715ebe3b57"}, - {file = "safetensors-0.3.1-cp310-cp310-win32.whl", hash = "sha256:118f8f7503ea312fc7af27e934088a1b589fb1eff5a7dea2cd1de6c71ee33391"}, - {file = "safetensors-0.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:54846eaae25fded28a7bebbb66be563cad221b4c80daee39e2f55df5e5e0266f"}, - {file = "safetensors-0.3.1-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:5af82e10946c4822506db0f29269f43147e889054704dde994d4e22f0c37377b"}, - {file = "safetensors-0.3.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:626c86dd1d930963c8ea7f953a3787ae85322551e3a5203ac731d6e6f3e18f44"}, - {file = "safetensors-0.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12e30677e6af1f4cc4f2832546e91dbb3b0aa7d575bfa473d2899d524e1ace08"}, - {file = "safetensors-0.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d534b80bc8d39945bb902f34b0454773971fe9e5e1f2142af451759d7e52b356"}, - {file = "safetensors-0.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ddd0ddd502cf219666e7d30f23f196cb87e829439b52b39f3e7da7918c3416df"}, - {file = "safetensors-0.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:997a2cc14023713f423e6d16536d55cb16a3d72850f142e05f82f0d4c76d383b"}, - {file = "safetensors-0.3.1-cp311-cp311-win32.whl", hash = "sha256:6ae9ca63d9e22f71ec40550207bd284a60a6b4916ae6ca12c85a8d86bf49e0c3"}, - {file = "safetensors-0.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:62aa7421ca455418423e35029524489480adda53e3f702453580180ecfebe476"}, - {file = "safetensors-0.3.1-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:6d54b3ed367b6898baab75dfd057c24f36ec64d3938ffff2af981d56bfba2f42"}, - {file = "safetensors-0.3.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:262423aeda91117010f8c607889066028f680fbb667f50cfe6eae96f22f9d150"}, - {file = "safetensors-0.3.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10efe2513a8327fd628cea13167089588acc23093ba132aecfc536eb9a4560fe"}, - {file = "safetensors-0.3.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:689b3d6a7ebce70ee9438267ee55ea89b575c19923876645e927d08757b552fe"}, - {file = "safetensors-0.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14cd9a87bc73ce06903e9f8ee8b05b056af6f3c9f37a6bd74997a16ed36ff5f4"}, - {file = "safetensors-0.3.1-cp37-cp37m-win32.whl", hash = "sha256:a77cb39624480d5f143c1cc272184f65a296f573d61629eff5d495d2e0541d3e"}, - {file = "safetensors-0.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9eff3190bfbbb52eef729911345c643f875ca4dbb374aa6c559675cfd0ab73db"}, - {file = "safetensors-0.3.1-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:05cbfef76e4daa14796db1bbb52072d4b72a44050c368b2b1f6fd3e610669a89"}, - {file = "safetensors-0.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:c49061461f4a81e5ec3415070a3f135530834c89cbd6a7db7cd49e3cb9d9864b"}, - {file = "safetensors-0.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22cf7e73ca42974f098ce0cf4dd8918983700b6b07a4c6827d50c8daefca776e"}, - {file = "safetensors-0.3.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04f909442d6223ff0016cd2e1b2a95ef8039b92a558014627363a2e267213f62"}, - {file = "safetensors-0.3.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c573c5a0d5d45791ae8c179e26d74aff86e719056591aa7edb3ca7be55bc961"}, - {file = "safetensors-0.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6994043b12e717cf2a6ba69077ac41f0d3675b2819734f07f61819e854c622c7"}, - {file = "safetensors-0.3.1-cp38-cp38-win32.whl", hash = "sha256:158ede81694180a0dbba59422bc304a78c054b305df993c0c6e39c6330fa9348"}, - {file = "safetensors-0.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:afdc725beff7121ea8d39a7339f5a6abcb01daa189ea56290b67fe262d56e20f"}, - {file = "safetensors-0.3.1-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:cba910fcc9e5e64d32d62b837388721165e9c7e45d23bc3a38ad57694b77f40d"}, - {file = "safetensors-0.3.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a4f7dbfe7285573cdaddd85ef6fa84ebbed995d3703ab72d71257944e384612f"}, - {file = "safetensors-0.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54aed0802f9eaa83ca7b1cbb986bfb90b8e2c67b6a4bcfe245627e17dad565d4"}, - {file = "safetensors-0.3.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:34b75a766f3cfc99fd4c33e329b76deae63f5f388e455d863a5d6e99472fca8e"}, - {file = "safetensors-0.3.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a0f31904f35dc14919a145b2d7a2d8842a43a18a629affe678233c4ea90b4af"}, - {file = "safetensors-0.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcf527ecc5f58907fd9031510378105487f318cc91ecdc5aee3c7cc8f46030a8"}, - {file = "safetensors-0.3.1-cp39-cp39-win32.whl", hash = "sha256:e2f083112cf97aa9611e2a05cc170a2795eccec5f6ff837f4565f950670a9d83"}, - {file = "safetensors-0.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f4f614b8e8161cd8a9ca19c765d176a82b122fa3d3387b77862145bfe9b4e93"}, - {file = "safetensors-0.3.1.tar.gz", hash = "sha256:571da56ff8d0bec8ae54923b621cda98d36dcef10feb36fd492c4d0c2cd0e869"}, -] - -[package.extras] -all = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (>=2.11.0)", "torch (>=1.10)"] -dev = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "flax (>=0.6.3)", "h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "isort (>=5.5.4)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "numpy (>=1.21.6)", "paddlepaddle (>=2.4.1)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)", "tensorflow (>=2.11.0)", "torch (>=1.10)"] -jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)"] -numpy = ["numpy (>=1.21.6)"] -paddlepaddle = ["paddlepaddle (>=2.4.1)"] -quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] -tensorflow = ["tensorflow (>=2.11.0)"] -testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"] -torch = ["torch (>=1.10)"] - -[[package]] -name = "scikit-learn" -version = "1.2.2" -description = "A set of python modules for machine learning and data mining" -category = "main" -optional = false -python-versions = ">=3.8" -files = [ - {file = "scikit-learn-1.2.2.tar.gz", hash = "sha256:8429aea30ec24e7a8c7ed8a3fa6213adf3814a6efbea09e16e0a0c71e1a1a3d7"}, - {file = "scikit_learn-1.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99cc01184e347de485bf253d19fcb3b1a3fb0ee4cea5ee3c43ec0cc429b6d29f"}, - {file = "scikit_learn-1.2.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:e6e574db9914afcb4e11ade84fab084536a895ca60aadea3041e85b8ac963edb"}, - {file = "scikit_learn-1.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fe83b676f407f00afa388dd1fdd49e5c6612e551ed84f3b1b182858f09e987d"}, - {file = "scikit_learn-1.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e2642baa0ad1e8f8188917423dd73994bf25429f8893ddbe115be3ca3183584"}, - {file = "scikit_learn-1.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ad66c3848c0a1ec13464b2a95d0a484fd5b02ce74268eaa7e0c697b904f31d6c"}, - {file = "scikit_learn-1.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dfeaf8be72117eb61a164ea6fc8afb6dfe08c6f90365bde2dc16456e4bc8e45f"}, - {file = "scikit_learn-1.2.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:fe0aa1a7029ed3e1dcbf4a5bc675aa3b1bc468d9012ecf6c6f081251ca47f590"}, - {file = "scikit_learn-1.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:065e9673e24e0dc5113e2dd2b4ca30c9d8aa2fa90f4c0597241c93b63130d233"}, - {file = "scikit_learn-1.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf036ea7ef66115e0d49655f16febfa547886deba20149555a41d28f56fd6d3c"}, - {file = "scikit_learn-1.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:8b0670d4224a3c2d596fd572fb4fa673b2a0ccfb07152688ebd2ea0b8c61025c"}, - {file = "scikit_learn-1.2.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9c710ff9f9936ba8a3b74a455ccf0dcf59b230caa1e9ba0223773c490cab1e51"}, - {file = "scikit_learn-1.2.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:2dd3ffd3950e3d6c0c0ef9033a9b9b32d910c61bd06cb8206303fb4514b88a49"}, - {file = "scikit_learn-1.2.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44b47a305190c28dd8dd73fc9445f802b6ea716669cfc22ab1eb97b335d238b1"}, - {file = "scikit_learn-1.2.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:953236889928d104c2ef14027539f5f2609a47ebf716b8cbe4437e85dce42744"}, - {file = "scikit_learn-1.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:7f69313884e8eb311460cc2f28676d5e400bd929841a2c8eb8742ae78ebf7c20"}, - {file = "scikit_learn-1.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8156db41e1c39c69aa2d8599ab7577af53e9e5e7a57b0504e116cc73c39138dd"}, - {file = "scikit_learn-1.2.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fe175ee1dab589d2e1033657c5b6bec92a8a3b69103e3dd361b58014729975c3"}, - {file = "scikit_learn-1.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d5312d9674bed14f73773d2acf15a3272639b981e60b72c9b190a0cffed5bad"}, - {file = "scikit_learn-1.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea061bf0283bf9a9f36ea3c5d3231ba2176221bbd430abd2603b1c3b2ed85c89"}, - {file = "scikit_learn-1.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:6477eed40dbce190f9f9e9d0d37e020815825b300121307942ec2110302b66a3"}, -] - -[package.dependencies] -joblib = ">=1.1.1" -numpy = ">=1.17.3" -scipy = ">=1.3.2" -threadpoolctl = ">=2.0.0" - -[package.extras] -benchmark = ["matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "pandas (>=1.0.5)"] -docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "plotly (>=5.10.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)", "sphinx (>=4.0.1)", "sphinx-gallery (>=0.7.0)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] -examples = ["matplotlib (>=3.1.3)", "pandas (>=1.0.5)", "plotly (>=5.10.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)"] -tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy (>=0.961)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pytest (>=5.3.1)", "pytest-cov (>=2.9.0)", "scikit-image (>=0.16.2)"] - -[[package]] -name = "scipy" -version = "1.10.1" -description = "Fundamental algorithms for scientific computing in Python" -category = "main" -optional = false -python-versions = "<3.12,>=3.8" -files = [ - {file = "scipy-1.10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e7354fd7527a4b0377ce55f286805b34e8c54b91be865bac273f527e1b839019"}, - {file = "scipy-1.10.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4b3f429188c66603a1a5c549fb414e4d3bdc2a24792e061ffbd607d3d75fd84e"}, - {file = "scipy-1.10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1553b5dcddd64ba9a0d95355e63fe6c3fc303a8fd77c7bc91e77d61363f7433f"}, - {file = "scipy-1.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c0ff64b06b10e35215abce517252b375e580a6125fd5fdf6421b98efbefb2d2"}, - {file = "scipy-1.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:fae8a7b898c42dffe3f7361c40d5952b6bf32d10c4569098d276b4c547905ee1"}, - {file = "scipy-1.10.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f1564ea217e82c1bbe75ddf7285ba0709ecd503f048cb1236ae9995f64217bd"}, - {file = "scipy-1.10.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d925fa1c81b772882aa55bcc10bf88324dadb66ff85d548c71515f6689c6dac5"}, - {file = "scipy-1.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaea0a6be54462ec027de54fca511540980d1e9eea68b2d5c1dbfe084797be35"}, - {file = "scipy-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15a35c4242ec5f292c3dd364a7c71a61be87a3d4ddcc693372813c0b73c9af1d"}, - {file = "scipy-1.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:43b8e0bcb877faf0abfb613d51026cd5cc78918e9530e375727bf0625c82788f"}, - {file = "scipy-1.10.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5678f88c68ea866ed9ebe3a989091088553ba12c6090244fdae3e467b1139c35"}, - {file = "scipy-1.10.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:39becb03541f9e58243f4197584286e339029e8908c46f7221abeea4b749fa88"}, - {file = "scipy-1.10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bce5869c8d68cf383ce240e44c1d9ae7c06078a9396df68ce88a1230f93a30c1"}, - {file = "scipy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07c3457ce0b3ad5124f98a86533106b643dd811dd61b548e78cf4c8786652f6f"}, - {file = "scipy-1.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:049a8bbf0ad95277ffba9b3b7d23e5369cc39e66406d60422c8cfef40ccc8415"}, - {file = "scipy-1.10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cd9f1027ff30d90618914a64ca9b1a77a431159df0e2a195d8a9e8a04c78abf9"}, - {file = "scipy-1.10.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:79c8e5a6c6ffaf3a2262ef1be1e108a035cf4f05c14df56057b64acc5bebffb6"}, - {file = "scipy-1.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51af417a000d2dbe1ec6c372dfe688e041a7084da4fdd350aeb139bd3fb55353"}, - {file = "scipy-1.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b4735d6c28aad3cdcf52117e0e91d6b39acd4272f3f5cd9907c24ee931ad601"}, - {file = "scipy-1.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:7ff7f37b1bf4417baca958d254e8e2875d0cc23aaadbe65b3d5b3077b0eb23ea"}, - {file = "scipy-1.10.1.tar.gz", hash = "sha256:2cf9dfb80a7b4589ba4c40ce7588986d6d5cebc5457cad2c2880f6bc2d42f3a5"}, -] - -[package.dependencies] -numpy = ">=1.19.5,<1.27.0" - -[package.extras] -dev = ["click", "doit (>=0.36.0)", "flake8", "mypy", "pycodestyle", "pydevtool", "rich-click", "typing_extensions"] -doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] -test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] - -[[package]] -name = "signatory" -version = "1.2.6.1.9.0" -description = "Differentiable computations of the signature and logsignature transforms, on both CPU and GPU." -category = "main" -optional = false -python-versions = "~=3.6" -files = [ - {file = "signatory-1.2.6.1.9.0-cp36-cp36m-win_amd64.whl", hash = "sha256:f29f3f7a95881f053074481c6d0dbddb5720dae53b4402874a2b6579cb07861a"}, - {file = "signatory-1.2.6.1.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8ce39d1f7712fd6d8c6b0c6d10c225345dd71fd2a97cf81306c7b93b8a9d671d"}, - {file = "signatory-1.2.6.1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:8c423539c676fb777c7de094b0fd9b97d54eb063dc507d12ec02e48709a0f96b"}, - {file = "signatory-1.2.6.1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:9e7dc939064ab7b7576450585cdb3fac2d22adb507bc80e87e1acded2d248b53"}, - {file = "signatory-1.2.6.1.9.0.tar.gz", hash = "sha256:c9ab17df6286688c1c7b12c0e906031a1fe7681dfca551876c1ff616e12807d8"}, -] - -[[package]] -name = "six" -version = "1.16.0" -description = "Python 2 and 3 compatibility utilities" -category = "main" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" -files = [ - {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, - {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, -] - -[[package]] -name = "threadpoolctl" -version = "3.1.0" -description = "threadpoolctl" -category = "main" -optional = false -python-versions = ">=3.6" -files = [ - {file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"}, - {file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"}, -] - -[[package]] -name = "tokenizers" -version = "0.13.3" -description = "Fast and Customizable Tokenizers" -category = "main" -optional = false -python-versions = "*" -files = [ - {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, - {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"}, - {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"}, - {file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"}, - {file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"}, - {file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"}, - {file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"}, - {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"}, - {file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"}, - {file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"}, - {file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"}, - {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"}, - {file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"}, - {file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"}, - {file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"}, - {file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"}, - {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"}, - {file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"}, - {file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"}, - {file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"}, - {file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"}, - {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"}, - {file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"}, - {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, - {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, -] - -[package.extras] -dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] -docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] -testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] - -[[package]] -name = "torch" -version = "1.9.0" -description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -category = "main" -optional = false -python-versions = ">=3.6.2" -files = [ - {file = "torch-1.9.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:3a2d070cf28860d285d4ab156f3954c0c1d12f4c037aa312a7c029227c0d106b"}, - {file = "torch-1.9.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:b296e65e25081af147af936f1e3a1f17f583a9afacfa5309742678ffef728ace"}, - {file = "torch-1.9.0-cp36-cp36m-win_amd64.whl", hash = "sha256:117098d4924b260a24a47c6b3fe37f2ae41f04a2ea2eff9f553ae9210b12fa54"}, - {file = "torch-1.9.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:d6103b9a634993bd967337a1149f9d8b23922f42a3660676239399e15c1b4515"}, - {file = "torch-1.9.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:0164673908e6b291ace592d382eba3e258b3bad009b8078cad8f3b9e00d8f23e"}, - {file = "torch-1.9.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:52548b45efff772fe3810fe91daf34f981ac0ca1a7227f6226fd5693f53b5b88"}, - {file = "torch-1.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:62c0a7e433681d0861494d1ede96d2485e4dbb3ea8fd867e8419addebf5de1af"}, - {file = "torch-1.9.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:d88333091fd1627894bbf0d6dcef58a90e36bdf0d90a5d4675b5e07e72075511"}, - {file = "torch-1.9.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:1d8139dcc864f48dc316376384f50e47a459284ad1cb84449242f4964e25aaec"}, - {file = "torch-1.9.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:0aa4cca3f16fab40cb8dae6a49d0eccdc8f4ead9d1a6428cd9ba12befe082b2a"}, - {file = "torch-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:646de1bef85d6c7590e98f8ea52e47acdcf58330982e4f5d73f5ca28dea2d552"}, - {file = "torch-1.9.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:e596f0105f748cf09d4763152d8157aaf58d5231232eaf2c5673d4562ba86ad3"}, - {file = "torch-1.9.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:ecc7193fff7741ced3db1f760666c8454d6664956288c54d1b49613b987a42f4"}, - {file = "torch-1.9.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:95eeec3a6c42fd35aca552777b7d9979ed489760423de97c0118a45e849a61f4"}, - {file = "torch-1.9.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8a2b2012b3c7d6019e189496688fa77de7029a220840b406d8302d1c8021a11c"}, - {file = "torch-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:7e2b14fe5b3a8266cbe2f6740c0195497507974ced7bc21e99971561913a0c28"}, - {file = "torch-1.9.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:0a9e74b5057463ce4e55d9332a5670993fc9e1299c52e1740e505eda106fb355"}, - {file = "torch-1.9.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:569ead6ae6bb0e636df0fc8af660ef03260e630dc5f2f4cf3198027e7b6bb481"}, -] - -[package.dependencies] -typing-extensions = "*" - -[[package]] -name = "torchvision" -version = "0.10.0" -description = "image and video datasets and models for torch deep learning" -category = "main" -optional = false -python-versions = "*" -files = [ - {file = "torchvision-0.10.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f26c5ba220445ee8e892033234485c9276304874e87ec9d5146779167be3148d"}, - {file = "torchvision-0.10.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:d5c80cf058c0ecb010a97dc71225f5274e45c596bfad7505e0d000abcccb7063"}, - {file = "torchvision-0.10.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:6d69bf15a6e885e3b8c674a524bc6b53016879cb8b0fd8537327edd2d243cab1"}, - {file = "torchvision-0.10.0-cp36-cp36m-win_amd64.whl", hash = "sha256:6b917d4762deaaa4c0cdd106403ea8384a1fdf93de424097bd71f3ebfdc76b41"}, - {file = "torchvision-0.10.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a405d968e7e94c0744870eef31c977635e2123b0b46becc1461a28b7c27d3c0c"}, - {file = "torchvision-0.10.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:9597da592f76d22d9e80a4a072294e093f8c3a06c404f3ff237f359b9225e097"}, - {file = "torchvision-0.10.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:ce2aae5567522f2a877c6334796721af07c164e94ff75876821fadb3310cfe7e"}, - {file = "torchvision-0.10.0-cp37-cp37m-win_amd64.whl", hash = "sha256:576d7b070f25cbfc78a710960fd8fa6d3961d640db05f7ace69d9a3e5bbf754a"}, - {file = "torchvision-0.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ad24107b2ed0ccc372af92822f1f8f5530907b6fb7520a08195cf0bb07446923"}, - {file = "torchvision-0.10.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bb04708354fb6d639f6e47d8066b0d546fbe0a3a68685cf8d413a6370c8f63ad"}, - {file = "torchvision-0.10.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:aa709fa63896f93e03a03976230a51050fcd5f1b45cf663f62d91b7eaaf8ac09"}, - {file = "torchvision-0.10.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:4796e1fb2995c6e495c5ea4e2b0fe0e4be44bd9416ef4a1349c1a406675cbdee"}, - {file = "torchvision-0.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:487bbfd89575a52cd18bca8a33e24c373570e060f801265051c3a0aafc769720"}, - {file = "torchvision-0.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f68291559a4cf1245c95efc5e47ebe158819aceec4e1f585d2fe133cd2c9d8e8"}, - {file = "torchvision-0.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ae9606cc248de9b3a077710529b11c57315d2914c8ee3099fbd93a62f56a1661"}, - {file = "torchvision-0.10.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:20a57ce42fa20c26d800c65d5b88dbaaa115a01f4f5623d41abfb182b854f199"}, - {file = "torchvision-0.10.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a30466893a5b97073c992859f3645e3e1f41daf2c1b4db6cb2ac8ec7d0e1f6bc"}, - {file = "torchvision-0.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:da65af25b51caf43327ecb3ccf550eedfd62d1f73511db44370b4b9522569b8d"}, -] - -[package.dependencies] -numpy = "*" -pillow = ">=5.3.0" -torch = "1.9.0" - -[package.extras] -scipy = ["scipy"] - -[[package]] -name = "tqdm" -version = "4.65.0" -description = "Fast, Extensible Progress Meter" -category = "main" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tqdm-4.65.0-py3-none-any.whl", hash = "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"}, - {file = "tqdm-4.65.0.tar.gz", hash = "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5"}, -] - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - -[package.extras] -dev = ["py-make (>=0.1.0)", "twine", "wheel"] -notebook = ["ipywidgets (>=6)"] -slack = ["slack-sdk"] -telegram = ["requests"] - -[[package]] -name = "transformers" -version = "4.30.2" -description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" -category = "main" -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "transformers-4.30.2-py3-none-any.whl", hash = "sha256:c332e3a3097f9ed89ce556b403251235931c00237b8bc2d7adaa19d226c13f1d"}, - {file = "transformers-4.30.2.tar.gz", hash = "sha256:f4a8aac4e1baffab4033f4a345b0d7dc7957d12a4f1ba969afea08205a513045"}, -] - -[package.dependencies] -filelock = "*" -huggingface-hub = ">=0.14.1,<1.0" -numpy = ">=1.17" -packaging = ">=20.0" -pyyaml = ">=5.1" -regex = "!=2019.12.17" -requests = "*" -safetensors = ">=0.3.1" -tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14" -tqdm = ">=4.27" - -[package.extras] -accelerate = ["accelerate (>=0.20.2)"] -agents = ["Pillow", "accelerate (>=0.20.2)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.9,!=1.12.0)"] -all = ["Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.3)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] -audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -codecarbon = ["codecarbon (==1.2.0)"] -deepspeed = ["accelerate (>=0.20.2)", "deepspeed (>=0.8.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.2)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.8.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf (<=3.20.3)", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.6.9)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.20.2)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.3)", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow", "accelerate (>=0.20.2)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.6.9)", "hf-doc-builder", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf (<=3.20.3)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "torchaudio", "torchvision"] -docs-specific = ["hf-doc-builder"] -fairscale = ["fairscale (>0.3)"] -flax = ["flax (>=0.4.1,<=0.6.9)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "optax (>=0.0.8,<=0.1.4)"] -flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -ftfy = ["ftfy"] -integrations = ["optuna", "ray[tune]", "sigopt"] -ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] -modelcreation = ["cookiecutter (==1.7.3)"] -natten = ["natten (>=0.14.6)"] -onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] -onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] -optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"] -ray = ["ray[tune]"] -retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] -sagemaker = ["sagemaker (>=2.31.0)"] -sentencepiece = ["protobuf (<=3.20.3)", "sentencepiece (>=0.1.91,!=0.1.92)"] -serving = ["fastapi", "pydantic", "starlette", "uvicorn"] -sigopt = ["sigopt"] -sklearn = ["scikit-learn"] -speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf (<=3.20.3)", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] -tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.4,<2.13)", "tensorflow-text (<2.13)", "tf2onnx"] -tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -timm = ["timm"] -tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"] -torch = ["accelerate (>=0.20.2)", "torch (>=1.9,!=1.12.0)"] -torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -torch-vision = ["Pillow", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.14.1,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf (<=3.20.3)", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.9,!=1.12.0)", "tqdm (>=4.27)"] -video = ["av (==9.2.0)", "decord (==0.6.0)"] -vision = ["Pillow"] - -[[package]] -name = "typing-extensions" -version = "4.5.0" -description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" -optional = false -python-versions = ">=3.7" -files = [ - {file = "typing_extensions-4.5.0-py3-none-any.whl", hash = "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4"}, - {file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"}, -] - -[[package]] -name = "urllib3" -version = "2.0.4" -description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "main" -optional = false -python-versions = ">=3.7" -files = [ - {file = "urllib3-2.0.4-py3-none-any.whl", hash = "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4"}, - {file = "urllib3-2.0.4.tar.gz", hash = "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11"}, -] - -[package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] -socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] -zstd = ["zstandard (>=0.18.0)"] - -[metadata] -lock-version = "2.0" -python-versions = ">=3.8, <3.9" -content-hash = "19d083b8c1dd876301a7caae2ee4a5875ad7454497511d530fb7d04ab7377f9d" diff --git a/pyproject.toml b/pyproject.toml index 65b785c..a5f3772 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,26 +1,122 @@ -[tool.poetry] -name = "nlpsig_networks" -version = "0.1.0" -description = "" +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + + +[project] +name = "sig_networks" authors = [ - "Ryan Chan ", - "Talia Tseriotou ", - "Kasra Hosseini " + { name = "Ryan Chan", email = "rchan@turing.ac.uk" }, + { name = "Talia Tseriotou", email = "t.tseriotou@qmul.ac.uk" }, + { name = "Kasra Hosseini", email = "khosseini@turing.ac.uk" }, ] +description = "Neural networks for longitudinal NLP classification tasks." readme = "README.md" -packages = [{include = "nlpsig_networks"}] - -[tool.poetry.dependencies] -python = ">=3.8, <3.9" -nlpsig = "^0.2.0" -torch = "1.9.0" -torchvision = "0.10.0" -signatory = "1.2.6.1.9.0" -scikit-learn = "^1.2.2" -pandas = "^1.5.3" -accelerate = "0.20.1" -transformers = "4.30.2" +requires-python = ">=3.8, <3.9" +classifiers = [ + "Development Status :: 1 - Planning", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", +] +dynamic = ["version"] +dependencies = [ + "torch == 1.9.0", + "torchvision == 0.10.0", + "signatory == 1.2.6.1.9.0", + "tdqm >= 0.0.1", + "pandas >= 1.5.3", + "umap-learn >= 0.5.3", + "scikit-learn >= 1.2.2", + "datasets >= 2.6.1", + "distinctipy >= 1.2.2", + "evaluate >= 0.4.0", + "accelerate == 0.20.1", + "transformers == 4.30.2" +] -[build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +[project.optional-dependencies] +test = [ + "pytest >=6", + "pytest-cov >=3", +] +dev = [ + "pytest >=6", + "pytest-cov >=3", +] +docs = [ + "sphinx>=4.0", + "myst_parser>=0.13", + "sphinx-book-theme>=0.1.0", + "sphinx_copybutton", + "furo", +] + +[project.urls] +Homepage = "https://github.com/datasig-ac-uk/nlpsig" +"Bug Tracker" = "https://github.com/datasig-ac-uk/nlpsig/issues" +Discussions = "https://github.com/datasig-ac-uk/nlpsig/discussions" +Changelog = "https://github.com/datasig-ac-uk/nlpsig/releases" +[tool.hatch] +version.path = "src/sig_networks/__init__.py" +envs.default.dependencies = [ + "pytest", + "pytest-cov", +] + + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] +xfail_strict = true +filterwarnings = [ + "error", + "ignore::UserWarning", + "ignore:Tensorflow not installed; ParametricUMAP will be unavailable:ImportWarning", # umap + "ignore:pkg_resources is deprecated as an API:DeprecationWarning", # umap + "ignore:Deprecated call to *:DeprecationWarning", + "ignore:numba.core.errors.NumbaDeprecationWarning", # umap using numba + "ignore:numba.core.errors.NumbaPendingDeprecationWarning", # umap using numba +] +log_cli_level = "INFO" +testpaths = [ + "tests", +] + + +[tool.mypy] +files = "src" +python_version = "3.8" +# warn_unused_configs = true +# strict = true +# show_error_codes = true +# enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +# warn_unreachable = true + + +[tool.ruff.per-file-ignores] +"tests/**" = ["T20"] +"noxfile.py" = ["T20"] + + +[tool.pylint] +py-version = "3.8" +ignore-paths= ["src/nlpsig/_version.py"] +reports.output-format = "colorized" +similarities.ignore-imports = "yes" +messages_control.disable = [ + "design", + "fixme", + "line-too-long", + "missing-module-docstring", + "wrong-import-position", +] diff --git a/src/sig_networks/__init__.py b/src/sig_networks/__init__.py new file mode 100644 index 0000000..f11ec6a --- /dev/null +++ b/src/sig_networks/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +__version__ = "0.1.0" diff --git a/nlpsig_networks/feature_concatenation.py b/src/sig_networks/feature_concatenation.py similarity index 99% rename from nlpsig_networks/feature_concatenation.py rename to src/sig_networks/feature_concatenation.py index e7395a9..e2dd2dc 100644 --- a/nlpsig_networks/feature_concatenation.py +++ b/src/sig_networks/feature_concatenation.py @@ -42,7 +42,7 @@ def __init__( - scaled_concatenation: concatenation of single value scaled path signature and embedding vector """ - super(FeatureConcatenation, self).__init__() + super().__init__() if comb_method not in [ "concatenation", diff --git a/nlpsig_networks/ffn_baseline.py b/src/sig_networks/ffn_baseline.py similarity index 97% rename from nlpsig_networks/ffn_baseline.py rename to src/sig_networks/ffn_baseline.py index 9303b7c..18a3d83 100644 --- a/nlpsig_networks/ffn_baseline.py +++ b/src/sig_networks/ffn_baseline.py @@ -32,7 +32,7 @@ def __init__( dropout_rate : float Probability of dropout. """ - super(FeedforwardNeuralNetModel, self).__init__() + super().__init__() if type(hidden_dim) == int: hidden_dim = [hidden_dim] diff --git a/nlpsig_networks/focal_loss.py b/src/sig_networks/focal_loss.py similarity index 95% rename from nlpsig_networks/focal_loss.py rename to src/sig_networks/focal_loss.py index 97f69e6..8427845 100644 --- a/nlpsig_networks/focal_loss.py +++ b/src/sig_networks/focal_loss.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import math -from typing import List, Optional, Union import numpy as np import torch @@ -18,10 +19,10 @@ class FocalLoss(nn.Module): def __init__( self, gamma: float = 0.0, - alpha: Optional[Union[float, list]] = None, + alpha: float | list | None = None, size_average: bool = True, ): - super(FocalLoss, self).__init__() + super().__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha, (float, int)): @@ -73,9 +74,9 @@ def __init__( gamma: float, beta: float, no_of_classes: int, - samples_per_cls: Optional[List] = None, + samples_per_cls: list | None = None, ): - super(ClassBalanced_FocalLoss, self).__init__() + super().__init__() self.gamma = gamma self.beta = beta self.no_of_classes = no_of_classes diff --git a/nlpsig_networks/huggingface_loader.py b/src/sig_networks/huggingface_loader.py similarity index 97% rename from nlpsig_networks/huggingface_loader.py rename to src/sig_networks/huggingface_loader.py index 3f0c9e3..cd77fcf 100644 --- a/nlpsig_networks/huggingface_loader.py +++ b/src/sig_networks/huggingface_loader.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from random import choice, randrange -from typing import List import pandas as pd from datasets import load_dataset @@ -86,7 +87,7 @@ def default_preprocess_newspop(self) -> None: self.dataset_df = dataset_df.replace(encode_labels) print("[INFO] preprocessed dataframe can be accessed: .dataset_df") - def _list_default_datetimes(self) -> List[str]: + def _list_default_datetimes(self) -> list[str]: list_datetimes = [ "2015-01-01 00:00:00", "2015-01-01 00:12:00", diff --git a/nlpsig_networks/lstm_baseline.py b/src/sig_networks/lstm_baseline.py similarity index 98% rename from nlpsig_networks/lstm_baseline.py rename to src/sig_networks/lstm_baseline.py index 864ddaa..14a8c2b 100644 --- a/nlpsig_networks/lstm_baseline.py +++ b/src/sig_networks/lstm_baseline.py @@ -37,7 +37,7 @@ def __init__( dropout_rate : float Probability of dropout. """ - super(LSTMModel, self).__init__() + super().__init__() self.hidden_dim = hidden_dim self.bidirectional = bidirectional diff --git a/nlpsig_networks/pytorch_utils.py b/src/sig_networks/pytorch_utils.py similarity index 96% rename from nlpsig_networks/pytorch_utils.py rename to src/sig_networks/pytorch_utils.py index de5495e..d368a0e 100644 --- a/nlpsig_networks/pytorch_utils.py +++ b/src/sig_networks/pytorch_utils.py @@ -2,7 +2,7 @@ import datetime import os -from typing import Optional +from uuid import uuid4 import numpy as np import pandas as pd @@ -14,9 +14,8 @@ from torch.optim.optimizer import Optimizer from torch.utils.data.dataloader import DataLoader from tqdm.auto import tqdm -from uuid import uuid4 -from nlpsig_networks.focal_loss import ClassBalanced_FocalLoss, FocalLoss +from sig_networks.focal_loss import ClassBalanced_FocalLoss, FocalLoss def _get_timestamp(add_runid: bool = True) -> str: @@ -370,16 +369,15 @@ def validation_pytorch( # compute macro recall score recall = sum(recall_scores) / len(recall_scores) - if verbose: - if epoch % verbose_epoch == 0: - print( - f"[Validation] || Epoch: {epoch+1} || " - f"Loss: {total_loss / len(valid_loader)} || " - f"Accuracy: {accuracy} || " - f"F1-score: {f1} || " - f"Precision: {precision} ||" - f"Recall: {recall}" - ) + if verbose and epoch % verbose_epoch == 0: + print( + f"[Validation] || Epoch: {epoch+1} || " + f"Loss: {total_loss / len(valid_loader)} || " + f"Accuracy: {accuracy} || " + f"F1-score: {f1} || " + f"Precision: {precision} ||" + f"Recall: {recall}" + ) return { "loss": total_loss / len(valid_loader), @@ -399,15 +397,15 @@ def training_pytorch( criterion: nn.Module, optimizer: Optimizer, num_epochs: int, - scheduler: Optional[_LRScheduler] = None, - valid_loader: Optional[DataLoader] = None, - seed: Optional[int] = 42, + scheduler: _LRScheduler | None = None, + valid_loader: DataLoader | None = None, + seed: int | None = 42, return_best: bool = False, save_best: bool = False, output: str = f"best_model_{_get_timestamp()}.pkl", early_stopping: bool = False, validation_metric: str = "loss", - patience: Optional[int] = 10, + patience: int | None = 10, device: str | None = None, verbose: bool = False, verbose_epoch: int = 100, @@ -475,10 +473,7 @@ def training_pytorch( ) # model train (& validation) per epoch - if verbose: - epochs_loop = tqdm(range(num_epochs)) - else: - epochs_loop = range(num_epochs) + epochs_loop = tqdm(range(num_epochs)) if verbose else range(num_epochs) for epoch in epochs_loop: # iterate through the training dataloader @@ -496,13 +491,11 @@ def training_pytorch( optimizer.step() # show training progress - if verbose: - if epoch % verbose_epoch == 0: - print("-" * 50) - print( - f"[Train] | Epoch: {epoch+1}/{num_epochs} || " - + f"Loss: {loss.item()}" - ) + if verbose and epoch % verbose_epoch == 0: + print("-" * 50) + print( + f"[Train] | Epoch: {epoch+1}/{num_epochs} || " + f"Loss: {loss.item()}" + ) if isinstance(valid_loader, DataLoader): # compute loss, accuracy and F1 on validation set @@ -559,12 +552,11 @@ def training_pytorch( if save_best | return_best: checkpoint = torch.load(f=output) model.load_state_dict(checkpoint["model_state_dict"]) - if save_best: - if verbose: - print( - f"Returning the best model which occurred at " - f"epoch {checkpoint['epoch']}" - ) + if save_best and verbose: + print( + f"Returning the best model which occurred at " + f"epoch {checkpoint['epoch']}" + ) if return_best: if not save_best: os.remove(output) @@ -684,10 +676,10 @@ def KFold_pytorch( num_epochs: int, batch_size: int = 64, return_metric_for_each_fold: bool = False, - seed: Optional[int] = 42, + seed: int | None = 42, return_best: bool = False, early_stopping: bool = False, - patience: Optional[int] = 10, + patience: int | None = 10, device: str | None = None, verbose: bool = False, verbose_epoch: int = 100, diff --git a/nlpsig_networks/__init__.py b/src/sig_networks/scripts/__init__.py similarity index 100% rename from nlpsig_networks/__init__.py rename to src/sig_networks/scripts/__init__.py diff --git a/nlpsig_networks/scripts/ffn_baseline_functions.py b/src/sig_networks/scripts/ffn_baseline_functions.py similarity index 99% rename from nlpsig_networks/scripts/ffn_baseline_functions.py rename to src/sig_networks/scripts/ffn_baseline_functions.py index b8ee5d7..e70d2bb 100644 --- a/nlpsig_networks/scripts/ffn_baseline_functions.py +++ b/src/sig_networks/scripts/ffn_baseline_functions.py @@ -10,9 +10,9 @@ import torch from tqdm.auto import tqdm -from nlpsig_networks.ffn_baseline import FeedforwardNeuralNetModel -from nlpsig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed -from nlpsig_networks.scripts.implement_model import implement_model +from sig_networks.ffn_baseline import FeedforwardNeuralNetModel +from sig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed +from sig_networks.scripts.implement_model import implement_model def implement_ffn( diff --git a/nlpsig_networks/scripts/fine_tune_bert_classification.py b/src/sig_networks/scripts/fine_tune_bert_classification.py similarity index 99% rename from nlpsig_networks/scripts/fine_tune_bert_classification.py rename to src/sig_networks/scripts/fine_tune_bert_classification.py index ec2251e..d08890e 100644 --- a/nlpsig_networks/scripts/fine_tune_bert_classification.py +++ b/src/sig_networks/scripts/fine_tune_bert_classification.py @@ -12,16 +12,16 @@ from nlpsig import TextEncoder from nlpsig.classification_utils import DataSplits, Folds from sklearn import metrics +from tqdm.auto import tqdm from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, ) -from tqdm.auto import tqdm -from nlpsig_networks.focal_loss import FocalLoss -from nlpsig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed +from sig_networks.focal_loss import FocalLoss +from sig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed def testing_transformer( @@ -483,10 +483,10 @@ def fine_tune_transformer_for_classification( valid_recall = [] valid_recall_scores = [] - labels = torch.empty((0)) - predicted = torch.empty((0)) - valid_labels = torch.empty((0)) - valid_predicted = torch.empty((0)) + labels = torch.empty(0) + predicted = torch.empty(0) + valid_labels = torch.empty(0) + valid_predicted = torch.empty(0) for k in range(n_splits): # compute how well the model performs on this fold results_for_fold = _fine_tune_transformer_for_data_split( diff --git a/nlpsig_networks/scripts/implement_model.py b/src/sig_networks/scripts/implement_model.py similarity index 99% rename from nlpsig_networks/scripts/implement_model.py rename to src/sig_networks/scripts/implement_model.py index 5296790..c8ea405 100644 --- a/nlpsig_networks/scripts/implement_model.py +++ b/src/sig_networks/scripts/implement_model.py @@ -9,8 +9,8 @@ import torch.nn as nn from nlpsig.classification_utils import DataSplits, Folds -from nlpsig_networks.focal_loss import FocalLoss -from nlpsig_networks.pytorch_utils import ( +from sig_networks.focal_loss import FocalLoss +from sig_networks.pytorch_utils import ( KFold_pytorch, _get_timestamp, testing_pytorch, diff --git a/nlpsig_networks/scripts/lstm_baseline_functions.py b/src/sig_networks/scripts/lstm_baseline_functions.py similarity index 99% rename from nlpsig_networks/scripts/lstm_baseline_functions.py rename to src/sig_networks/scripts/lstm_baseline_functions.py index 08381bb..9aeae95 100644 --- a/nlpsig_networks/scripts/lstm_baseline_functions.py +++ b/src/sig_networks/scripts/lstm_baseline_functions.py @@ -9,9 +9,9 @@ import torch from tqdm.auto import tqdm -from nlpsig_networks.lstm_baseline import LSTMModel -from nlpsig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed -from nlpsig_networks.scripts.implement_model import implement_model +from sig_networks.lstm_baseline import LSTMModel +from sig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed +from sig_networks.scripts.implement_model import implement_model def implement_lstm( diff --git a/nlpsig_networks/scripts/seqsignet_attention_bilstm_functions.py b/src/sig_networks/scripts/seqsignet_attention_bilstm_functions.py similarity index 99% rename from nlpsig_networks/scripts/seqsignet_attention_bilstm_functions.py rename to src/sig_networks/scripts/seqsignet_attention_bilstm_functions.py index a8b5eb9..48bed76 100644 --- a/nlpsig_networks/scripts/seqsignet_attention_bilstm_functions.py +++ b/src/sig_networks/scripts/seqsignet_attention_bilstm_functions.py @@ -8,10 +8,10 @@ import torch from tqdm.auto import tqdm -from nlpsig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed -from nlpsig_networks.scripts.implement_model import implement_model -from nlpsig_networks.scripts.seqsignet_functions import obtain_SeqSigNet_input -from nlpsig_networks.seqsignet_attention_bilstm import SeqSigNetAttentionBiLSTM +from sig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed +from sig_networks.scripts.implement_model import implement_model +from sig_networks.scripts.seqsignet_functions import obtain_SeqSigNet_input +from sig_networks.seqsignet_attention_bilstm import SeqSigNetAttentionBiLSTM def implement_seqsignet_attention_bilstm( diff --git a/nlpsig_networks/scripts/seqsignet_attention_encoder_functions.py b/src/sig_networks/scripts/seqsignet_attention_encoder_functions.py similarity index 99% rename from nlpsig_networks/scripts/seqsignet_attention_encoder_functions.py rename to src/sig_networks/scripts/seqsignet_attention_encoder_functions.py index e70bcf8..d7b7a53 100644 --- a/nlpsig_networks/scripts/seqsignet_attention_encoder_functions.py +++ b/src/sig_networks/scripts/seqsignet_attention_encoder_functions.py @@ -8,10 +8,10 @@ import torch from tqdm.auto import tqdm -from nlpsig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed -from nlpsig_networks.scripts.implement_model import implement_model -from nlpsig_networks.scripts.seqsignet_functions import obtain_SeqSigNet_input -from nlpsig_networks.seqsignet_attention_encoder import ( +from sig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed +from sig_networks.scripts.implement_model import implement_model +from sig_networks.scripts.seqsignet_functions import obtain_SeqSigNet_input +from sig_networks.seqsignet_attention_encoder import ( SeqSigNetAttentionEncoder, ) diff --git a/nlpsig_networks/scripts/seqsignet_functions.py b/src/sig_networks/scripts/seqsignet_functions.py similarity index 99% rename from nlpsig_networks/scripts/seqsignet_functions.py rename to src/sig_networks/scripts/seqsignet_functions.py index b936c54..a579fc5 100644 --- a/nlpsig_networks/scripts/seqsignet_functions.py +++ b/src/sig_networks/scripts/seqsignet_functions.py @@ -9,9 +9,9 @@ import torch from tqdm.auto import tqdm -from nlpsig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed -from nlpsig_networks.scripts.implement_model import implement_model -from nlpsig_networks.seqsignet_bilstm import SeqSigNet +from sig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed +from sig_networks.scripts.implement_model import implement_model +from sig_networks.seqsignet_bilstm import SeqSigNet def obtain_SeqSigNet_input( diff --git a/nlpsig_networks/scripts/swmhau_network_functions.py b/src/sig_networks/scripts/swmhau_network_functions.py similarity index 99% rename from nlpsig_networks/scripts/swmhau_network_functions.py rename to src/sig_networks/scripts/swmhau_network_functions.py index 2356592..4018ef0 100644 --- a/nlpsig_networks/scripts/swmhau_network_functions.py +++ b/src/sig_networks/scripts/swmhau_network_functions.py @@ -8,10 +8,10 @@ import torch from tqdm.auto import tqdm -from nlpsig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed -from nlpsig_networks.scripts.implement_model import implement_model -from nlpsig_networks.scripts.swnu_network_functions import obtain_SWNUNetwork_input -from nlpsig_networks.swmhau_network import SWMHAUNetwork +from sig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed +from sig_networks.scripts.implement_model import implement_model +from sig_networks.scripts.swnu_network_functions import obtain_SWNUNetwork_input +from sig_networks.swmhau_network import SWMHAUNetwork def implement_swmhau_network( diff --git a/nlpsig_networks/scripts/swnu_network_functions.py b/src/sig_networks/scripts/swnu_network_functions.py similarity index 99% rename from nlpsig_networks/scripts/swnu_network_functions.py rename to src/sig_networks/scripts/swnu_network_functions.py index 137baf2..cf87354 100644 --- a/nlpsig_networks/scripts/swnu_network_functions.py +++ b/src/sig_networks/scripts/swnu_network_functions.py @@ -9,9 +9,9 @@ import torch from tqdm.auto import tqdm -from nlpsig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed -from nlpsig_networks.scripts.implement_model import implement_model -from nlpsig_networks.swnu_network import SWNUNetwork +from sig_networks.pytorch_utils import SaveBestModel, _get_timestamp, set_seed +from sig_networks.scripts.implement_model import implement_model +from sig_networks.swnu_network import SWNUNetwork def obtain_SWNUNetwork_input( @@ -864,7 +864,7 @@ def swnu_network_hyperparameter_search( None if checkpoint["extra_info"]["hidden_dim_aug"] is None else [ - tuple([checkpoint["extra_info"]["hidden_dim_aug"]]) + (checkpoint["extra_info"]["hidden_dim_aug"],) for _ in range(len(test_results.index)) ] if (type(checkpoint["extra_info"]["hidden_dim_aug"]) == int) diff --git a/nlpsig_networks/seqsignet_attention_bilstm.py b/src/sig_networks/seqsignet_attention_bilstm.py similarity index 96% rename from nlpsig_networks/seqsignet_attention_bilstm.py rename to src/sig_networks/seqsignet_attention_bilstm.py index 4bc1de1..84212c6 100644 --- a/nlpsig_networks/seqsignet_attention_bilstm.py +++ b/src/sig_networks/seqsignet_attention_bilstm.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn -from nlpsig_networks.feature_concatenation import FeatureConcatenation -from nlpsig_networks.ffn_baseline import FeedforwardNeuralNetModel -from nlpsig_networks.swmhau import SWMHAU +from sig_networks.feature_concatenation import FeatureConcatenation +from sig_networks.ffn_baseline import FeedforwardNeuralNetModel +from sig_networks.swmhau import SWMHAU class SeqSigNetAttentionBiLSTM(nn.Module): @@ -98,7 +98,7 @@ def __init__( - scaled_concatenation: concatenation of single value scaled path signature and embedding vector """ - super(SeqSigNetAttentionBiLSTM, self).__init__() + super().__init__() # SWMHAU applied to the input (the unit includes the convolution layer) self.swmhau = SWMHAU( diff --git a/nlpsig_networks/seqsignet_attention_encoder.py b/src/sig_networks/seqsignet_attention_encoder.py similarity index 96% rename from nlpsig_networks/seqsignet_attention_encoder.py rename to src/sig_networks/seqsignet_attention_encoder.py index 554d77d..239f352 100644 --- a/nlpsig_networks/seqsignet_attention_encoder.py +++ b/src/sig_networks/seqsignet_attention_encoder.py @@ -3,10 +3,10 @@ import torch import torch.nn as nn -from nlpsig_networks.feature_concatenation import FeatureConcatenation -from nlpsig_networks.ffn_baseline import FeedforwardNeuralNetModel -from nlpsig_networks.swmhau import SWMHAU -from nlpsig_networks.utils import obtain_signatures_mask +from sig_networks.feature_concatenation import FeatureConcatenation +from sig_networks.ffn_baseline import FeedforwardNeuralNetModel +from sig_networks.swmhau import SWMHAU +from sig_networks.utils import obtain_signatures_mask class SeqSigNetAttentionEncoder(nn.Module): @@ -101,7 +101,7 @@ def __init__( - scaled_concatenation: concatenation of single value scaled path signature and embedding vector """ - super(SeqSigNetAttentionEncoder, self).__init__() + super().__init__() if transformer_encoder_layers < 1: raise ValueError( diff --git a/nlpsig_networks/seqsignet_bilstm.py b/src/sig_networks/seqsignet_bilstm.py similarity index 97% rename from nlpsig_networks/seqsignet_bilstm.py rename to src/sig_networks/seqsignet_bilstm.py index ca32107..47e28cf 100644 --- a/nlpsig_networks/seqsignet_bilstm.py +++ b/src/sig_networks/seqsignet_bilstm.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn -from nlpsig_networks.feature_concatenation import FeatureConcatenation -from nlpsig_networks.ffn_baseline import FeedforwardNeuralNetModel -from nlpsig_networks.swnu import SWNU +from sig_networks.feature_concatenation import FeatureConcatenation +from sig_networks.ffn_baseline import FeedforwardNeuralNetModel +from sig_networks.swnu import SWNU class SeqSigNet(nn.Module): @@ -101,7 +101,7 @@ def __init__( - scaled_concatenation: concatenation of single value scaled path signature and embedding vector """ - super(SeqSigNet, self).__init__() + super().__init__() if pooling not in ["signature", "lstm"]: raise ValueError( diff --git a/nlpsig_networks/specific_classification_utils.py b/src/sig_networks/specific_classification_utils.py similarity index 98% rename from nlpsig_networks/specific_classification_utils.py rename to src/sig_networks/specific_classification_utils.py index 2eba93e..a838051 100644 --- a/nlpsig_networks/specific_classification_utils.py +++ b/src/sig_networks/specific_classification_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle from os import listdir from os.path import isfile, join @@ -183,8 +185,8 @@ def process_model_results(model_code_name, FOLDER_results): results0 = pickle.load(fin) for my_ran_seed in results0["classifier_params"]["RANDOM_SEED_list"]: - labels_final = torch.empty((0)) - predicted_final = torch.empty((0)) + labels_final = torch.empty(0) + predicted_final = torch.empty(0) seed_files = [f for f in per_model_files if (str(my_ran_seed) + "seed") in f] for sf in seed_files: diff --git a/nlpsig_networks/swmhau.py b/src/sig_networks/swmhau.py similarity index 98% rename from nlpsig_networks/swmhau.py rename to src/sig_networks/swmhau.py index 76193c3..93e44c1 100644 --- a/nlpsig_networks/swmhau.py +++ b/src/sig_networks/swmhau.py @@ -10,8 +10,8 @@ signature_channels, ) -from nlpsig_networks.ffn_baseline import FeedforwardNeuralNetModel -from nlpsig_networks.utils import obtain_signatures_mask +from sig_networks.ffn_baseline import FeedforwardNeuralNetModel +from sig_networks.utils import obtain_signatures_mask class SWMHA(nn.Module): @@ -60,7 +60,7 @@ def __init__( Whether or not to reverse the path before passing it through the signature layers, by default False. """ - super(SWMHA, self).__init__() + super().__init__() self.signature_terms = None # check if the parameters are compatible with each other @@ -317,7 +317,7 @@ def __init__( Passed into `Augment` class from `signatory` package if `augmentation_type='signatory'`, by default None. """ - super(SWMHAU, self).__init__() + super().__init__() self.input_channels = input_channels self.output_channels = output_channels @@ -351,7 +351,7 @@ def __init__( # alternative to convolution: using Augment from signatory self.augment = Augment( in_channels=self.input_channels, - layer_sizes=self.hidden_dim_aug + [self.output_channels], + layer_sizes=[*self.hidden_dim_aug, self.output_channels], include_original=False, include_time=False, kernel_size=3, diff --git a/nlpsig_networks/swmhau_network.py b/src/sig_networks/swmhau_network.py similarity index 96% rename from nlpsig_networks/swmhau_network.py rename to src/sig_networks/swmhau_network.py index 8890116..044d569 100644 --- a/nlpsig_networks/swmhau_network.py +++ b/src/sig_networks/swmhau_network.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn -from nlpsig_networks.feature_concatenation import FeatureConcatenation -from nlpsig_networks.ffn_baseline import FeedforwardNeuralNetModel -from nlpsig_networks.swmhau import SWMHAU +from sig_networks.feature_concatenation import FeatureConcatenation +from sig_networks.ffn_baseline import FeedforwardNeuralNetModel +from sig_networks.swmhau import SWMHAU class SWMHAUNetwork(nn.Module): @@ -90,7 +90,7 @@ def __init__( - scaled_concatenation: concatenation of single value scaled path signature and embedding vector """ - super(SWMHAUNetwork, self).__init__() + super().__init__() self.swmhau = SWMHAU( input_channels=input_channels, diff --git a/nlpsig_networks/swnu.py b/src/sig_networks/swnu.py similarity index 97% rename from nlpsig_networks/swnu.py rename to src/sig_networks/swnu.py index 0ecaa85..09f9095 100644 --- a/nlpsig_networks/swnu.py +++ b/src/sig_networks/swnu.py @@ -10,7 +10,7 @@ signature_channels, ) -from nlpsig_networks.utils import obtain_signatures_mask +from sig_networks.utils import obtain_signatures_mask class SWLSTM(nn.Module): @@ -57,7 +57,7 @@ def __init__( Whether or not a birectional LSTM is used for the final SWLSTM block, by default False (unidirectional LSTM is used in this case). """ - super(SWLSTM, self).__init__() + super().__init__() # logging inputs to the class self.input_size = input_size @@ -205,10 +205,7 @@ def forward(self, x: torch.Tensor): out = self.final_signature(x) elif self.pooling == "lstm": # add element-wise the forward and backward LSTM states - if self.BiLSTM: - out = h_n[-1, :, :] + h_n[-2, :, :] - else: - out = h_n[-1, :, :] + out = h_n[-1, :, :] + h_n[-2, :, :] if self.BiLSTM else h_n[-1, :, :] # reverse sequence padding out = out[inverse_perm] else: @@ -276,7 +273,7 @@ def __init__( Passed into `Augment` class from `signatory` package if `augmentation_type='signatory'`, by default None. """ - super(SWNU, self).__init__() + super().__init__() self.input_channels = input_channels self.log_signature = log_signature @@ -316,7 +313,7 @@ def __init__( # alternative to convolution: using Augment from signatory self.augment = Augment( in_channels=self.input_channels, - layer_sizes=self.hidden_dim_aug + [self.output_channels], + layer_sizes=[*self.hidden_dim_aug, self.output_channels], include_original=False, include_time=False, kernel_size=3, diff --git a/nlpsig_networks/swnu_network.py b/src/sig_networks/swnu_network.py similarity index 96% rename from nlpsig_networks/swnu_network.py rename to src/sig_networks/swnu_network.py index 411123d..d7c15ca 100644 --- a/nlpsig_networks/swnu_network.py +++ b/src/sig_networks/swnu_network.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn -from nlpsig_networks.feature_concatenation import FeatureConcatenation -from nlpsig_networks.ffn_baseline import FeedforwardNeuralNetModel -from nlpsig_networks.swnu import SWNU +from sig_networks.feature_concatenation import FeatureConcatenation +from sig_networks.ffn_baseline import FeedforwardNeuralNetModel +from sig_networks.swnu import SWNU class SWNUNetwork(nn.Module): @@ -92,7 +92,7 @@ def __init__( - scaled_concatenation: concatenation of single value scaled path signature and embedding vector """ - super(SWNUNetwork, self).__init__() + super().__init__() if pooling not in ["signature", "lstm"]: raise ValueError( diff --git a/nlpsig_networks/utils.py b/src/sig_networks/utils.py similarity index 100% rename from nlpsig_networks/utils.py rename to src/sig_networks/utils.py From 31516934be8736086c72e391492a431bf4677558 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 19:39:54 +0000 Subject: [PATCH 02/16] apply pre-commit --- .../ffn_current_focal_2_kfold_best_model.csv | 2 +- ...tm_history_11_focal_2_kfold_best_model.csv | 2 +- ...seqsignet-attention-encoder-script copy.py | 123 ------------------ .../anno_mi-client-seqsignet-script copy | 101 -------------- .../anno_mi-client-swmhau-script copy | 101 -------------- .../anno_mi-client-swnu-script copy | 102 --------------- notebooks/Reddit_MoC/load_redditmoc.py | 38 +++--- notebooks/Reddit_MoC/load_sbert-embeddings.py | 8 +- .../Reddit_MoC/redditmoc-baseline-bert.ipynb | 16 ++- .../redditmoc-baseline-bilstm.ipynb | 17 +-- .../Reddit_MoC/redditmoc-baseline-ffn.ipynb | 18 ++- .../Reddit_MoC/redditmoc-swmhau-masked.ipynb | 29 +++-- notebooks/Rumours/load_rumours.py | 4 +- notebooks/Stance/load_stance.py | 16 ++- .../Stance/stance-baseline-ffn-history.ipynb | 2 +- .../Talklife_MoC/load_sbert-embeddings.py | 8 +- notebooks/Talklife_MoC/load_talklifemoc.py | 42 +++--- .../talklifemoc-baseline-bert.ipynb | 14 +- .../talklifemoc-baseline-bilstm-w11.ipynb | 44 +++++-- .../talklifemoc-baseline-bilstm-w20.ipynb | 44 +++++-- .../talklifemoc-baseline-bilstm-w35.ipynb | 44 +++++-- .../talklifemoc-baseline-bilstm-w5.ipynb | 42 ++++-- .../talklifemoc-baseline-ffn-history.ipynb | 44 +++++-- .../talklifemoc-baseline-ffn.ipynb | 48 ++++--- .../Talklife_MoC/talklifemoc-swmhau-w11.ipynb | 92 +++++++++---- .../Talklife_MoC/talklifemoc-swmhau-w20.ipynb | 63 ++++++--- .../Talklife_MoC/talklifemoc-swmhau-w5.ipynb | 65 ++++++--- .../talklifemoc-swnu-bilstm-35.ipynb | 50 +++++-- .../talklifemoc-swnu-bilstm.ipynb | 52 ++++++-- .../Talklife_MoC/talklifemoc-swnu-w11.ipynb | 50 +++++-- .../Talklife_MoC/talklifemoc-swnu-w20.ipynb | 48 +++++-- .../Talklife_MoC/talklifemoc-swnu-w5.ipynb | 50 +++++-- notebooks/results-readout.py | 81 +++++++++--- 33 files changed, 712 insertions(+), 748 deletions(-) delete mode 100644 notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-attention-encoder-script copy.py delete mode 100644 notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-script copy delete mode 100644 notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swmhau-script copy delete mode 100644 notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swnu-script copy diff --git a/notebooks/AnnoMI/client_talk_type_output/ffn_current_focal_2_kfold_best_model.csv b/notebooks/AnnoMI/client_talk_type_output/ffn_current_focal_2_kfold_best_model.csv index 16cef9a..e5935e5 100644 --- a/notebooks/AnnoMI/client_talk_type_output/ffn_current_focal_2_kfold_best_model.csv +++ b/notebooks/AnnoMI/client_talk_type_output/ffn_current_focal_2_kfold_best_model.csv @@ -1,4 +1,4 @@ ,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,valid_precision,valid_precision_scores,valid_recall,valid_recall_scores,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size 0,,0.6402298808097839,0.518987928780223,[0.76754191 0.44606164 0.34336023],0.5131149305011911,[0.78799878 0.46352313 0.28782288],0.534480944335035,[0.7481203 0.42986799 0.42545455],,0.6575144529342651,0.5473814347074913,[0.78115157 0.46930134 0.39169139],0.5454692032240676,[0.78125 0.50996933 0.34518828],0.5561228606669005,[0.78105316 0.43464052 0.4526749 ],"(512, 512)",0.1,0.001,1,focal,2,True,5,64 0,,0.6229885220527649,0.5063565905105594,[0.75261428 0.41402619 0.3524293 ],0.49859894367141694,[0.77843016 0.42424242 0.29312425],0.5248547885442102,[0.72845575 0.40429043 0.44181818],,0.6631342172622681,0.5574780764400712,[0.78449375 0.4919713 0.39596918],0.5527498325638676,[0.79421637 0.51539012 0.34864301],0.5679187998930711,[0.7750063 0.47058824 0.45816187],"(512, 512)",0.1,0.001,12,focal,2,True,5,64 -0,,0.6406130194664001,0.5096616660227287,[0.77081222 0.41295167 0.34522111],0.5115455869756562,[0.77181792 0.47878128 0.28403756],0.5242818139535179,[0.76980914 0.3630363 0.44 ],,0.6640976071357727,0.5468888875022938,[0.79088771 0.45092838 0.39885057],0.5524170732756318,[0.7775073 0.53651939 0.34322453],0.5565400371396844,[0.80473671 0.38888889 0.47599451],"(512, 512)",0.1,0.001,123,focal,2,True,5,64 \ No newline at end of file +0,,0.6406130194664001,0.5096616660227287,[0.77081222 0.41295167 0.34522111],0.5115455869756562,[0.77181792 0.47878128 0.28403756],0.5242818139535179,[0.76980914 0.3630363 0.44 ],,0.6640976071357727,0.5468888875022938,[0.79088771 0.45092838 0.39885057],0.5524170732756318,[0.7775073 0.53651939 0.34322453],0.5565400371396844,[0.80473671 0.38888889 0.47599451],"(512, 512)",0.1,0.001,123,focal,2,True,5,64 diff --git a/notebooks/AnnoMI/client_talk_type_output/lstm_history_11_focal_2_kfold_best_model.csv b/notebooks/AnnoMI/client_talk_type_output/lstm_history_11_focal_2_kfold_best_model.csv index 7d1e85c..bec8616 100644 --- a/notebooks/AnnoMI/client_talk_type_output/lstm_history_11_focal_2_kfold_best_model.csv +++ b/notebooks/AnnoMI/client_talk_type_output/lstm_history_11_focal_2_kfold_best_model.csv @@ -1,4 +1,4 @@ ,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,valid_precision,valid_precision_scores,valid_recall,valid_recall_scores,num_layers,bidirectional,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size 0,,0.6379310488700867,0.5256797956726735,[0.75807416 0.45041486 0.36855037],0.5174319011572314,[0.78482972 0.43214556 0.33532042],0.5374902151869323,[0.73308271 0.47029703 0.40909091],,0.6901091933250427,0.593650607458289,[0.79871134 0.54642745 0.43581303],0.5878228582820997,[0.81746241 0.52701882 0.41898734],0.6007227033474569,[0.78080121 0.56732026 0.45404664],1,True,300,0.1,0.0005,1,focal,2,True,5,64 0,,0.6245210766792297,0.5160405336922796,[0.74690799 0.43505976 0.36615385],0.5061868528186951,[0.78058008 0.42064715 0.31733333],0.5330810478387341,[0.71602082 0.45049505 0.43272727],,0.6796724200248718,0.5892014363908511,[0.78868814 0.5440806 0.43483557],0.5789532080062053,[0.82093213 0.52490887 0.39101862],0.6044330489397405,[0.75888133 0.56470588 0.48971193],1,True,300,0.1,0.0005,12,focal,2,True,5,64 -0,,0.6093869805335999,0.512373581596779,[0.73194423 0.451341 0.35383552],0.5017814694134902,[0.79863248 0.42131617 0.28539576],0.5423210447129154,[0.67553499 0.4859736 0.46545455],,0.6687540411949158,0.5868382641499507,[0.77656859 0.5515737 0.43237251],0.5739280772923682,[0.83189407 0.52709946 0.3627907 ],0.6138513018376075,[0.72814311 0.57843137 0.53497942],1,True,300,0.1,0.0005,123,focal,2,True,5,64 \ No newline at end of file +0,,0.6093869805335999,0.512373581596779,[0.73194423 0.451341 0.35383552],0.5017814694134902,[0.79863248 0.42131617 0.28539576],0.5423210447129154,[0.67553499 0.4859736 0.46545455],,0.6687540411949158,0.5868382641499507,[0.77656859 0.5515737 0.43237251],0.5739280772923682,[0.83189407 0.52709946 0.3627907 ],0.6138513018376075,[0.72814311 0.57843137 0.53497942],1,True,300,0.1,0.0005,123,focal,2,True,5,64 diff --git a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-attention-encoder-script copy.py b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-attention-encoder-script copy.py deleted file mode 100644 index 2025475..0000000 --- a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-attention-encoder-script copy.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -import os -import pickle - -import numpy as np -import torch - -from sig_networks.scripts.seqsignet_attention_encoder_functions import ( - seqsignet_attention_encoder_hyperparameter_search, -) - -from ..load_anno_mi import ( - anno_mi, - client_index, - client_transcript_id, - output_dim_client, - y_data_client, -) - -seed = 2023 - -# set device -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print("Device: ", device) - -# set output directory -output_dir = "client_talk_type_output" -if not os.path.isdir(output_dir): - os.makedirs(output_dir) - -# load sbert embeddings -with open("../anno_mi_sbert.pkl", "rb") as f: - sbert_embeddings = pickle.load(f) - -# set features -features = ["time_encoding", "timeline_index"] -standardise_method = ["z_score", None] -include_features_in_path = True -include_features_in_input = True - -# set hyperparameters -num_epochs = 100 -dimensions = [15] -# define swmhau parameters: (output_channels, sig_depth, num_heads) -swmhau_parameters = [(12, 3, 10), (10, 3, 5)] -num_layers = [1] -ffn_hidden_dim_sizes = [[32, 32], [128, 128], [512, 512]] -dropout_rates = [0.1] -learning_rates = [5e-4, 3e-4, 1e-4, 1e-5] -seeds = [1, 12, 123] -loss = "focal" -gamma = 2 -validation_metric = "f1" -patience = 3 - -# set kwargs -kwargs = { - "num_epochs": num_epochs, - "df": anno_mi, - "id_column": "transcript_id", - "label_column": "client_talk_type", - "embeddings": sbert_embeddings, - "y_data": y_data_client, - "output_dim": output_dim_client, - "dimensions": dimensions, - "log_signature": True, - "pooling": "signature", - "transformer_encoder_layers": 2, - "swmhau_parameters": swmhau_parameters, - "num_layers": num_layers, - "ffn_hidden_dim_sizes": ffn_hidden_dim_sizes, - "dropout_rates": dropout_rates, - "learning_rates": learning_rates, - "seeds": seeds, - "loss": loss, - "gamma": gamma, - "device": device, - "features": features, - "standardise_method": standardise_method, - "include_features_in_path": include_features_in_path, - "include_features_in_input": include_features_in_input, - "path_indices": client_index, - "split_ids": client_transcript_id, - "k_fold": True, - "patience": patience, - "validation_metric": validation_metric, - "verbose": False, -} - -# run hyperparameter search -lengths = [(3, 5, 3), (3, 5, 6), (3, 5, 11), (3, 5, 26), (3, 5, 36)] - -for shift, window_size, n in lengths: - print(f"shift: {shift}, window_size: {window_size}, n: {n}") - ( - seqsignet_attention_encoder_umap_kfold, - best_seqsignet_attention_encoder_umap_kfold, - _, - __, - ) = seqsignet_attention_encoder_hyperparameter_search( - shift=shift, - window_size=window_size, - n=n, - dim_reduce_methods=["umap"], - results_output=f"{output_dir}/seqsignet_attention_encoder_umap_focal_{gamma}_{shift}_{window_size}_{n}_kfold.csv", - **kwargs, - ) - - print(f"F1: {best_seqsignet_attention_encoder_umap_kfold['f1'].mean()}") - print( - f"Precision: {best_seqsignet_attention_encoder_umap_kfold['precision'].mean()}" - ) - print(f"Recall: {best_seqsignet_attention_encoder_umap_kfold['recall'].mean()}") - print( - f"F1 scores: {np.stack(best_seqsignet_attention_encoder_umap_kfold['f1_scores']).mean(axis=0)}" - ) - print( - f"Precision scores: {np.stack(best_seqsignet_attention_encoder_umap_kfold['precision_scores']).mean(axis=0)}" - ) - print( - f"Recall scores: {np.stack(best_seqsignet_attention_encoder_umap_kfold['recall_scores']).mean(axis=0)}" - ) diff --git a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-script copy b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-script copy deleted file mode 100644 index 8c370d9..0000000 --- a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-seqsignet-script copy +++ /dev/null @@ -1,101 +0,0 @@ -import numpy as np -import pickle -import os -import torch -from sig_networks.scripts.seqsignet_functions import seqsignet_hyperparameter_search -from ..load_anno_mi import anno_mi, y_data_client, output_dim_client, client_index, client_transcript_id - -seed = 2023 - -# set device -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print("Device: ", device) - -# set output directory -output_dir = "client_talk_type_output" -if not os.path.isdir(output_dir): - os.makedirs(output_dir) - -# load sbert embeddings -with open("../anno_mi_sbert.pkl", "rb") as f: - sbert_embeddings = pickle.load(f) - -# set features -features = ["time_encoding", "timeline_index"] -standardise_method = ["z_score", None] -include_features_in_path = True -include_features_in_input = True - -# set hyperparameters -num_epochs = 100 -dimensions = [15] -swnu_hidden_dim_sizes_and_sig_depths = [([12], 3), ([10], 3)] -lstm_hidden_dim_sizes = [300, 400] -ffn_hidden_dim_sizes = [[32,32], [128,128], [512,512]] -dropout_rates = [0.1] -learning_rates = [5e-4, 3e-4, 1e-4] -seeds = [1, 12, 123] -loss = "focal" -gamma = 2 -validation_metric = "f1" -patience = 3 - -# set kwargs -kwargs = { - "num_epochs": num_epochs, - "df": anno_mi, - "id_column": "transcript_id", - "label_column": "client_talk_type", - "embeddings": sbert_embeddings, - "y_data": y_data_client, - "output_dim": output_dim_client, - "dimensions": dimensions, - "log_signature": True, - "pooling": "signature", - "swnu_hidden_dim_sizes_and_sig_depths": swnu_hidden_dim_sizes_and_sig_depths, - "lstm_hidden_dim_sizes": lstm_hidden_dim_sizes, - "ffn_hidden_dim_sizes": ffn_hidden_dim_sizes, - "dropout_rates": dropout_rates, - "learning_rates": learning_rates, - "BiLSTM": True, - "seeds": seeds, - "loss": loss, - "gamma": gamma, - "device": device, - "features": features, - "standardise_method": standardise_method, - "include_features_in_path": include_features_in_path, - "include_features_in_input": include_features_in_input, - "path_indices": client_index, - "split_ids": client_transcript_id, - "k_fold": True, - "patience": patience, - "validation_metric": validation_metric, - "verbose": False, -} - -# run hyperparameter search -lengths = [(3, 5, 3), (3, 5, 6), (3, 5, 11), (3, 5, 26), (3, 5, 36)] - -for shift, window_size, n in lengths: - print(f"shift: {shift}, window_size: {window_size}, n: {n}") - ( - seqsignet_network_umap_kfold, - best_seqsignet_network_umap_kfold, - _, - __, - ) = seqsignet_hyperparameter_search( - shift=shift, - window_size=window_size, - n=n, - dim_reduce_methods=["umap"], - results_output=f"{output_dir}/seqsignet_umap_focal_{gamma}_{shift}_{window_size}_{n}_kfold.csv", - **kwargs, - ) - - print(f"F1: {best_seqsignet_network_umap_kfold['f1'].mean()}") - print(f"Precision: {best_seqsignet_network_umap_kfold['precision'].mean()}") - print(f"Recall: {best_seqsignet_network_umap_kfold['recall'].mean()}") - print(f"F1 scores: {np.stack(best_seqsignet_network_umap_kfold['f1_scores']).mean(axis=0)}") - print(f"Precision scores: {np.stack(best_seqsignet_network_umap_kfold['precision_scores']).mean(axis=0)}") - print(f"Recall scores: {np.stack(best_seqsignet_network_umap_kfold['recall_scores']).mean(axis=0)}") diff --git a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swmhau-script copy b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swmhau-script copy deleted file mode 100644 index 8618fc2..0000000 --- a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swmhau-script copy +++ /dev/null @@ -1,101 +0,0 @@ -import numpy as np -import pickle -import os -import torch -from sig_networks.scripts.swmhau_network_functions import ( - swmhau_network_hyperparameter_search, -) -from ..load_anno_mi import anno_mi, y_data_client, output_dim_client, client_index, client_transcript_id - -seed = 2023 - -# set device -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print("Device: ", device) - -# set output directory -output_dir = "client_talk_type_output" -if not os.path.isdir(output_dir): - os.makedirs(output_dir) - -# load sbert embeddings -with open("../anno_mi_sbert.pkl", "rb") as f: - sbert_embeddings = pickle.load(f) - -# set features -features = ["time_encoding", "timeline_index"] -standardise_method = ["z_score", None] -include_features_in_path = True -include_features_in_input = True - -# set hyperparameters -num_epochs = 100 -dimensions = [15] -# define swmhau parameters: (output_channels, sig_depth, num_heads) -swmhau_parameters = [(12, 3, 10), (10, 3, 5)] -num_layers = [1] -ffn_hidden_dim_sizes = [[32,32], [128,128], [512,512]] -dropout_rates = [0.1] -learning_rates = [5e-4, 3e-4, 1e-4, 1e-5] -seeds = [1, 12, 123] -loss = "focal" -gamma = 2 -validation_metric = "f1" -patience = 3 - -# set kwargs -kwargs = { - "num_epochs": num_epochs, - "df": anno_mi, - "id_column": "transcript_id", - "label_column": "client_talk_type", - "embeddings": sbert_embeddings, - "y_data": y_data_client, - "output_dim": output_dim_client, - "dimensions": dimensions, - "log_signature": True, - "pooling": "signature", - "swmhau_parameters": swmhau_parameters, - "num_layers": num_layers, - "ffn_hidden_dim_sizes": ffn_hidden_dim_sizes, - "dropout_rates": dropout_rates, - "learning_rates": learning_rates, - "seeds": seeds, - "loss": loss, - "gamma": gamma, - "device": device, - "features": features, - "standardise_method": standardise_method, - "include_features_in_path": include_features_in_path, - "include_features_in_input": include_features_in_input, - "path_indices": client_index, - "split_ids": client_transcript_id, - "k_fold": True, - "patience": patience, - "validation_metric": validation_metric, - "verbose": False, -} - -# run hyperparameter search -lengths = [5, 11, 20, 35, 80, 110] - -for size in lengths: - print(f"history_length: {size}") - ( - swmhau_network_umap_kfold, - best_swmhau_network_umap_kfold, - _, - __, - ) = swmhau_network_hyperparameter_search( - history_lengths=[size], - dim_reduce_methods=["umap"], - results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}_kfold.csv", - **kwargs, - ) - - print(f"F1: {best_swmhau_network_umap_kfold['f1'].mean()}") - print(f"Precision: {best_swmhau_network_umap_kfold['precision'].mean()}") - print(f"Recall: {best_swmhau_network_umap_kfold['recall'].mean()}") - print(f"F1 scores: {np.stack(best_swmhau_network_umap_kfold['f1_scores']).mean(axis=0)}") - print(f"Precision scores: {np.stack(best_swmhau_network_umap_kfold['precision_scores']).mean(axis=0)}") - print(f"Recall scores: {np.stack(best_swmhau_network_umap_kfold['recall_scores']).mean(axis=0)}") diff --git a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swnu-script copy b/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swnu-script copy deleted file mode 100644 index c38bf3d..0000000 --- a/notebooks/AnnoMI/client_talk_type_prediction/anno_mi-client-swnu-script copy +++ /dev/null @@ -1,102 +0,0 @@ -import numpy as np -import pickle -import os -import torch -from sig_networks.scripts.swnu_network_functions import ( - swnu_network_hyperparameter_search, -) -from ..load_anno_mi import anno_mi, y_data_client, output_dim_client, client_index, client_transcript_id - -seed = 2023 - -# set device -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print("Device: ", device) - -# set output directory -output_dir = "client_talk_type_output" -if not os.path.isdir(output_dir): - os.makedirs(output_dir) - -# load sbert embeddings -with open("../anno_mi_sbert.pkl", "rb") as f: - sbert_embeddings = pickle.load(f) - -# set features -features = ["time_encoding", "timeline_index"] -standardise_method = ["z_score", None] -include_features_in_path = True -include_features_in_input = True - -# set hyperparameters -num_epochs = 100 -dimensions = [15] -swnu_hidden_dim_sizes_and_sig_depths = [([12], 3), ([10], 3)] -ffn_hidden_dim_sizes = [[32,32], [128,128], [512,512]] -dropout_rates = [0.1] -learning_rates = [5e-4, 3e-4, 1e-4] -seeds = [1, 12, 123] -loss = "focal" -gamma = 2 -validation_metric = "f1" -patience = 3 - -# set kwargs -kwargs = { - "num_epochs": num_epochs, - "df": anno_mi, - "id_column": "transcript_id", - "label_column": "client_talk_type", - "embeddings": sbert_embeddings, - "y_data": y_data_client, - "output_dim": output_dim_client, - "dimensions": dimensions, - "log_signature": True, - "pooling": "signature", - "swnu_hidden_dim_sizes_and_sig_depths": swnu_hidden_dim_sizes_and_sig_depths, - "ffn_hidden_dim_sizes": ffn_hidden_dim_sizes, - "dropout_rates": dropout_rates, - "learning_rates": learning_rates, - "BiLSTM": True, - "seeds": seeds, - "loss": loss, - "gamma": gamma, - "device": device, - "features": features, - "standardise_method": standardise_method, - "include_features_in_path": include_features_in_path, - "include_features_in_input": include_features_in_input, - "path_indices": client_index, - "split_ids": client_transcript_id, - "k_fold": True, - "patience": patience, - "validation_metric": validation_metric, - "verbose": False, -} - -# run hyperparameter search -lengths = [(3, 5, 3), (3, 5, 6), (3, 5, 11), (3, 5, 26), (3, 5, 36)] - -# run hyperparameter search -lengths = [5, 11, 20, 35, 80, 110] - -for size in lengths: - print(f"history_length: {size}") - ( - swnu_network_umap_kfold, - best_swnu_network_umap_kfold, - _, - __, - ) = swnu_network_hyperparameter_search( - history_lengths=[size], - dim_reduce_methods=["umap"], - results_output=f"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv", - **kwargs, - ) - - print(f"F1: {best_swnu_network_umap_kfold['f1'].mean()}") - print(f"Precision: {best_swnu_network_umap_kfold['precision'].mean()}") - print(f"Recall: {best_swnu_network_umap_kfold['recall'].mean()}") - print(f"F1 scores: {np.stack(best_swnu_network_umap_kfold['f1_scores']).mean(axis=0)}") - print(f"Precision scores: {np.stack(best_swnu_network_umap_kfold['precision_scores']).mean(axis=0)}") - print(f"Recall scores: {np.stack(best_swnu_network_umap_kfold['recall_scores']).mean(axis=0)}") diff --git a/notebooks/Reddit_MoC/load_redditmoc.py b/notebooks/Reddit_MoC/load_redditmoc.py index 242ac4d..14c27ed 100644 --- a/notebooks/Reddit_MoC/load_redditmoc.py +++ b/notebooks/Reddit_MoC/load_redditmoc.py @@ -1,8 +1,8 @@ -import pandas as pd +import pandas as pd import os import datetime import torch -import paths +import paths ################################################## ########## Load in data ########################## @@ -11,27 +11,29 @@ df = pd.read_pickle(reddit_fname) dict3 = {} -dict3['0'] = '0' -dict3['E'] = '1' -dict3['IE'] = '1' -dict3['S'] = '2' -dict3['IS'] = '2' +dict3["0"] = "0" +dict3["E"] = "1" +dict3["IE"] = "1" +dict3["S"] = "2" +dict3["IS"] = "2" -#GET THE FLAT LABELS +# GET THE FLAT LABELS df = df.replace({"label_3": dict3}) -df['label'] = df['label_3'] +df["label"] = df["label_3"] -#GET TRAIN/DEV/TEST SET -df['set'] = None -df.loc[(df.train_or_test =='test'),'set'] = 'test' -df.loc[(df.train_or_test !='test') & (df.fold !=0),'set'] = 'train' -df.loc[(df.fold ==0),'set'] = 'dev' +# GET TRAIN/DEV/TEST SET +df["set"] = None +df.loc[(df.train_or_test == "test"), "set"] = "test" +df.loc[(df.train_or_test != "test") & (df.fold != 0), "set"] = "train" +df.loc[(df.fold == 0), "set"] = "dev" -#KEEP SPECIFIC COLUMNS AND RESET INDEX -df = df[['user_id', 'timeline_id', 'postid', 'content','label', 'datetime', 'set']].reset_index(drop=True) +# KEEP SPECIFIC COLUMNS AND RESET INDEX +df = df[ + ["user_id", "timeline_id", "postid", "content", "label", "datetime", "set"] +].reset_index(drop=True) ################################################## ########## Dimensions and Y labels ############### ################################################## -output_dim = len(df['label'].unique()) -y_data = torch.tensor(df['label'].astype(float).values, dtype=torch.int64) \ No newline at end of file +output_dim = len(df["label"].unique()) +y_data = torch.tensor(df["label"].astype(float).values, dtype=torch.int64) diff --git a/notebooks/Reddit_MoC/load_sbert-embeddings.py b/notebooks/Reddit_MoC/load_sbert-embeddings.py index 68b4475..151b91f 100644 --- a/notebooks/Reddit_MoC/load_sbert-embeddings.py +++ b/notebooks/Reddit_MoC/load_sbert-embeddings.py @@ -3,9 +3,9 @@ import os import paths -#read embeddings -emb_sbert_filename= paths.embeddings_fname -with open(emb_sbert_filename, 'rb') as f: +# read embeddings +emb_sbert_filename = paths.embeddings_fname +with open(emb_sbert_filename, "rb") as f: sbert_embeddings = pickle.load(f) -sbert_embeddings = torch.tensor(sbert_embeddings) \ No newline at end of file +sbert_embeddings = torch.tensor(sbert_embeddings) diff --git a/notebooks/Reddit_MoC/redditmoc-baseline-bert.ipynb b/notebooks/Reddit_MoC/redditmoc-baseline-bert.ipynb index 9bcbbca..4e2ca85 100644 --- a/notebooks/Reddit_MoC/redditmoc-baseline-bert.ipynb +++ b/notebooks/Reddit_MoC/redditmoc-baseline-bert.ipynb @@ -131,13 +131,15 @@ "outputs": [], "source": [ "num_epochs = 10\n", - "seeds = [12] #[1, 12, 123]\n", + "seeds = [12] # [1, 12, 123]\n", "loss = \"focal\"\n", "gamma = 2\n", "validation_metric = \"f1\"\n", - "split_indices = (df[df['set']=='train'].index, \n", - " df[df['set']=='dev'].index, \n", - " df[df['set']=='test'].index)" + "split_indices = (\n", + " df[df[\"set\"] == \"train\"].index,\n", + " df[df[\"set\"] == \"dev\"].index,\n", + " df[df[\"set\"] == \"test\"].index,\n", + ")" ] }, { @@ -177,10 +179,10 @@ " split_indices=split_indices,\n", " k_fold=False,\n", " validation_metric=validation_metric,\n", - " results_output=None, #f\"{output_dir}/bert_classifier.csv\",\n", + " results_output=None, # f\"{output_dir}/bert_classifier.csv\",\n", " device=device,\n", - " verbose=False\n", - ")\n" + " verbose=False,\n", + ")" ] }, { diff --git a/notebooks/Reddit_MoC/redditmoc-baseline-bilstm.ipynb b/notebooks/Reddit_MoC/redditmoc-baseline-bilstm.ipynb index f20c8b5..69f51ae 100644 --- a/notebooks/Reddit_MoC/redditmoc-baseline-bilstm.ipynb +++ b/notebooks/Reddit_MoC/redditmoc-baseline-bilstm.ipynb @@ -44,8 +44,7 @@ "metadata": {}, "outputs": [], "source": [ - "from nlpsig_networks.scripts.lstm_baseline_functions import (\n", - " lstm_hyperparameter_search)" + "from nlpsig_networks.scripts.lstm_baseline_functions import lstm_hyperparameter_search" ] }, { @@ -128,9 +127,11 @@ "gamma = 2\n", "validation_metric = \"f1\"\n", "patience = 5\n", - "split_indices = (df[df['set']=='train'].index, \n", - " df[df['set']=='dev'].index, \n", - " df[df['set']=='test'].index)" + "split_indices = (\n", + " df[df[\"set\"] == \"train\"].index,\n", + " df[df[\"set\"] == \"dev\"].index,\n", + " df[df[\"set\"] == \"test\"].index,\n", + ")" ] }, { @@ -391,13 +392,13 @@ " gamma=gamma,\n", " device=device,\n", " path_indices=None,\n", - " split_ids= None,\n", - " split_indices = split_indices,\n", + " split_ids=None,\n", + " split_indices=split_indices,\n", " k_fold=False,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/lstm_history_{size}_focal_{gamma}.csv\",\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, diff --git a/notebooks/Reddit_MoC/redditmoc-baseline-ffn.ipynb b/notebooks/Reddit_MoC/redditmoc-baseline-ffn.ipynb index 0a287ba..206e6ec 100644 --- a/notebooks/Reddit_MoC/redditmoc-baseline-ffn.ipynb +++ b/notebooks/Reddit_MoC/redditmoc-baseline-ffn.ipynb @@ -98,9 +98,9 @@ "source": [ "num_epochs = 100\n", "batch = 32\n", - "hidden_dim_sizes = [[32,32]]#[[32,32], [64,64], [128,128], [256,256]]\n", - "dropout_rates = [0.2] #[0.5, 0.2, 0.1]\n", - "learning_rates = [1e-4]#[1e-3, 1e-4, 5e-4]\n", + "hidden_dim_sizes = [[32, 32]] # [[32,32], [64,64], [128,128], [256,256]]\n", + "dropout_rates = [0.2] # [0.5, 0.2, 0.1]\n", + "learning_rates = [1e-4] # [1e-3, 1e-4, 5e-4]\n", "seeds = [1, 12, 123]\n", "loss = \"focal\"\n", "gamma = 2\n", @@ -156,7 +156,7 @@ } ], "source": [ - "ffn_current, best_ffn_current, _, __ = ffn_hyperparameter_search( \n", + "ffn_current, best_ffn_current, _, __ = ffn_hyperparameter_search(\n", " num_epochs=num_epochs,\n", " x_data=sbert_embeddings,\n", " y_data=y_data,\n", @@ -170,12 +170,16 @@ " device=device,\n", " batch_size=batch,\n", " data_split_seed=123,\n", - " split_ids= None, #torch.tensor(df_rumours['timeline_id'].astype(int)),\n", - " split_indices = (df[df['set']=='train'].index, df[df['set']=='dev'].index, df[df['set']=='test'].index),\n", + " split_ids=None, # torch.tensor(df_rumours['timeline_id'].astype(int)),\n", + " split_indices=(\n", + " df[df[\"set\"] == \"train\"].index,\n", + " df[df[\"set\"] == \"dev\"].index,\n", + " df[df[\"set\"] == \"test\"].index,\n", + " ),\n", " k_fold=False,\n", " validation_metric=validation_metric,\n", " results_output=None,\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, diff --git a/notebooks/Reddit_MoC/redditmoc-swmhau-masked.ipynb b/notebooks/Reddit_MoC/redditmoc-swmhau-masked.ipynb index 4cf36f2..5a1fa69 100644 --- a/notebooks/Reddit_MoC/redditmoc-swmhau-masked.ipynb +++ b/notebooks/Reddit_MoC/redditmoc-swmhau-masked.ipynb @@ -45,7 +45,7 @@ "outputs": [], "source": [ "from nlpsig_networks.scripts.swmhau_network_functions import (\n", - " swmhau_network_hyperparameter_search\n", + " swmhau_network_hyperparameter_search,\n", ")" ] }, @@ -118,8 +118,8 @@ "metadata": {}, "outputs": [], "source": [ - "features = [\"time_encoding\"]#[\"time_encoding\", \"timeline_index\"]\n", - "standardise_method = [\"z_score\"]#[\"z_score\", None]\n", + "features = [\"time_encoding\"] # [\"time_encoding\", \"timeline_index\"]\n", + "standardise_method = [\"z_score\"] # [\"z_score\", None]\n", "num_features = len(features)\n", "add_time_in_path = False" ] @@ -132,11 +132,11 @@ "source": [ "num_epochs = 100\n", "embedding_dim = 384\n", - "dimensions = [15] # [50, 15]\n", + "dimensions = [15] # [50, 15]\n", "# define swmhau parameters: (output_channels, sig_depth, num_heads)\n", "swmhau_parameters = [(12, 3, 10), (8, 4, 6), (8, 4, 12)]\n", "num_layers = [1]\n", - "ffn_hidden_dim_sizes = [[256,256],[512,512]]\n", + "ffn_hidden_dim_sizes = [[256, 256], [512, 512]]\n", "dropout_rates = [0.5, 0.1]\n", "learning_rates = [1e-3, 1e-4, 5e-4]\n", "seeds = [1, 12, 123]\n", @@ -144,9 +144,11 @@ "gamma = 2\n", "validation_metric = \"f1\"\n", "patience = 5\n", - "split_indices = (df[df['set']=='train'].index, \n", - " df[df['set']=='dev'].index, \n", - " df[df['set']=='test'].index)" + "split_indices = (\n", + " df[df[\"set\"] == \"train\"].index,\n", + " df[df[\"set\"] == \"dev\"].index,\n", + " df[df[\"set\"] == \"test\"].index,\n", + ")" ] }, { @@ -624,7 +626,12 @@ ], "source": [ "size = 20\n", - "swmhau_network_umap, best_swmhau_network_umap, _, __ = swmhau_network_hyperparameter_search(\n", + "(\n", + " swmhau_network_umap,\n", + " best_swmhau_network_umap,\n", + " _,\n", + " __,\n", + ") = swmhau_network_hyperparameter_search(\n", " num_epochs=num_epochs,\n", " df=df,\n", " id_column=\"timeline_id\",\n", @@ -654,7 +661,7 @@ " patience=patience,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}.csv\",\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, @@ -1255,7 +1262,7 @@ } ], "source": [ - "best_swmhau_network_umap['f1'].mean()" + "best_swmhau_network_umap[\"f1\"].mean()" ] }, { diff --git a/notebooks/Rumours/load_rumours.py b/notebooks/Rumours/load_rumours.py index 7c64819..bbfa6ab 100644 --- a/notebooks/Rumours/load_rumours.py +++ b/notebooks/Rumours/load_rumours.py @@ -169,9 +169,7 @@ def time_fraction(x): ########## Dimensions and Y labels ############### ################################################## y_data = df_rumours["label"] -label_to_id = { - str(y_data.unique()[i]): i for i in range(len(y_data.unique())) -} +label_to_id = {str(y_data.unique()[i]): i for i in range(len(y_data.unique()))} id_to_label = {v: k for k, v in label_to_id.items()} output_dim = len(label_to_id.keys()) diff --git a/notebooks/Stance/load_stance.py b/notebooks/Stance/load_stance.py index 91d534d..4e85b8b 100644 --- a/notebooks/Stance/load_stance.py +++ b/notebooks/Stance/load_stance.py @@ -18,6 +18,7 @@ ########## Conversion to Timeline#### ############ ################################################## + # Convert conversation thread to linear timeline: we use timestamps # of each post in the twitter thread to obtain a chronologically ordered list. def tree2timeline(conversation): @@ -43,6 +44,7 @@ def tree2timeline(conversation): timeline.extend(sorted_replies) return timeline + stance_timelines = {"dev": [], "train": [], "test": []} count_threads = 0 @@ -92,18 +94,18 @@ def time_fraction(x): ################################################## ########## Label Mapping ######################### ################################################## -#labels in numbers +# labels in numbers dictl = {} -dictl['support'] = 0 -dictl['deny'] = 1 -dictl['comment'] = 2 -dictl['query'] = 3 +dictl["support"] = 0 +dictl["deny"] = 1 +dictl["comment"] = 2 +dictl["query"] = 3 -#Get the numerical labels +# Get the numerical labels df = df.replace({"label": dictl}) ################################################## ########## Dimensions and Y labels ############### ################################################## output_dim = len(df["label"].unique()) -y_data = torch.tensor(df["label"].astype(float).values, dtype=torch.int64) \ No newline at end of file +y_data = torch.tensor(df["label"].astype(float).values, dtype=torch.int64) diff --git a/notebooks/Stance/stance-baseline-ffn-history.ipynb b/notebooks/Stance/stance-baseline-ffn-history.ipynb index 86b94a6..917909a 100644 --- a/notebooks/Stance/stance-baseline-ffn-history.ipynb +++ b/notebooks/Stance/stance-baseline-ffn-history.ipynb @@ -25,7 +25,7 @@ "\n", "# set device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device='cuda:2'" + "device = \"cuda:2\"" ] }, { diff --git a/notebooks/Talklife_MoC/load_sbert-embeddings.py b/notebooks/Talklife_MoC/load_sbert-embeddings.py index 68b4475..151b91f 100644 --- a/notebooks/Talklife_MoC/load_sbert-embeddings.py +++ b/notebooks/Talklife_MoC/load_sbert-embeddings.py @@ -3,9 +3,9 @@ import os import paths -#read embeddings -emb_sbert_filename= paths.embeddings_fname -with open(emb_sbert_filename, 'rb') as f: +# read embeddings +emb_sbert_filename = paths.embeddings_fname +with open(emb_sbert_filename, "rb") as f: sbert_embeddings = pickle.load(f) -sbert_embeddings = torch.tensor(sbert_embeddings) \ No newline at end of file +sbert_embeddings = torch.tensor(sbert_embeddings) diff --git a/notebooks/Talklife_MoC/load_talklifemoc.py b/notebooks/Talklife_MoC/load_talklifemoc.py index 8770d2d..5d05e3a 100644 --- a/notebooks/Talklife_MoC/load_talklifemoc.py +++ b/notebooks/Talklife_MoC/load_talklifemoc.py @@ -1,43 +1,47 @@ -import pandas as pd +import pandas as pd import sys import pickle import torch import paths -sys.path.append("..")# Adds higher directory to python modules path +sys.path.append("..") # Adds higher directory to python modules path ################################################## ########## Load in data ########################## ################################################## -#TalkLifeDataset = data_handler.TalkLifeDataset() -#df = TalkLifeDataset.return_annotated_timelines(load_from_pickle=True) +# TalkLifeDataset = data_handler.TalkLifeDataset() +# df = TalkLifeDataset.return_annotated_timelines(load_from_pickle=True) df = pd.read_pickle(paths.data_fname) -df = df[df['content']!='nan'] +df = df[df["content"] != "nan"] -#labels in numbers +# labels in numbers dictl = {} -dictl['0'] = '0' -dictl['IE'] = '1' -dictl['IEP'] = '1' -dictl['ISB'] = '2' -dictl['IS'] = '2'# +dictl["0"] = "0" +dictl["IE"] = "1" +dictl["IEP"] = "1" +dictl["ISB"] = "2" +dictl["IS"] = "2" # -#GET THE FLAT LABELS +# GET THE FLAT LABELS df = df.replace({"label": dictl}) -#read pickle of folds +# read pickle of folds folds_fname = paths.folds_fname -with open(folds_fname, 'rb') as f: +with open(folds_fname, "rb") as f: folds = pickle.load(f) -#obtain columns for train/dev/test for each fold +# obtain columns for train/dev/test for each fold for fold in folds.keys(): - df['fold'+str(fold)] = df['timeline_id'].map(lambda x: 'train' if x in folds[fold]['train'] else ('dev' if x in folds[fold]['dev'] else 'test')) + df["fold" + str(fold)] = df["timeline_id"].map( + lambda x: "train" + if x in folds[fold]["train"] + else ("dev" if x in folds[fold]["dev"] else "test") + ) -#rest index +# rest index df = df.reset_index(drop=True) ################################################## ########## Dimensions and Y labels ############### ################################################## -output_dim = len(df['label'].unique()) -y_data = torch.tensor(df['label'].astype(float).values, dtype=torch.int64) \ No newline at end of file +output_dim = len(df["label"].unique()) +y_data = torch.tensor(df["label"].astype(float).values, dtype=torch.int64) diff --git a/notebooks/Talklife_MoC/talklifemoc-baseline-bert.ipynb b/notebooks/Talklife_MoC/talklifemoc-baseline-bert.ipynb index 817325b..c8fd1d9 100644 --- a/notebooks/Talklife_MoC/talklifemoc-baseline-bert.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-baseline-bert.ipynb @@ -131,11 +131,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -219,7 +225,7 @@ " seeds=seeds,\n", " loss=loss,\n", " gamma=gamma,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " validation_metric=validation_metric,\n", diff --git a/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w11.ipynb b/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w11.ipynb index 9dc888b..ef88d15 100644 --- a/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w11.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w11.ipynb @@ -23,7 +23,7 @@ "\n", "# set device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device = 'cuda:1'" + "device = \"cuda:1\"" ] }, { @@ -124,11 +124,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -394,13 +400,13 @@ " loss=loss,\n", " gamma=gamma,\n", " device=device,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/lstm_history_{size}_focal_{gamma}_kfold.csv\",\n", " verbose=False,\n", - ") " + ")" ] }, { @@ -1450,13 +1456,25 @@ } ], "source": [ - "best_bilstm_history_11[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', \n", - " 'hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', \n", - " 'batch_size']]" + "best_bilstm_history_11[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w20.ipynb b/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w20.ipynb index a1c5966..d82a880 100644 --- a/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w20.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w20.ipynb @@ -23,7 +23,7 @@ "\n", "# set device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device = 'cuda:1'" + "device = \"cuda:1\"" ] }, { @@ -124,11 +124,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -394,13 +400,13 @@ " loss=loss,\n", " gamma=gamma,\n", " device=device,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/lstm_history_{size}_focal_{gamma}_kfold.csv\",\n", " verbose=False,\n", - ") " + ")" ] }, { @@ -1450,13 +1456,25 @@ } ], "source": [ - "best_bilstm_history_20[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', \n", - " 'hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', \n", - " 'batch_size']]" + "best_bilstm_history_20[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w35.ipynb b/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w35.ipynb index 338693c..2440d24 100644 --- a/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w35.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w35.ipynb @@ -23,7 +23,7 @@ "\n", "# set device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device = 'cuda:0'" + "device = \"cuda:0\"" ] }, { @@ -124,11 +124,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -394,13 +400,13 @@ " loss=loss,\n", " gamma=gamma,\n", " device=device,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/lstm_history_{size}_focal_{gamma}_kfold.csv\",\n", " verbose=False,\n", - ") " + ")" ] }, { @@ -1450,13 +1456,25 @@ } ], "source": [ - "best_bilstm_history_35[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', \n", - " 'hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', \n", - " 'batch_size']]" + "best_bilstm_history_35[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w5.ipynb b/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w5.ipynb index b5261a2..5c0424e 100644 --- a/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w5.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-baseline-bilstm-w5.ipynb @@ -136,11 +136,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -406,13 +412,13 @@ " loss=loss,\n", " gamma=gamma,\n", " device=device,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/lstm_history_{size}_focal_{gamma}_kfold.csv\",\n", " verbose=False,\n", - ") " + ")" ] }, { @@ -1462,13 +1468,25 @@ } ], "source": [ - "best_bilstm_history_5[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', \n", - " 'hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', \n", - " 'batch_size']]" + "best_bilstm_history_5[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-baseline-ffn-history.ipynb b/notebooks/Talklife_MoC/talklifemoc-baseline-ffn-history.ipynb index b7d1c49..e6c7a1d 100644 --- a/notebooks/Talklife_MoC/talklifemoc-baseline-ffn-history.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-baseline-ffn-history.ipynb @@ -125,11 +125,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -405,13 +411,13 @@ " loss=loss,\n", " gamma=gamma,\n", " device=device,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", - " results_output= f\"{output_dir}/ffn_mean_history_focal_{gamma}.csv\",\n", - " verbose=False\n", + " results_output=f\"{output_dir}/ffn_mean_history_focal_{gamma}.csv\",\n", + " verbose=False,\n", ")" ] }, @@ -1125,13 +1131,25 @@ } ], "source": [ - "best_ffn_mean_history[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', \n", - " 'hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', \n", - " 'batch_size']]" + "best_ffn_mean_history[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-baseline-ffn.ipynb b/notebooks/Talklife_MoC/talklifemoc-baseline-ffn.ipynb index 8834234..3ead992 100644 --- a/notebooks/Talklife_MoC/talklifemoc-baseline-ffn.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-baseline-ffn.ipynb @@ -120,7 +120,7 @@ "outputs": [], "source": [ "num_epochs = 100\n", - "hidden_dim_sizes = [[64,64],[128,128],[256,256],[512, 512]]\n", + "hidden_dim_sizes = [[64, 64], [128, 128], [256, 256], [512, 512]]\n", "dropout_rates = [0.1, 0.2]\n", "learning_rates = [1e-3, 1e-4, 5e-4]\n", "seeds = [1, 12, 123]\n", @@ -136,11 +136,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -341,7 +347,7 @@ } ], "source": [ - "ffn_current, best_ffn_current, _, __ = ffn_hyperparameter_search( \n", + "ffn_current, best_ffn_current, _, __ = ffn_hyperparameter_search(\n", " num_epochs=num_epochs,\n", " x_data=sbert_embeddings,\n", " y_data=y_data,\n", @@ -353,13 +359,13 @@ " loss=loss,\n", " gamma=gamma,\n", " device=device,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", - " results_output= f\"{output_dir}/ffn_current_focal_{gamma}.csv\",\n", - " verbose=False\n", + " results_output=f\"{output_dir}/ffn_current_focal_{gamma}.csv\",\n", + " verbose=False,\n", ")" ] }, @@ -668,13 +674,25 @@ } ], "source": [ - "best_ffn_current[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', \n", - " 'hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', \n", - " 'batch_size']]" + "best_ffn_current[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-swmhau-w11.ipynb b/notebooks/Talklife_MoC/talklifemoc-swmhau-w11.ipynb index 2034592..4f9f73d 100644 --- a/notebooks/Talklife_MoC/talklifemoc-swmhau-w11.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-swmhau-w11.ipynb @@ -33,7 +33,7 @@ "outputs": [], "source": [ "from nlpsig_networks.scripts.swmhau_network_functions import (\n", - " swmhau_network_hyperparameter_search\n", + " swmhau_network_hyperparameter_search,\n", ")" ] }, @@ -118,11 +118,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -134,13 +140,13 @@ "source": [ "num_epochs = 100\n", "embedding_dim = 384\n", - "dimensions = [15] \n", + "dimensions = [15]\n", "dimreduction_method = [\"umap\"]\n", "# define swmhau parameters: (output_channels, sig_depth, num_heads)\n", "swmhau_parameters = [(12, 3, 10), (8, 4, 6)]\n", "num_layers = [1]\n", - "ffn_hidden_dim_sizes = [[256,256],[512,512]]\n", - "dropout_rates = [0.1, 0.2] \n", + "ffn_hidden_dim_sizes = [[256, 256], [512, 512]]\n", + "dropout_rates = [0.1, 0.2]\n", "learning_rates = [1e-3, 1e-4, 5e-4]\n", "seeds = [1, 12, 123]\n", "loss = \"focal\"\n", @@ -512,7 +518,12 @@ ], "source": [ "size = 11\n", - "swmhau_network_umap_11, best_swmhau_network_umap_11, _, __ = swmhau_network_hyperparameter_search(\n", + "(\n", + " swmhau_network_umap_11,\n", + " best_swmhau_network_umap_11,\n", + " _,\n", + " __,\n", + ") = swmhau_network_hyperparameter_search(\n", " num_epochs=num_epochs,\n", " df=df,\n", " id_column=\"timeline_id\",\n", @@ -537,13 +548,13 @@ " standardise_method=standardise_method,\n", " include_features_in_path=include_features_in_path,\n", " include_features_in_input=include_features_in_input,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}.csv\",\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, @@ -553,7 +564,9 @@ "metadata": {}, "outputs": [], "source": [ - "best_swmhau_network_umap_11 = pd.read_csv('talklife_moc_output/swmhau_network_umap_focal_2_11_best_model.csv')" + "best_swmhau_network_umap_11 = pd.read_csv(\n", + " \"talklife_moc_output/swmhau_network_umap_focal_2_11_best_model.csv\"\n", + ")" ] }, { @@ -711,14 +724,30 @@ } ], "source": [ - "best_swmhau_network_umap_11[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', 'sig_depth',\n", - " 'num_heads', \n", - " 'ffn_hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', 'augmentation_type',\n", - " 'hidden_dim_aug', 'comb_method', 'batch_size']]" + "best_swmhau_network_umap_11[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"sig_depth\",\n", + " \"num_heads\",\n", + " \"ffn_hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"augmentation_type\",\n", + " \"hidden_dim_aug\",\n", + " \"comb_method\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { @@ -787,9 +816,28 @@ "metadata": {}, "outputs": [], "source": [ - "best_swmhau_network_umap_11['f1_scores'] = best_swmhau_network_umap_11['f1_scores'].map(lambda x: [float(idx.replace(\" \", \"\")) for idx in x.replace(\"[\", \"\").replace(\"]\", \"\").replace(\" 0\", \",0\").split(',')])\n", - "best_swmhau_network_umap_11['precision_scores'] = best_swmhau_network_umap_11['precision_scores'].map(lambda x: [float(idx.replace(\" \", \"\")) for idx in x.replace(\"[\", \"\").replace(\"]\", \"\").replace(\" 0\", \",0\").split(',')])\n", - "best_swmhau_network_umap_11['recall_scores'] = best_swmhau_network_umap_11['recall_scores'].map(lambda x: [float(idx.replace(\" \", \"\")) for idx in x.replace(\"[\", \"\").replace(\"]\", \"\").replace(\" 0\", \",0\").split(',')])" + "best_swmhau_network_umap_11[\"f1_scores\"] = best_swmhau_network_umap_11[\"f1_scores\"].map(\n", + " lambda x: [\n", + " float(idx.replace(\" \", \"\"))\n", + " for idx in x.replace(\"[\", \"\").replace(\"]\", \"\").replace(\" 0\", \",0\").split(\",\")\n", + " ]\n", + ")\n", + "best_swmhau_network_umap_11[\"precision_scores\"] = best_swmhau_network_umap_11[\n", + " \"precision_scores\"\n", + "].map(\n", + " lambda x: [\n", + " float(idx.replace(\" \", \"\"))\n", + " for idx in x.replace(\"[\", \"\").replace(\"]\", \"\").replace(\" 0\", \",0\").split(\",\")\n", + " ]\n", + ")\n", + "best_swmhau_network_umap_11[\"recall_scores\"] = best_swmhau_network_umap_11[\n", + " \"recall_scores\"\n", + "].map(\n", + " lambda x: [\n", + " float(idx.replace(\" \", \"\"))\n", + " for idx in x.replace(\"[\", \"\").replace(\"]\", \"\").replace(\" 0\", \",0\").split(\",\")\n", + " ]\n", + ")" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-swmhau-w20.ipynb b/notebooks/Talklife_MoC/talklifemoc-swmhau-w20.ipynb index cfce3bd..83c9903 100644 --- a/notebooks/Talklife_MoC/talklifemoc-swmhau-w20.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-swmhau-w20.ipynb @@ -33,7 +33,7 @@ "outputs": [], "source": [ "from nlpsig_networks.scripts.swmhau_network_functions import (\n", - " swmhau_network_hyperparameter_search\n", + " swmhau_network_hyperparameter_search,\n", ")" ] }, @@ -118,11 +118,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -134,13 +140,13 @@ "source": [ "num_epochs = 100\n", "embedding_dim = 384\n", - "dimensions = [15] \n", + "dimensions = [15]\n", "dimreduction_method = [\"umap\"]\n", "# define swmhau parameters: (output_channels, sig_depth, num_heads)\n", "swmhau_parameters = [(12, 3, 10), (8, 4, 6)]\n", "num_layers = [1]\n", - "ffn_hidden_dim_sizes = [[256,256],[512,512]]\n", - "dropout_rates = [0.1, 0.2] \n", + "ffn_hidden_dim_sizes = [[256, 256], [512, 512]]\n", + "dropout_rates = [0.1, 0.2]\n", "learning_rates = [1e-3, 1e-4, 5e-4]\n", "seeds = [1, 12, 123]\n", "loss = \"focal\"\n", @@ -512,7 +518,12 @@ ], "source": [ "size = 20\n", - "swmhau_network_umap_20, best_swmhau_network_umap_20, _, __ = swmhau_network_hyperparameter_search(\n", + "(\n", + " swmhau_network_umap_20,\n", + " best_swmhau_network_umap_20,\n", + " _,\n", + " __,\n", + ") = swmhau_network_hyperparameter_search(\n", " num_epochs=num_epochs,\n", " df=df,\n", " id_column=\"timeline_id\",\n", @@ -537,13 +548,13 @@ " standardise_method=standardise_method,\n", " include_features_in_path=include_features_in_path,\n", " include_features_in_input=include_features_in_input,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}.csv\",\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, @@ -702,14 +713,30 @@ } ], "source": [ - "best_swmhau_network_umap_20[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', 'sig_depth',\n", - " 'num_heads', \n", - " 'ffn_hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', 'augmentation_type',\n", - " 'hidden_dim_aug', 'comb_method', 'batch_size']]" + "best_swmhau_network_umap_20[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"sig_depth\",\n", + " \"num_heads\",\n", + " \"ffn_hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"augmentation_type\",\n", + " \"hidden_dim_aug\",\n", + " \"comb_method\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-swmhau-w5.ipynb b/notebooks/Talklife_MoC/talklifemoc-swmhau-w5.ipynb index 00d6c34..1a7f4d2 100644 --- a/notebooks/Talklife_MoC/talklifemoc-swmhau-w5.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-swmhau-w5.ipynb @@ -24,7 +24,7 @@ "\n", "# set device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device = 'cuda:1'" + "device = \"cuda:1\"" ] }, { @@ -34,7 +34,7 @@ "outputs": [], "source": [ "from nlpsig_networks.scripts.swmhau_network_functions import (\n", - " swmhau_network_hyperparameter_search\n", + " swmhau_network_hyperparameter_search,\n", ")" ] }, @@ -119,11 +119,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -135,13 +141,13 @@ "source": [ "num_epochs = 100\n", "embedding_dim = 384\n", - "dimensions = [15] \n", + "dimensions = [15]\n", "dimreduction_method = [\"umap\"]\n", "# define swmhau parameters: (output_channels, sig_depth, num_heads)\n", "swmhau_parameters = [(12, 3, 10), (8, 4, 6)]\n", "num_layers = [1]\n", - "ffn_hidden_dim_sizes = [[256,256],[512,512]]\n", - "dropout_rates = [0.1, 0.2] \n", + "ffn_hidden_dim_sizes = [[256, 256], [512, 512]]\n", + "dropout_rates = [0.1, 0.2]\n", "learning_rates = [1e-3, 1e-4, 5e-4]\n", "seeds = [1, 12, 123]\n", "loss = \"focal\"\n", @@ -513,7 +519,12 @@ ], "source": [ "size = 5\n", - "swmhau_network_umap_5, best_swmhau_network_umap_5, _, __ = swmhau_network_hyperparameter_search(\n", + "(\n", + " swmhau_network_umap_5,\n", + " best_swmhau_network_umap_5,\n", + " _,\n", + " __,\n", + ") = swmhau_network_hyperparameter_search(\n", " num_epochs=num_epochs,\n", " df=df,\n", " id_column=\"timeline_id\",\n", @@ -538,13 +549,13 @@ " standardise_method=standardise_method,\n", " include_features_in_path=include_features_in_path,\n", " include_features_in_input=include_features_in_input,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}.csv\",\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, @@ -703,14 +714,30 @@ } ], "source": [ - "best_swmhau_network_umap_5[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', 'sig_depth',\n", - " 'num_heads', \n", - " 'ffn_hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', 'augmentation_type',\n", - " 'hidden_dim_aug', 'comb_method', 'batch_size']]" + "best_swmhau_network_umap_5[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"sig_depth\",\n", + " \"num_heads\",\n", + " \"ffn_hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"augmentation_type\",\n", + " \"hidden_dim_aug\",\n", + " \"comb_method\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-swnu-bilstm-35.ipynb b/notebooks/Talklife_MoC/talklifemoc-swnu-bilstm-35.ipynb index e864a2a..3603aef 100644 --- a/notebooks/Talklife_MoC/talklifemoc-swnu-bilstm-35.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-swnu-bilstm-35.ipynb @@ -116,11 +116,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -132,7 +138,7 @@ "source": [ "num_epochs = 100\n", "embedding_dim = 384\n", - "dimensions = [15] \n", + "dimensions = [15]\n", "dimreduction_method = [\"umap\"]\n", "swnu_hidden_dim_sizes_and_sig_depths = [([12], 3), ([10], 4)]\n", "lstm_hidden_dim_sizes = [384]\n", @@ -567,7 +573,7 @@ " best_seqsignet_network_umap_35,\n", " _,\n", " __,\n", - ") = seqsignet_hyperparameter_search( \n", + ") = seqsignet_hyperparameter_search(\n", " num_epochs=num_epochs,\n", " df=df,\n", " id_column=\"timeline_id\",\n", @@ -595,13 +601,13 @@ " standardise_method=standardise_method,\n", " include_features_in_path=include_features_in_path,\n", " include_features_in_input=include_features_in_input,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/seqsignet_umap_focal_{gamma}_{size}_kfold.csv\",\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, @@ -620,12 +626,30 @@ "metadata": {}, "outputs": [], "source": [ - "best_seqsignet_network_umap_35[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', 'features','standardise_method' ,'output_channels','lstm_hidden_dim',\n", - " 'dimensions','swnu_hidden_dim','sig_depth','ffn_hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'batch_size']]" + "best_seqsignet_network_umap_35[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"features\",\n", + " \"standardise_method\",\n", + " \"output_channels\",\n", + " \"lstm_hidden_dim\",\n", + " \"dimensions\",\n", + " \"swnu_hidden_dim\",\n", + " \"sig_depth\",\n", + " \"ffn_hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-swnu-bilstm.ipynb b/notebooks/Talklife_MoC/talklifemoc-swnu-bilstm.ipynb index ed57fde..5c461b5 100644 --- a/notebooks/Talklife_MoC/talklifemoc-swnu-bilstm.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-swnu-bilstm.ipynb @@ -24,7 +24,7 @@ "\n", "# set device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device='cuda:1'" + "device = \"cuda:1\"" ] }, { @@ -117,11 +117,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -133,7 +139,7 @@ "source": [ "num_epochs = 100\n", "embedding_dim = 384\n", - "dimensions = [15] \n", + "dimensions = [15]\n", "dimreduction_method = [\"umap\"]\n", "swnu_hidden_dim_sizes_and_sig_depths = [([12], 3), ([10], 4)]\n", "lstm_hidden_dim_sizes = [384]\n", @@ -570,7 +576,7 @@ " best_seqsignet_network_umap_20,\n", " _,\n", " __,\n", - ") = seqsignet_hyperparameter_search( \n", + ") = seqsignet_hyperparameter_search(\n", " num_epochs=num_epochs,\n", " df=df,\n", " id_column=\"timeline_id\",\n", @@ -598,13 +604,13 @@ " standardise_method=standardise_method,\n", " include_features_in_path=include_features_in_path,\n", " include_features_in_input=include_features_in_input,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/seqsignet_umap_focal_{gamma}_{size}_kfold.csv\",\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, @@ -1178,12 +1184,30 @@ } ], "source": [ - "best_seqsignet_network_umap_20[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', 'features','standardise_method' ,'output_channels','lstm_hidden_dim',\n", - " 'dimensions','swnu_hidden_dim','sig_depth','ffn_hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'batch_size']]" + "best_seqsignet_network_umap_20[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"features\",\n", + " \"standardise_method\",\n", + " \"output_channels\",\n", + " \"lstm_hidden_dim\",\n", + " \"dimensions\",\n", + " \"swnu_hidden_dim\",\n", + " \"sig_depth\",\n", + " \"ffn_hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-swnu-w11.ipynb b/notebooks/Talklife_MoC/talklifemoc-swnu-w11.ipynb index 502bba2..7cca10a 100644 --- a/notebooks/Talklife_MoC/talklifemoc-swnu-w11.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-swnu-w11.ipynb @@ -24,7 +24,7 @@ "\n", "# set device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device = 'cuda:1'" + "device = \"cuda:1\"" ] }, { @@ -120,11 +120,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -540,7 +546,8 @@ ], "source": [ "size = 11\n", - "(swnu_network_umap_11,\n", + "(\n", + " swnu_network_umap_11,\n", " best_swnu_network_umap_11,\n", " _,\n", " __,\n", @@ -569,13 +576,13 @@ " standardise_method=standardise_method,\n", " include_features_in_path=include_features_in_path,\n", " include_features_in_input=include_features_in_input,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv\",\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, @@ -2164,13 +2171,28 @@ } ], "source": [ - "best_swnu_network_umap_11[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', \n", - " 'dimensions','swnu_hidden_dim','sig_depth','ffn_hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', \n", - " 'batch_size']]" + "best_swnu_network_umap_11[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"dimensions\",\n", + " \"swnu_hidden_dim\",\n", + " \"sig_depth\",\n", + " \"ffn_hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-swnu-w20.ipynb b/notebooks/Talklife_MoC/talklifemoc-swnu-w20.ipynb index 3cba6fc..df649fb 100644 --- a/notebooks/Talklife_MoC/talklifemoc-swnu-w20.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-swnu-w20.ipynb @@ -119,11 +119,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -539,7 +545,8 @@ ], "source": [ "size = 20\n", - "(swnu_network_umap_20,\n", + "(\n", + " swnu_network_umap_20,\n", " best_swnu_network_umap_20,\n", " _,\n", " __,\n", @@ -568,13 +575,13 @@ " standardise_method=standardise_method,\n", " include_features_in_path=include_features_in_path,\n", " include_features_in_input=include_features_in_input,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv\",\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, @@ -2163,13 +2170,28 @@ } ], "source": [ - "best_swnu_network_umap_20[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', \n", - " 'dimensions','swnu_hidden_dim','sig_depth','ffn_hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', \n", - " 'batch_size']]" + "best_swnu_network_umap_20[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"dimensions\",\n", + " \"swnu_hidden_dim\",\n", + " \"sig_depth\",\n", + " \"ffn_hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/Talklife_MoC/talklifemoc-swnu-w5.ipynb b/notebooks/Talklife_MoC/talklifemoc-swnu-w5.ipynb index 8db4486..a93ce03 100644 --- a/notebooks/Talklife_MoC/talklifemoc-swnu-w5.ipynb +++ b/notebooks/Talklife_MoC/talklifemoc-swnu-w5.ipynb @@ -24,7 +24,7 @@ "\n", "# set device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "device = 'cuda:2'" + "device = \"cuda:2\"" ] }, { @@ -120,11 +120,17 @@ "metadata": {}, "outputs": [], "source": [ - "#create indices for kfold\n", - "fold_col_names = [c for c in df.columns if 'fold' in c ]\n", + "# create indices for kfold\n", + "fold_col_names = [c for c in df.columns if \"fold\" in c]\n", "fold_list = []\n", "for foldc in fold_col_names:\n", - " fold_list.append((df[df[foldc]=='train'].index, df[df[foldc]=='dev'].index, df[df[foldc]=='test'].index))\n", + " fold_list.append(\n", + " (\n", + " df[df[foldc] == \"train\"].index,\n", + " df[df[foldc] == \"dev\"].index,\n", + " df[df[foldc] == \"test\"].index,\n", + " )\n", + " )\n", "fold_list = tuple(fold_list)" ] }, @@ -540,7 +546,8 @@ ], "source": [ "size = 5\n", - "(swnu_network_umap_5,\n", + "(\n", + " swnu_network_umap_5,\n", " best_swnu_network_umap_5,\n", " _,\n", " __,\n", @@ -569,13 +576,13 @@ " standardise_method=standardise_method,\n", " include_features_in_path=include_features_in_path,\n", " include_features_in_input=include_features_in_input,\n", - " split_ids=None, \n", + " split_ids=None,\n", " split_indices=fold_list,\n", " k_fold=True,\n", " patience=patience,\n", " validation_metric=validation_metric,\n", " results_output=f\"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv\",\n", - " verbose=False\n", + " verbose=False,\n", ")" ] }, @@ -2164,13 +2171,28 @@ } ], "source": [ - "best_swnu_network_umap_5[['f1', 'f1_scores', 'precision', \n", - " 'recall', 'valid_f1',\n", - " 'valid_f1_scores', 'valid_precision', \n", - " 'valid_recall', \n", - " 'dimensions','swnu_hidden_dim','sig_depth','ffn_hidden_dim', 'dropout_rate', 'learning_rate', 'seed',\n", - " 'loss_function', 'k_fold', \n", - " 'batch_size']]" + "best_swnu_network_umap_5[\n", + " [\n", + " \"f1\",\n", + " \"f1_scores\",\n", + " \"precision\",\n", + " \"recall\",\n", + " \"valid_f1\",\n", + " \"valid_f1_scores\",\n", + " \"valid_precision\",\n", + " \"valid_recall\",\n", + " \"dimensions\",\n", + " \"swnu_hidden_dim\",\n", + " \"sig_depth\",\n", + " \"ffn_hidden_dim\",\n", + " \"dropout_rate\",\n", + " \"learning_rate\",\n", + " \"seed\",\n", + " \"loss_function\",\n", + " \"k_fold\",\n", + " \"batch_size\",\n", + " ]\n", + "]" ] }, { diff --git a/notebooks/results-readout.py b/notebooks/results-readout.py index 46e5ce0..ab4c438 100644 --- a/notebooks/results-readout.py +++ b/notebooks/results-readout.py @@ -6,6 +6,7 @@ sizes = [5, 11, 20, 35, 80] seqsignet_sizes = [(3, 5, 3), (3, 5, 6), (3, 5, 11), (3, 5, 26)] + def readout_results_from_csv(csv_filename: str, model_name: str, digits: int): try: results_df = pd.read_csv(csv_filename) @@ -15,18 +16,37 @@ def readout_results_from_csv(csv_filename: str, model_name: str, digits: int): print(f"Precision: {round(results_df['precision'].mean(), digits)}") print(f"Recall: {round(results_df['recall'].mean(), digits)}") # print individual class F1, precision and recall scores averaged - f1_scores_stacked = np.stack(results_df['f1_scores'].apply(lambda x: list(np.fromstring(x[1:-1], sep=' ')))) - print(f"F1 scores: {[round(x, digits) for x in f1_scores_stacked.mean(axis=0)]}") - precision_scores_stacked = np.stack(results_df['precision_scores'].apply(lambda x: list(np.fromstring(x[1:-1], sep=' ')))) - print(f"Precision scores: {[round(x, digits) for x in precision_scores_stacked.mean(axis=0)]}") - recall_scores_stacked = np.stack(results_df['recall_scores'].apply(lambda x: list(np.fromstring(x[1:-1], sep=' ')))) - print(f"Recall scores: {[round(x, digits) for x in recall_scores_stacked.mean(axis=0)]}") + f1_scores_stacked = np.stack( + results_df["f1_scores"].apply( + lambda x: list(np.fromstring(x[1:-1], sep=" ")) + ) + ) + print( + f"F1 scores: {[round(x, digits) for x in f1_scores_stacked.mean(axis=0)]}" + ) + precision_scores_stacked = np.stack( + results_df["precision_scores"].apply( + lambda x: list(np.fromstring(x[1:-1], sep=" ")) + ) + ) + print( + f"Precision scores: {[round(x, digits) for x in precision_scores_stacked.mean(axis=0)]}" + ) + recall_scores_stacked = np.stack( + results_df["recall_scores"].apply( + lambda x: list(np.fromstring(x[1:-1], sep=" ")) + ) + ) + print( + f"Recall scores: {[round(x, digits) for x in recall_scores_stacked.mean(axis=0)]}" + ) print("\n") except: print(f"Error reading {csv_filename}") + def main(): - # parse command line arguments + # parse command line arguments parser = argparse.ArgumentParser() parser.add_argument( "--results-dir", @@ -47,14 +67,18 @@ def main(): gamma = 2 sizes = [5, 11, 20, 35] seqsignet_sizes = [(3, 5, 3), (3, 5, 6), (3, 5, 11)] - + # readout FFN file = f"{args.results_dir}/ffn_current_focal_{gamma}_kfold_best_model.csv" readout_results_from_csv(file, model_name="FFN with current", digits=args.digits) # readout FFN with history concatenation file = f"{args.results_dir}/ffn_mean_history_focal_{gamma}_kfold_best_model.csv" - readout_results_from_csv(file, model_name="FFN with mean history concatenated with current", digits=args.digits) + readout_results_from_csv( + file, + model_name="FFN with mean history concatenated with current", + digits=args.digits, + ) # readout BERT (focal loss) file = f"{args.results_dir}/bert_classifier_focal.csv" @@ -62,40 +86,57 @@ def main(): # readout BERT (ce) file = f"{args.results_dir}/bert_classifier_ce.csv" - readout_results_from_csv(file, model_name="BERT (cross-entropy)", digits=args.digits) + readout_results_from_csv( + file, model_name="BERT (cross-entropy)", digits=args.digits + ) # readout LSTM for size in sizes: - file = f"{args.results_dir}/lstm_history_{size}_focal_{gamma}_kfold_best_model.csv" - readout_results_from_csv(file, model_name=f"BiLSTM (size={size})", digits=args.digits) - + file = ( + f"{args.results_dir}/lstm_history_{size}_focal_{gamma}_kfold_best_model.csv" + ) + readout_results_from_csv( + file, model_name=f"BiLSTM (size={size})", digits=args.digits + ) + # readout SWNU-Network for size in sizes: file = f"{args.results_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold_best_model.csv" - readout_results_from_csv(file, model_name=f"SWNU-Network (size={size})", digits=args.digits) + readout_results_from_csv( + file, model_name=f"SWNU-Network (size={size})", digits=args.digits + ) # readout SWMHAU-Network for size in sizes: file = f"{args.results_dir}/swmhau_network_umap_focal_{gamma}_{size}_kfold_best_model.csv" - readout_results_from_csv(file, model_name=f"SWMHAU-Network (size={size})", digits=args.digits) + readout_results_from_csv( + file, model_name=f"SWMHAU-Network (size={size})", digits=args.digits + ) # readout SeqSigNet for shift, window_size, n in seqsignet_sizes: file = f"{args.results_dir}/seqsignet_umap_focal_{gamma}_{shift}_{window_size}_{n}_kfold_best_model.csv" k = shift * n + (window_size - shift) - readout_results_from_csv(file, model_name=f"SeqSigNet (size={k})", digits=args.digits) + readout_results_from_csv( + file, model_name=f"SeqSigNet (size={k})", digits=args.digits + ) # readout SeqSigNetAttentionBiLSTM for shift, window_size, n in seqsignet_sizes: file = f"{args.results_dir}/seqsignet_attention_bilstm_umap_focal_{gamma}_{shift}_{window_size}_{n}_kfold_best_model.csv" k = shift * n + (window_size - shift) - readout_results_from_csv(file, model_name=f"SeqSigNetAttentionBiLSTM (size={k})", digits=args.digits) - + readout_results_from_csv( + file, model_name=f"SeqSigNetAttentionBiLSTM (size={k})", digits=args.digits + ) + # readout SeqSigNetAttentionEncoder for shift, window_size, n in seqsignet_sizes: file = f"{args.results_dir}/seqsignet_attention_encoder_umap_focal_{gamma}_{shift}_{window_size}_{n}_kfold_best_model.csv" k = shift * n + (window_size - shift) - readout_results_from_csv(file, model_name=f"SeqSigNetAttentionEncoder (size={k})", digits=args.digits) + readout_results_from_csv( + file, model_name=f"SeqSigNetAttentionEncoder (size={k})", digits=args.digits + ) + if __name__ == "__main__": - main() \ No newline at end of file + main() From 2ca51b869a850386c8f23cad890fc6c2d5be4110 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 19:53:52 +0000 Subject: [PATCH 03/16] develop readme --- README.md | 48 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 462d9cf..81c4607 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,54 @@ -# sig-networks +# SigNetworks: Sequential Path Signature Networks -_sig-networks_ is a package for training and evaluating neural networks for -longitudinal NLP classification tasks. +SigNetworks (`sig-networks`) is a PyTorch package for training and evaluating +neural networks for longitudinal NLP classification tasks. `sig-networks` is a +library that applies models first developed in +[Sequential Path Signature Networks for Personalised Longitudinal Language Modeling](https://aclanthology.org/2023.findings-acl.310/) +by Tseriotou et al. (2023) which presented a novel extension of neural +sequential models using the notion of path signatures from rough path theory. ## Installation -... +SigNetworks is available on PyPI and can be installed with pip: + +```bash +pip install sig_networks +``` + +### Signatory/Torch + +SigNetworks depends on the +[`patrick-kidger/signatory`](https://github.com/patrick-kidger/signatory) +library for differentiable computation of path signatures/log-signatures in +PyTorch. Please see the +[signatory documentation](https://signatory.readthedocs.io/en/latest/) for +installation instructions of the signatory library. ## Usage -The library is still under development but it is possible to train and evaluate -several models in a few lines of code. +The key components in the _signature-window_ model s presented in (see +[Sequential Path Signature Networks for Personalised Longitudinal Language Modeling](https://aclanthology.org/2023.findings-acl.310/) +for full details) are written as PyTorch modules which can be used in a modular +fashion. The key components are: + +- The Signature Window Network Units (SWNUs): + [`sig_networks.SWNU`](src/sig_networks/swnu.py) +- The Signature Window (Multihead-)Attention Units (SWMHAUs): + [`sig_networks.SWMHAU`](src/sig_networks/swmhau.py) +- The SWNU-Network model: + [`sig_networks.SWNUNetwork`](src/sig_networks/swnu_network.py) +- The SWMHAU-Network model: + [`sig_networks.SWMHAUNetwork`](src/sig_networks/swmhau_network.py) +- The SeqSigNet model: + [`sig_networks.SeqSigNet`](src/sig_networks/seqsignet_bilstm.py) +- The SeqSigNet-Attention-Encoder model: + [`sig_networks.SeqSigNetAttentionEncoder`](src/sig_networks/seqsignet_attention.py) +- The SeqSigNet-Attention-BiLSTM model: + [`sig_networks.SeqSigNetAttentionBiLSTM`](src/sig_networks/seqsignet_attention_bilstm.py) +```python ... +``` ## Pre-commit and linters From b103c1432ff17a734155cefeb8a4b69ec2ca0b75 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 20:07:39 +0000 Subject: [PATCH 04/16] add test --- src/sig_networks/_compat/__init__.py | 1 + src/sig_networks/_compat/typing.py | 14 ++++++++++++++ tests/test_package.py | 23 +++++++++++++++++++++++ 3 files changed, 38 insertions(+) create mode 100644 src/sig_networks/_compat/__init__.py create mode 100644 src/sig_networks/_compat/typing.py create mode 100644 tests/test_package.py diff --git a/src/sig_networks/_compat/__init__.py b/src/sig_networks/_compat/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/src/sig_networks/_compat/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/sig_networks/_compat/typing.py b/src/sig_networks/_compat/typing.py new file mode 100644 index 0000000..9c94ee1 --- /dev/null +++ b/src/sig_networks/_compat/typing.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import sys + +if sys.version_info < (3, 8): + from typing import Literal, Protocol, runtime_checkable +else: + from typing import Literal, Protocol, runtime_checkable + +__all__ = ["Protocol", "runtime_checkable", "Literal"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/tests/test_package.py b/tests/test_package.py new file mode 100644 index 0000000..a833527 --- /dev/null +++ b/tests/test_package.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import sig_networks as m +from sig_networks._compat.typing import Protocol, runtime_checkable + + +def test_version(): + assert m.__version__ + + +@runtime_checkable +class HasQuack(Protocol): + def quack() -> str: + ... + + +class Duck: + def quack() -> str: + return "quack" + + +def test_has_typing(): + assert isinstance(Duck(), HasQuack) From da2ac22250be510eb2bf43ef11e8c69799b74154 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 20:35:16 +0000 Subject: [PATCH 05/16] install torch before library --- .github/workflows/ci.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b2c0984..e95ffe5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,6 +52,9 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install torch + run: python -m pip install torch==1.9.0 + - name: Install package run: python -m pip install .[test] From b80b3670294edfadc2b6a82b97826463e8982d2c Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 20:41:13 +0000 Subject: [PATCH 06/16] signatory installation issue help --- .github/workflows/ci.yml | 5 +---- README.md | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e95ffe5..5a3d686 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,11 +52,8 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Install torch - run: python -m pip install torch==1.9.0 - - name: Install package - run: python -m pip install .[test] + run: python -m pip install torch==1.9.0 & python -m pip install .[test] - name: Test package run: python -m pytest -ra --cov=nlpsig diff --git a/README.md b/README.md index 81c4607..706b930 100644 --- a/README.md +++ b/README.md @@ -24,17 +24,35 @@ PyTorch. Please see the [signatory documentation](https://signatory.readthedocs.io/en/latest/) for installation instructions of the signatory library. +A common `signatory` installation issue is that the installation requires that +you already have PyTorch installed. In this case, you can try the following: + +```bash +# install PyTorch +pip install torch==1.9.0 +# install signatory +pip install signatory==1.2.6.1.9.0 +# install sig_networks +pip install sig_networks +``` + ## Usage -The key components in the _signature-window_ model s presented in (see +The key components in the _signature-window_ models presented in (see [Sequential Path Signature Networks for Personalised Longitudinal Language Modeling](https://aclanthology.org/2023.findings-acl.310/) for full details) are written as PyTorch modules which can be used in a modular fashion. The key components are: - The Signature Window Network Units (SWNUs): [`sig_networks.SWNU`](src/sig_networks/swnu.py) + - There also exists the SWLSTM module which does not include the + 1D-convolution layer at the start of the SWNU: + [`sig_networks.SWLSTM`](src/sig_networks/swnu.py) - The Signature Window (Multihead-)Attention Units (SWMHAUs): [`sig_networks.SWMHAU`](src/sig_networks/swmhau.py) + - As with the SWNU, there also exists the SWMHA module which does not include + the 1D-convolution layer at the start of the SWMHAU: + [`sig_networks.SWMHA`](src/sig_networks/swmhau.py) - The SWNU-Network model: [`sig_networks.SWNUNetwork`](src/sig_networks/swnu_network.py) - The SWMHAU-Network model: From e32538913cde07454c6f46663f8ccb58aba57e11 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 20:46:22 +0000 Subject: [PATCH 07/16] attempt to install torch first in ci --- .github/workflows/ci.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5a3d686..9d78a45 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,8 +52,12 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install torch + run: | + python -m pip install torch==1.9.0 + - name: Install package - run: python -m pip install torch==1.9.0 & python -m pip install .[test] + run: python -m pip install .[test] - name: Test package run: python -m pytest -ra --cov=nlpsig From 8a20507b8b12f026d79f78f18f101eaafe0e5f58 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 20:48:53 +0000 Subject: [PATCH 08/16] attempt to install torch first in ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d78a45..402d2c5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: - name: Install torch run: | - python -m pip install torch==1.9.0 + pip install torch==1.9.0 - name: Install package run: python -m pip install .[test] From 3033a62fc0f2f9a8968bd63f5503fc3e26ec7688 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 20:53:07 +0000 Subject: [PATCH 09/16] explicitly install signatory after torch in ci --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 402d2c5..e7ce184 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,6 +56,10 @@ jobs: run: | pip install torch==1.9.0 + - name: Install signatory + run: | + pip install signatory==1.2.6.1.9.0 + - name: Install package run: python -m pip install .[test] From d13c54b0c698bafd8cac62998ef0b89075059eb9 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 21:13:25 +0000 Subject: [PATCH 10/16] explicitly install signatory after torch in ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e7ce184..1f5507b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,7 +58,7 @@ jobs: - name: Install signatory run: | - pip install signatory==1.2.6.1.9.0 + pip install signatory==1.2.6.1.9.0 --no-cache-dir --force-reinstall - name: Install package run: python -m pip install .[test] From 66f4b19e7dab7b5f2510a1e2f7bc89f7c117c750 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 21:20:19 +0000 Subject: [PATCH 11/16] debug signatory issue --- .github/workflows/ci.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1f5507b..f7b4d87 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,6 +56,12 @@ jobs: run: | pip install torch==1.9.0 + - name: Debug installed packages + run: pip list + + - name: Debug Python environment + run: python -V + - name: Install signatory run: | pip install signatory==1.2.6.1.9.0 --no-cache-dir --force-reinstall From fea652a811edb76eb536a7c69d89d5884446bb6c Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 21:31:16 +0000 Subject: [PATCH 12/16] add more torch install --- .github/workflows/ci.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f7b4d87..f39894f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,8 +53,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install torch - run: | - pip install torch==1.9.0 + run: pip install torch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 - name: Debug installed packages run: pip list From 018f49e40c1f22735aa836b51456eeacaf2db193 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 22:21:33 +0000 Subject: [PATCH 13/16] update readme --- .github/workflows/ci.yml | 74 ++++++++++++------------- README.md | 110 ++++++++++++++++++++++++++++++++++++- src/sig_networks/swmhau.py | 2 +- 3 files changed, 147 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f39894f..541d546 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,43 +33,43 @@ jobs: with: extra_args: --hook-stage manual --all-files - checks: - name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} - runs-on: ${{ matrix.runs-on }} - needs: [pre-commit] - strategy: - fail-fast: false - matrix: - python-version: ["3.8"] - runs-on: [ubuntu-latest, macos-latest] - - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - - name: Install torch - run: pip install torch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 - - - name: Debug installed packages - run: pip list - - - name: Debug Python environment - run: python -V - - - name: Install signatory - run: | - pip install signatory==1.2.6.1.9.0 --no-cache-dir --force-reinstall - - - name: Install package - run: python -m pip install .[test] - - - name: Test package - run: python -m pytest -ra --cov=nlpsig + # checks: + # name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} + # runs-on: ${{ matrix.runs-on }} + # needs: [pre-commit] + # strategy: + # fail-fast: false + # matrix: + # python-version: ["3.8"] + # runs-on: [ubuntu-latest, macos-latest] + + # steps: + # - uses: actions/checkout@v4 + # with: + # fetch-depth: 0 + + # - uses: actions/setup-python@v4 + # with: + # python-version: ${{ matrix.python-version }} + + # - name: Install torch + # run: pip install torch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 + + # - name: Debug installed packages + # run: pip list + + # - name: Debug Python environment + # run: python -V + + # - name: Install signatory + # run: | + # pip install signatory==1.2.6.1.9.0 --no-cache-dir --force-reinstall + + # - name: Install package + # run: python -m pip install .[test] + + # - name: Test package + # run: python -m pytest -ra --cov=nlpsig dist: name: Distribution build diff --git a/README.md b/README.md index 706b930..e14eeb3 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,13 @@ SigNetworks is available on PyPI and can be installed with pip: pip install sig_networks ``` +Note that currently `sig_networks` only supports Python 3.8 since it relies on +the [`signatory`](https://github.com/patrick-kidger/signatory) library. However, +it is possible to install `signatory` with more recent Python and PyTorch +versions if you install it from source. See the installation guide in the +[signatory documentation](https://signatory.readthedocs.io/en/latest/pages/usage/installation.html) +for more details. + ### Signatory/Torch SigNetworks depends on the @@ -36,6 +43,10 @@ pip install signatory==1.2.6.1.9.0 pip install sig_networks ``` +If you encounter any issues with the installation of `signatory`, please see the +FAQs in the +[signatory documentation](https://signatory.readthedocs.io/en/latest/pages/miscellaneous/faq.html). + ## Usage The key components in the _signature-window_ models presented in (see @@ -64,10 +75,107 @@ fashion. The key components are: - The SeqSigNet-Attention-BiLSTM model: [`sig_networks.SeqSigNetAttentionBiLSTM`](src/sig_networks/seqsignet_attention_bilstm.py) +### Using the SWNU and SWMHAU modules + +The Signature Window units (SWNU and SWMHAU) accept a batch of streams and +returns a batch of feature representations. For example: + ```python -... +from sig_networks.swnu import SWNU +import torch + +# initialise a SWNU object +swnu = SWNU( + input_channels=10, + hidden_dim=5, + log_signature=False, + sig_depth=3, + pooling="signature", + BiLSTM=True, +) + +# create a three-dimensional tensor of batched streams +# shape [batch, length, channels] where channels = 10 +streams = torch.randn(2, 20, 10) + +# pass the streams through the SWNU +features = swnu(streams) + +# features is a two-dimensional tensor of shape [batch, signature_channels] +features.shape ``` +The SWMHAU is similar to the SWNU, but rather than having an LSTM to process the +signature streams, we have a multihead-attention layer. For example: + +```python +from sig_networks.swmhau import SWMHAU +import torch + +# initialise a SWMHAU object +swmhau = SWMHAU( + input_channels=10, + output_channels=5, + log_signature=False, + sig_depth=3, + num_heads=5, + num_layers=1, + dropout_rate=0.1, + pooling="signature", +) + +# create a three-dimensional tensor of batched streams +# shape [batch, length, channels] where channels = 10 +streams = torch.randn(2, 20, 10) + +# pass the streams through the SWMHAU +features = swmhau(streams) + +# features is a two-dimensional tensor of shape [batch, signature_channels] +features.shape +``` + +Note in the above, we used the `pooling="signature"` option. This means that at +the end of the SWNU/SWMHAU, we will take a final signature of the streams to get +a fixed-length feature representation for each item in the batch. There are +other options such as taking the final LSTM hidden state for SWNU (set +`pooling="lstm"`), or using a CLS pooling for SWMHAU (set `pooling="cls"`). +There is another option where `pooling=None` which means that the SWNU/SWMHAU +where no pooling is applied at the end of the SWNU/SWMHAU and the output is a +three-dimensional tensor of shape `[batch, length, hidden_dim]`. + +### Using the network models + +The library also has the SWNU-Network and SeqSigNet models as introduced in +[Sequential Path Signature Networks for Personalised Longitudinal Language Modeling](https://aclanthology.org/2023.findings-acl.310/). + +Since then, there have been developments of other models which utilise the SWNUs +and SWMHAUs discussed above. Each of these models are avaliable as PyTorch +modules which can be initialised and trained in the usual way. + +For SWNU-Network and SWMHAU-Network models, they expect two inputs: + +1. `path`: a batch of streams of shape `[batch, length, channels]` - these get + processed by the SWNU/SWMHAU +2. `features`: a batch of features of shape `[batch, features]` - these get + concatenated with the output of the SWNU/SWMHAU to be fed into a FFN layer + +For SeqSigNet models (e.g. SeqSigNet, SeqSigNet-Attention-Encoder, +SeqSigNet-Attention-BiLSTM), they also expect two inputs but the path is +slightly different: + +1. `path`: a batch of streams of shape `[batch, units, length, channels]` - each + of the units for each batch will get processed by the SWNU/SWMHAU. + Afterwards, there is a global network to process the outputs of the + SWNU/SWMHAU in order to pool the outputs into a single fixed-length feature + represenation for the history. The global network can either be a BiLSTM (in + the case of SeqSigNet and SeqSigNet-Attention-BiLSTM) or a Transformer + Encoder (in the case of SeqSigNet-Attention-Encoder). +2. `features`: a batch of features of shape `[batch, features]` - these get + concatenated with the output of the global network (either BiLSTM or a + Transformer Encoder) that processes the outputs of SWNU and SWMHAU to be fed + into a FFN layer + ## Pre-commit and linters To take advantage of `pre-commit`, which will automatically format your code and diff --git a/src/sig_networks/swmhau.py b/src/sig_networks/swmhau.py index 93e44c1..eed7f30 100644 --- a/src/sig_networks/swmhau.py +++ b/src/sig_networks/swmhau.py @@ -266,7 +266,7 @@ class SWMHAU(nn.Module): def __init__( self, input_channels: int, - output_channels: int | None, + output_channels: int, log_signature: bool, sig_depth: int, num_heads: int, From 6d5b16843a82d3e9821e8ded524f77d65f04e5ea Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 16 Nov 2023 22:30:41 +0000 Subject: [PATCH 14/16] fix link to seqsignet-attention-encoder --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e14eeb3..badb93b 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ fashion. The key components are: - The SeqSigNet model: [`sig_networks.SeqSigNet`](src/sig_networks/seqsignet_bilstm.py) - The SeqSigNet-Attention-Encoder model: - [`sig_networks.SeqSigNetAttentionEncoder`](src/sig_networks/seqsignet_attention.py) + [`sig_networks.SeqSigNetAttentionEncoder`](src/sig_networks/seqsignet_attention_encoder.py) - The SeqSigNet-Attention-BiLSTM model: [`sig_networks.SeqSigNetAttentionBiLSTM`](src/sig_networks/seqsignet_attention_bilstm.py) From 6c311ab41ce7a7f4110e9b3904f75f7d8d9ff65a Mon Sep 17 00:00:00 2001 From: rchan Date: Fri, 17 Nov 2023 13:58:37 +0000 Subject: [PATCH 15/16] add info on repo structure in readme --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index badb93b..c32a0d8 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,16 @@ If you encounter any issues with the installation of `signatory`, please see the FAQs in the [signatory documentation](https://signatory.readthedocs.io/en/latest/pages/miscellaneous/faq.html). +## Repo structure + +The key parts of the libary are found in [`src/`](src/): + +- [`src/sig_networks/`](src/sig_networks/) contains the source code for the + models and includes PyTorch modules for the various components of the models + (see below for more usage details) +- [`src/scripts/`](src/scripts/) contains some helper scripts for training and + evaluating the models + ## Usage The key components in the _signature-window_ models presented in (see From 774b35aaffc29d73c16752f158ee37576401b509 Mon Sep 17 00:00:00 2001 From: rchan Date: Fri, 17 Nov 2023 14:52:37 +0000 Subject: [PATCH 16/16] add some nlpsig info in readme --- README.md | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/README.md b/README.md index c32a0d8..4341dc1 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,41 @@ fashion. The key components are: - The SeqSigNet-Attention-BiLSTM model: [`sig_networks.SeqSigNetAttentionBiLSTM`](src/sig_networks/seqsignet_attention_bilstm.py) +### `nlpsig`: Preparing the data + +The functionality to prepare the data for the SW-models and constructing paths +and inputs are found in the [`nlpsig`](https://github.com/datasig-ac-uk/nlpsig) +library which can be easily installed using `pip` and comes as a dependency in +`sig-networks`. + +Paths can be constructed using the +[`nlpsig.PrepareData`](https://nlpsig.readthedocs.io/en/latest/data_preparation.html) +class. Furthermore, there is functionality within the +[`nlpsig.TextEncoder`](https://nlpsig.readthedocs.io/en/latest/encode_text.html#nlpsig.encode_text.TextEncoder) +and +[`nlpsig.SentenceEncoder`](https://nlpsig.readthedocs.io/en/latest/encode_text.html#nlpsig.encode_text.SentenceEncoder) +classes to obtain embeddings using transformers to be used as the channels in +the paths. Since we want to take path signatures within the SW-models, we need +to ensure that the number of channels in the path are low enough that we can +take the path signatures efficiently. To enable this, there are also a number of +dimensionality reduction methods in the `nlpsig` library - see +[`nlpsig.DimReduce`](https://nlpsig.readthedocs.io/en/latest/dimensionality_reduction.html). + +For full details, see the +[`nlpsig` GitHub repo](https://github.com/datasig-ac-uk/nlpsig) and there are +examples of using the library in the [`examples/`](examples/) directory. + +Note that for obtaining inputs to the SWNU-/SWMHA-Networks and the SeqSigNet +family models, there are helper functions in the scripts (see e.g. +`obtain_SWNUNetwork_input` in +[`src/scripts/swnu_network_functions.py`](src/scripts/swnu_network_functions.py) +and `obtain_SeqSigNet_input` in +[`src/scripts/seqsignet_functions.py`](src/scripts/seqsignet_functions.py)). +There is also examples run-throughs in the [`examples/`](examples/) directory: + +- [Training a SWNU-Network model for Anno-MI client-talk-type prediction](examples/AnnoMI/anno_mi-client-swnu-example.ipynb) +- [Training a SeqSigNet model for Anno-MI client-talk-type prediction](examples/AnnoMI/anno_mi-client-seqsignet-example.ipynb) + ### Using the SWNU and SWMHAU modules The Signature Window units (SWNU and SWMHAU) accept a batch of streams and @@ -186,6 +221,13 @@ slightly different: Transformer Encoder) that processes the outputs of SWNU and SWMHAU to be fed into a FFN layer +### Example experiments + +In the [`examples/`](examples/) directory, there are some example experiments +using the library to comapre the SW-models with other baseline models such as a +simple FFN, a BiLSTM model on the sentence-transformer representations and using +a pre-trained Transformer model for classification. + ## Pre-commit and linters To take advantage of `pre-commit`, which will automatically format your code and