diff --git a/opennmt/tests/tokenizer_test.py b/opennmt/tests/tokenizer_test.py index 869985cbd..bf3d627a5 100644 --- a/opennmt/tests/tokenizer_test.py +++ b/opennmt/tests/tokenizer_test.py @@ -103,6 +103,15 @@ def testOpenNMTTokenizer(self): [["Hello", "world", "■!"], ["Test"], ["My", "name"]], ["Hello world!", "Test", "My name"]) + def testOpenNMTTokenizerEmptyTensor(self): + tokenizer = tokenizers.OpenNMTTokenizer() + tokens = tokenizer.tokenize(tf.constant("")) + self.assertIs(tokens.dtype, tf.string) + self.assertListEqual(tokens.shape.as_list(), [0]) + text = tokenizer.detokenize(tokens) + self.assertIs(text.dtype, tf.string) + self.assertListEqual(text.shape.as_list(), []) + def testOpenNMTTokenizerArguments(self): tokenizer = tokenizers.OpenNMTTokenizer( mode="aggressive", spacer_annotate=True, spacer_new=True) diff --git a/opennmt/tokenizers/tokenizer.py b/opennmt/tokenizers/tokenizer.py index 23b9cb888..f7b817c6c 100644 --- a/opennmt/tokenizers/tokenizer.py +++ b/opennmt/tokenizers/tokenizer.py @@ -159,7 +159,7 @@ def _tokenize_tensor(self, text): def _python_wrapper(string_t): string = tf.compat.as_text(string_t.numpy()) tokens = self._tokenize_string(string) - return tf.constant(tokens) + return tf.constant(tokens, dtype=tf.string) tokens = tf.py_function(_python_wrapper, [text], tf.string) tokens.set_shape([None]) return tokens