From 9b2435cb0e92836d49bbdda9c18fcfae21262f77 Mon Sep 17 00:00:00 2001 From: heinpa Date: Tue, 20 Aug 2024 09:45:45 +0200 Subject: [PATCH] allow configuration of source and target languages --- qanary-component-MT-Python-MBart/Dockerfile | 2 +- qanary-component-MT-Python-MBart/boot.sh | 6 +- .../component/__init__.py | 2 +- .../component/mt_mbart_nlp.py | 172 ++++++++---------- qanary-component-MT-Python-MBart/pytest.ini | 10 +- .../requirements.txt | 20 +- .../tests/test_lang_utils.py | 70 +++++++ .../tests/test_mt_mbart_nlp.py | 63 +++++-- .../utils/__init__.py | 0 .../utils/lang_utils.py | 74 ++++++++ .../utils/model_utils.py | 7 + 11 files changed, 292 insertions(+), 134 deletions(-) create mode 100644 qanary-component-MT-Python-MBart/tests/test_lang_utils.py create mode 100644 qanary-component-MT-Python-MBart/utils/__init__.py create mode 100644 qanary-component-MT-Python-MBart/utils/lang_utils.py create mode 100644 qanary-component-MT-Python-MBart/utils/model_utils.py diff --git a/qanary-component-MT-Python-MBart/Dockerfile b/qanary-component-MT-Python-MBart/Dockerfile index 61eb53cca..7fad2a6ce 100644 --- a/qanary-component-MT-Python-MBart/Dockerfile +++ b/qanary-component-MT-Python-MBart/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.7 +FROM python:3.10 COPY requirements.txt ./ diff --git a/qanary-component-MT-Python-MBart/boot.sh b/qanary-component-MT-Python-MBart/boot.sh index 8ef76b030..e40993cc3 100755 --- a/qanary-component-MT-Python-MBart/boot.sh +++ b/qanary-component-MT-Python-MBart/boot.sh @@ -2,8 +2,10 @@ export $(grep -v '^#' .env | xargs) -echo Downloading the model -python -c 'from transformers import MBartForConditionalGeneration, MBart50TokenizerFast; model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt"); tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")' +echo Downloading the models + +python -c "from utils.model_utils import load_models_and_tokenizers; load_models_and_tokenizers(); " + echo Downloading the model finished echo SERVER_PORT: $SERVER_PORT diff --git a/qanary-component-MT-Python-MBart/component/__init__.py b/qanary-component-MT-Python-MBart/component/__init__.py index 40da0f2f5..6cd66870e 100644 --- a/qanary-component-MT-Python-MBart/component/__init__.py +++ b/qanary-component-MT-Python-MBart/component/__init__.py @@ -1,7 +1,7 @@ from component.mt_mbart_nlp import mt_mbart_nlp_bp from flask import Flask -version = "0.1.2" +version = "0.2.0" # default config file configfile = "app.conf" diff --git a/qanary-component-MT-Python-MBart/component/mt_mbart_nlp.py b/qanary-component-MT-Python-MBart/component/mt_mbart_nlp.py index 9e4ad5520..afd88cff4 100644 --- a/qanary-component-MT-Python-MBart/component/mt_mbart_nlp.py +++ b/qanary-component-MT-Python-MBart/component/mt_mbart_nlp.py @@ -1,60 +1,25 @@ -from langdetect import detect import logging import os from flask import Blueprint, jsonify, request from qanary_helpers.qanary_queries import get_text_question_in_graph, insert_into_triplestore +from qanary_helpers.language_queries import get_translated_texts_in_triplestore, get_texts_with_detected_language_in_triplestore, question_text_with_language, create_annotation_of_question_translation +from utils.model_utils import load_models_and_tokenizers +from utils.lang_utils import translation_options, LANG_CODE_MAP -from transformers import MBartForConditionalGeneration, MBart50TokenizerFast logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) - mt_mbart_nlp_bp = Blueprint("mt_mbart_nlp_bp", __name__, template_folder="templates") SERVICE_NAME_COMPONENT = os.environ["SERVICE_NAME_COMPONENT"] -SOURCE_LANG = os.environ["SOURCE_LANGUAGE"] -TARGET_LANG = os.environ["TARGET_LANGUAGE"] - -model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") -tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") -lang_code_map = { - "en": "en_XX", - "de": "de_DE", - "ru": "ru_RU", - "fr": "fr_XX", - "es": "ex_XX", - "pt": "pr_XX" -} -target_lang = TARGET_LANG - -supported_langs = lang_code_map.keys() # TODO: check supported languages for LibreTranslate - - -@mt_mbart_nlp_bp.route("/annotatequestion", methods=["POST"]) -def qanary_service(): - """the POST endpoint required for a Qanary service""" - triplestore_endpoint = request.json["values"]["urn:qanary#endpoint"] - triplestore_ingraph = request.json["values"]["urn:qanary#inGraph"] - triplestore_outgraph = request.json["values"]["urn:qanary#outGraph"] - logging.info("endpoint: %s, inGraph: %s, outGraph: %s" % \ - (triplestore_endpoint, triplestore_ingraph, triplestore_outgraph)) +model, tokenizer = load_models_and_tokenizers() - text = get_text_question_in_graph(triplestore_endpoint=triplestore_endpoint, - graph=triplestore_ingraph)[0]["text"] - question_uri = get_text_question_in_graph(triplestore_endpoint=triplestore_endpoint, - graph=triplestore_ingraph)[0]["uri"] - logging.info(f"Question text: {text}") - if SOURCE_LANG != None and len(SOURCE_LANG.strip()) > 0: - lang = SOURCE_LANG - logging.info("Using custom SOURCE_LANGUAGE") - else: - lang = detect(text) - logging.info("No SOURCE_LANGUAGE specified, using langdetect!") - logging.info(f"source language: {lang}") +def translate_input(text:str, source_lang: str, target_lang: str) -> str: + logging.info(f"translating \"{text}\" from \"{source_lang}\" to \"{target_lang}\"") ## MAIN FUNCTIONALITY - tokenizer.src_lang = lang_code_map[lang] # TODO: do formats match? + tokenizer.src_lang = LANG_CODE_MAP[source_lang] # TODO: do formats match? logging.info(f"source language mapped code: {tokenizer.src_lang}") batch = tokenizer(text, return_tensors="pt") @@ -66,73 +31,78 @@ def qanary_service(): # Perform the translation and decode the output generated_tokens = model.generate( **batch, - forced_bos_token_id=tokenizer.lang_code_to_id[lang_code_map[target_lang]]) # TODO: defined target lang + forced_bos_token_id=tokenizer.convert_tokens_to_ids(LANG_CODE_MAP[target_lang])) result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] + translation = result.replace("\"", "\\\"") #keep quotation marks that are part of the translation + logging.info(f"result: \"{translation}\"") + return translation + +@mt_mbart_nlp_bp.route("/annotatequestion", methods=["POST"]) +def qanary_service(): + """the POST endpoint required for a Qanary service""" + triplestore_endpoint = request.json["values"]["urn:qanary#endpoint"] + triplestore_ingraph = request.json["values"]["urn:qanary#inGraph"] + triplestore_outgraph = request.json["values"]["urn:qanary#outGraph"] + logging.info("endpoint: %s, inGraph: %s, outGraph: %s" % \ + (triplestore_endpoint, triplestore_ingraph, triplestore_outgraph)) - # building SPARQL query TODO: verify this annotation AnnotationOfQuestionTranslation ?? - SPARQLqueryAnnotationOfQuestionTranslation = """ - PREFIX qa: - PREFIX oa: - PREFIX xsd: - - INSERT {{ - GRAPH <{uuid}> {{ - ?a a qa:AnnotationOfQuestionTranslation ; - oa:hasTarget <{qanary_question_uri}> ; - oa:hasBody "{translation_result}"@{target_lang} ; - oa:annotatedBy ; - oa:annotatedAt ?time . - - }} - }} - WHERE {{ - BIND (IRI(str(RAND())) AS ?a) . - BIND (now() as ?time) - }} - """.format( - uuid=triplestore_ingraph, - qanary_question_uri=question_uri, - translation_result=result.replace("\"", "\\\""), #keep quotation marks that are part of the translation - target_lang=TARGET_LANG, - app_name=SERVICE_NAME_COMPONENT - ) - - SPARQLqueryAnnotationOfQuestionLanguage = """ - PREFIX qa: - PREFIX oa: - PREFIX xsd: - - INSERT {{ - GRAPH <{uuid}> {{ - ?b a qa:AnnotationOfQuestionLanguage ; - oa:hasTarget <{qanary_question_uri}> ; - oa:hasBody "{src_lang}"^^xsd:string ; - oa:annotatedBy ; - oa:annotatedAt ?time . - }} - }} - WHERE {{ - BIND (IRI(str(RAND())) AS ?b) . - BIND (now() as ?time) - }} - """.format( - uuid=triplestore_ingraph, - qanary_question_uri=question_uri, - src_lang=lang, - app_name=SERVICE_NAME_COMPONENT - ) - - logging.info(f'SPARQL: {SPARQLqueryAnnotationOfQuestionTranslation}') - logging.info(f'SPARQL: {SPARQLqueryAnnotationOfQuestionLanguage}') - # inserting new data to the triplestore - insert_into_triplestore(triplestore_endpoint, SPARQLqueryAnnotationOfQuestionTranslation) - insert_into_triplestore(triplestore_endpoint, SPARQLqueryAnnotationOfQuestionLanguage) + text_question_in_graph = get_text_question_in_graph(triplestore_endpoint=triplestore_endpoint, graph=triplestore_ingraph) + question_text = text_question_in_graph[0]['text'] + logging.info(f'Original question text: {question_text}') + + + # Collect texts to be translated (group by source language) + + for source_lang in translation_options.keys(): + source_texts = find_source_texts_in_triplestore( + triplestore_endpoint=triplestore_endpoint, + graph_uri=triplestore_ingraph, + lang=source_lang + ) + + # translate source texts into specified target languages + for target_lang in translation_options[source_lang]: + for source_text in source_texts: + translation = translate_input(source_text.get_text(), source_lang, target_lang) + if len(translation.strip()) > 0: + SPARQLqueryAnnotationOfQuestionTranslation = create_annotation_of_question_translation( + graph_uri=triplestore_ingraph, + question_uri=source_text.get_uri(), + translation=translation, + translation_language=target_lang, + app_name=SERVICE_NAME_COMPONENT + ) + insert_into_triplestore(triplestore_endpoint, SPARQLqueryAnnotationOfQuestionTranslation) + else: + logging.error(f"result is empty string!") return jsonify(request.get_json()) +def find_source_texts_in_triplestore(triplestore_endpoint: str, graph_uri: str, lang: str) -> list[question_text_with_language]: + source_texts = [] + + # check if supported languages have been determined already (LD) + # (use filters) + # if so, use the target uris to find the question text to translate + ld_source_texts = get_texts_with_detected_language_in_triplestore(triplestore_endpoint, graph_uri, lang) + source_texts.extend(ld_source_texts) + + # check if there are translations into the relevant language (MT) + # (use filters) + # if so, use the translation texts + mt_source_texts = get_translated_texts_in_triplestore(triplestore_endpoint, graph_uri, lang) + source_texts.extend(mt_source_texts) + + # TODO: what if nothing found? + if len(source_texts) == 0: + logging.warning(f"No source texts with language {lang} could be found In the triplestore!") + + return source_texts + + @mt_mbart_nlp_bp.route("/", methods=["GET"]) def index(): """examplary GET endpoint""" diff --git a/qanary-component-MT-Python-MBart/pytest.ini b/qanary-component-MT-Python-MBart/pytest.ini index a51bc97cf..c612b0784 100644 --- a/qanary-component-MT-Python-MBart/pytest.ini +++ b/qanary-component-MT-Python-MBart/pytest.ini @@ -1,16 +1,14 @@ [pytest] -log_cli = 0 +log_cli = 1 log_cli_level = INFO log_cli_format = %(asctime)s [%(levelname)8s] [%(filename)s:%(lineno)s] %(message)s log_cli_date_format=%Y-%m-%d %H:%M:%S env = - SERVER_PORT=40120 + SERVICE_PORT=40120 + SERVICE_HOST=http://public-component-host SPRING_BOOT_ADMIN_URL=http://qanary-pipeline-host:40111 - SERVER_HOST=http://public-component-host SPRING_BOOT_ADMIN_CLIENT_INSTANCE_SERVICE-BASE-URL=http://public-component-host:40120 SPRING_BOOT_ADMIN_USERNAME=admin SPRING_BOOT_ADMIN_PASSWORD=admin - SERVICE_NAME_COMPONENT=MT-MBart + SERVICE_NAME_COMPONENT=MT-MBart-Component SERVICE_DESCRIPTION_COMPONENT=Translates question to English - SOURCE_LANGUAGE=de - TARGET_LANGUAGE=en diff --git a/qanary-component-MT-Python-MBart/requirements.txt b/qanary-component-MT-Python-MBart/requirements.txt index 7e37256a2..3fdfe2226 100644 --- a/qanary-component-MT-Python-MBart/requirements.txt +++ b/qanary-component-MT-Python-MBart/requirements.txt @@ -1,13 +1,9 @@ -Flask -langdetect==1.0.9 -langid==1.1.6 -mock==3.0.5 -python-dotenv==0.21.1 +Flask==3.0.3 +pytest==8.3.2 +pytest-env==1.1.3 qanary_helpers==0.2.2 -transformers==4.41.0 -sentencepiece==0.1.97 -torch==2.3.0 -gunicorn==20.1.0 -protobuf==3.20.* -pytest -pytest-env +SentencePiece==0.2.0 +SPARQLWrapper==2.0.0 +torch==2.4.0 +transformers==4.44.0 +qanary-helpers==0.2.2 diff --git a/qanary-component-MT-Python-MBart/tests/test_lang_utils.py b/qanary-component-MT-Python-MBart/tests/test_lang_utils.py new file mode 100644 index 000000000..6f81e7117 --- /dev/null +++ b/qanary-component-MT-Python-MBart/tests/test_lang_utils.py @@ -0,0 +1,70 @@ +import logging +from unittest import mock +from unittest import TestCase +import os +import importlib + +class TestLangUtils(TestCase): + + logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO) + + @mock.patch.dict(os.environ, {'SOURCE_LANGUAGE': 'fr'}) + def test_only_one_source_language(self): + import utils.lang_utils + importlib.reload(utils.lang_utils) + from utils.lang_utils import translation_options + assert 'fr' in translation_options.keys() + assert len(translation_options.keys()) == 1 + + + @mock.patch.dict(os.environ, {'TARGET_LANGUAGE': 'ru'}) + def test_only_one_target_language(self): + import utils.lang_utils + importlib.reload(utils.lang_utils) + from utils.lang_utils import translation_options + # all 5 non-russian source languages should support 'ru' + assert len(translation_options.items()) == 5 + # but each item should only contain the one target language! + assert ('en', ['ru']) in translation_options.items() + assert ('de', ['ru']) in translation_options.items() + assert ('es', ['ru']) in translation_options.items() + assert ('fr', ['ru']) in translation_options.items() + assert ('pt', ['ru']) in translation_options.items() + + + @mock.patch.dict(os.environ, {'SOURCE_LANGUAGE': 'en', 'TARGET_LANGUAGE': 'es'}) + def test_specific_source_and_target_language(self): + import utils.lang_utils + importlib.reload(utils.lang_utils) + from utils.lang_utils import translation_options + assert translation_options == {'en': ['es']} + + + @mock.patch.dict(os.environ, {'SOURCE_LANGUAGE': 'zh'}) + def test_unsupported_source_language_raises_error(self): + try: + import utils.lang_utils + importlib.reload(utils.lang_utils) + except ValueError as ve: + logging.error(ve) + pass + + + @mock.patch.dict(os.environ, {'SOURCE_LANGUAGE': 'en', 'TARGET_LANGUAGE': 'zh'}) + def test_unsupported_target_for_source_language_raises_error(self): + try: + import utils.lang_utils + importlib.reload(utils.lang_utils) + except ValueError as ve: + logging.error(ve) + pass + + + @mock.patch.dict(os.environ, {'TARGET_LANGUAGE': 'zh'}) + def test_unsupported_target_language_raises_error(self): + try: + import utils.lang_utils + importlib.reload(utils.lang_utils) + except ValueError as ve: + logging.error(ve) + pass diff --git a/qanary-component-MT-Python-MBart/tests/test_mt_mbart_nlp.py b/qanary-component-MT-Python-MBart/tests/test_mt_mbart_nlp.py index eedfce342..ed180fe33 100644 --- a/qanary-component-MT-Python-MBart/tests/test_mt_mbart_nlp.py +++ b/qanary-component-MT-Python-MBart/tests/test_mt_mbart_nlp.py @@ -1,15 +1,18 @@ -from component.mt_mbart_nlp import * -from component import app +import logging from unittest.mock import patch +from unittest import mock import re from unittest import TestCase +from qanary_helpers.language_queries import question_text_with_language +import os +import importlib class TestComponent(TestCase): logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO) - questions = list([{"uri": "urn:test-uri", "text": "was ist ein Test?"}]) + questions = list([{"uri": "urn:test-uri", "text": "Was ist die Hauptstadt von Deutschland?"}]) endpoint = "urn:qanary#test-endpoint" in_graph = "urn:qanary#test-inGraph" out_graph = "urn:qanary#test-outGraph" @@ -17,6 +20,10 @@ class TestComponent(TestCase): source_language = "de" target_language = "en" + source_texts = [ + question_text_with_language("uri", "Was ist die Hauptstadt von Deutschland?", "de") + ] + request_data = '''{ "values": { "urn:qanary#endpoint": "urn:qanary#test-endpoint", @@ -33,14 +40,27 @@ class TestComponent(TestCase): } + @mock.patch.dict(os.environ, {'SOURCE_LANGUAGE': 'de', 'TARGET_LANGUAGE': 'en'}) def test_qanary_service(self): + import utils.lang_utils + importlib.reload(utils.lang_utils) + import component.mt_mbart_nlp + importlib.reload(component.mt_mbart_nlp) + from component import app + + logging.info("port: %s" % (os.environ["SERVICE_PORT"])) + assert os.environ["SERVICE_NAME_COMPONENT"] == "MT-MBart-Component" + assert os.environ["SOURCE_LANGUAGE"] == self.source_language + assert os.environ["TARGET_LANGUAGE"] == self.target_language with app.test_client() as client, \ patch('component.mt_mbart_nlp.get_text_question_in_graph') as mocked_get_text_question_in_graph, \ + patch('component.mt_mbart_nlp.find_source_texts_in_triplestore') as mocked_find_source_texts_in_triplestore, \ patch('component.mt_mbart_nlp.insert_into_triplestore') as mocked_insert_into_triplestore: # given a non-english question is present in the current graph mocked_get_text_question_in_graph.return_value = self.questions + mocked_find_source_texts_in_triplestore.return_value = self.source_texts mocked_insert_into_triplestore.return_value = None # when a call to /annotatequestion is made @@ -49,25 +69,21 @@ def test_qanary_service(self): # then the text question is retrieved from the triplestore mocked_get_text_question_in_graph.assert_called_with(triplestore_endpoint=self.endpoint, graph=self.in_graph) + mocked_find_source_texts_in_triplestore.assert_called_with(triplestore_endpoint=self.endpoint, graph_uri=self.in_graph, lang=self.source_language) + assert mocked_find_source_texts_in_triplestore.call_count == 1 + # get arguments of the (2) separate insert calls arg_list = mocked_insert_into_triplestore.call_args_list # get the call arguments for question translation call_args_translation = [a.args for a in arg_list if "AnnotationOfQuestionTranslation" in a.args[1]] assert len(call_args_translation) == 1 - # get the call arguments for question language - call_args_language = [a.args for a in arg_list if "AnnotationOfQuestionLanguage" in a.args[1]] - assert len(call_args_language) == 1 # clean query strings query_translation = re.sub(r"(\\n\W*|\n\W*)", " ", call_args_translation[0][1]) - query_language = re.sub(r"(\\n\W*|\n\W*)", " ", call_args_language[0][1]) # then the triplestore is updated twice # (question language and translation) - assert mocked_insert_into_triplestore.call_count == 2 - - # then the source language is correctly identified and annotated - self.assertRegex(query_language, r".*AnnotationOfQuestionLanguage(.*;\W?)*oa:hasBody \""+self.source_language+r"\".*\.") + assert mocked_insert_into_triplestore.call_count == 1 # then the question is translated and the result is annotated self.assertRegex(query_translation, r".*AnnotationOfQuestionTranslation(.*;\W?)*oa:hasBody \".*\"@" + self.target_language + r".*\.") @@ -75,3 +91,28 @@ def test_qanary_service(self): # then the response is not empty assert response_json != None + + + # test with all supported languages enabled + def test_translate_input(self): + import component.mt_mbart_nlp + from component.mt_mbart_nlp import translate_input + import utils.lang_utils + importlib.reload(utils.lang_utils) + importlib.reload(component.mt_mbart_nlp) + translations = [ + {"text": "Was ist die Hauptstadt von Deutschland?", + "translation": "What is the capital of Germany?", + "source_lang": "de", "target_lang": "en"}, + {"text": "What is the capital of Germany?", + "translation": "Quelle est la capitale de l'Allemagne?", + "source_lang": "en", "target_lang": "fr"}, +# {"text": "What is the capital of Germany?", TODO: MBart answers: "Что такое столица Германии?" +# "translation": "Какая столица Германии?", +# "source_lang": "en", "target_lang": "ru"}, + ] + + for translation in translations: + expected = translation["translation"] + actual = translate_input(translation["text"], translation["source_lang"], translation["target_lang"]) + assert expected == actual diff --git a/qanary-component-MT-Python-MBart/utils/__init__.py b/qanary-component-MT-Python-MBart/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qanary-component-MT-Python-MBart/utils/lang_utils.py b/qanary-component-MT-Python-MBart/utils/lang_utils.py new file mode 100644 index 000000000..e682bf5b6 --- /dev/null +++ b/qanary-component-MT-Python-MBart/utils/lang_utils.py @@ -0,0 +1,74 @@ +import os +import logging + + +logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO) + +SOURCE_LANGUAGE = os.getenv("SOURCE_LANGUAGE") +TARGET_LANGUAGE = os.getenv("TARGET_LANGUAGE") +SUPPORTED_LANGS = { +# source: targets + 'en': ['de', 'ru', 'fr', 'es', 'pt'], + 'de': ['en', 'ru', 'fr', 'es', 'pt'], + 'ru': ['en', 'de', 'fr', 'es', 'pt'], + 'fr': ['en', 'de', 'ru', 'es', 'pt'], + 'es': ['en', 'de', 'ru', 'fr', 'pt'], + 'pt': ['en', 'de', 'ru', 'fr', 'es'] +} + +LANG_CODE_MAP = { + "en": "en_XX", + "de": "de_DE", + "ru": "ru_RU", + "fr": "fr_XX", + "es": "ex_XX", + "pt": "pr_XX" +} + + +def setup_translation_options() -> dict: + """Create a dictionary of possible source and target languages, based on SUPPORTED_LANGS + and configured languages.""" + + logging.info("SETTING UP TRANSLATION OPTIONS") + translation_options = dict() # init emtpy + + # check if a source language is specified + if SOURCE_LANGUAGE != None and len(SOURCE_LANGUAGE.strip()) > 0: + # pre-select appropriate translation options from the list of supported source languages + try: + translation_options[SOURCE_LANGUAGE] = SUPPORTED_LANGS[SOURCE_LANGUAGE] + # this will fail for invalid keys! + except KeyError: + raise ValueError(f"The source language \"{SOURCE_LANGUAGE}\" is not supported!") + # if no source language is specified, use all source languages that are supported by the models + else: + translation_options = SUPPORTED_LANGS + + # check if a target language is specified + if TARGET_LANGUAGE != None and len(TARGET_LANGUAGE.strip()) > 0: + discard_keys = list() + # remove instances where source == target + translation_options.pop(TARGET_LANGUAGE, None) + for source_lang in translation_options.keys(): + if TARGET_LANGUAGE in translation_options[source_lang]: + translation_options[source_lang] = [TARGET_LANGUAGE] + else: + discard_keys.append(source_lang) + # cleanup keys + translation_options = {sl:tl for sl,tl in translation_options.items() if sl not in discard_keys} + # check for empty translation options, if all keys dropped + if len(translation_options.keys()) == 0: + raise ValueError("The target language \"{tl}\" is not supported for any configured source languages! \nValid language pairs (source: [targets]) are: \n{slk}!" + .format(tl=TARGET_LANGUAGE, slk=SUPPORTED_LANGS)) + # check if only some keys dropped + elif len(discard_keys) > 0: + logging.warning("Specific target language \"{tl}\" is not supported for these source languages: {dk}!. \nThese language pairs will be ignored." + .format(tl=TARGET_LANGUAGE, dk=discard_keys)) + # else do nothing, the lists are already complete + + logging.info(translation_options) + return translation_options + + +translation_options = setup_translation_options() diff --git a/qanary-component-MT-Python-MBart/utils/model_utils.py b/qanary-component-MT-Python-MBart/utils/model_utils.py new file mode 100644 index 000000000..c3ba9531a --- /dev/null +++ b/qanary-component-MT-Python-MBart/utils/model_utils.py @@ -0,0 +1,7 @@ +from transformers import MBartForConditionalGeneration, MBart50TokenizerFast + + +def load_models_and_tokenizers(): + model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") + tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") + return model, tokenizer