From 930946cd2fca12e923bc645433dca3629c11a2bd Mon Sep 17 00:00:00 2001 From: xingchensong Date: Thu, 7 Dec 2023 20:49:25 +0800 Subject: [PATCH 1/3] feat(all): format all files --- itn/chinese/inverse_normalizer.py | 51 +++++---- itn/chinese/rules/cardinal.py | 80 ++++++------- itn/chinese/rules/date.py | 6 +- itn/chinese/rules/fraction.py | 7 +- itn/chinese/rules/license_plate.py | 3 +- itn/chinese/rules/measure.py | 21 ++-- itn/chinese/rules/money.py | 6 +- itn/chinese/rules/time.py | 13 +-- itn/chinese/test/data/normalizer.txt | 2 +- itn/chinese/test/normalizer_test.py | 99 +++++++--------- itn/main.py | 20 +++- .../app/src/main/cpp/wetextprocessing.cc | 22 ++-- runtime/patch/openfst/src/include/fst/flags.h | 107 ++++++++---------- runtime/patch/openfst/src/include/fst/log.h | 16 +-- runtime/patch/openfst/src/lib/flags.cc | 38 +++---- tn/chinese/data/number/teen.tsv | 4 +- tn/chinese/data/time/minute.tsv | 2 +- tn/chinese/data/time/second.tsv | 2 +- tn/chinese/rules/postprocessor.py | 7 +- tn/chinese/rules/whitelist.py | 4 +- tn/chinese/test/data/postprocessor.txt | 2 +- tn/main.py | 41 ++++--- tn/token_parser.py | 8 +- 23 files changed, 282 insertions(+), 279 deletions(-) diff --git a/itn/chinese/inverse_normalizer.py b/itn/chinese/inverse_normalizer.py index 35c6e08..3861f4a 100644 --- a/itn/chinese/inverse_normalizer.py +++ b/itn/chinese/inverse_normalizer.py @@ -31,7 +31,9 @@ class InverseNormalizer(Processor): - def __init__(self, cache_dir=None, overwrite_cache=False, + def __init__(self, + cache_dir=None, + overwrite_cache=False, enable_standalone_number=True, enable_0_to_9=False, enable_million=False): @@ -44,32 +46,39 @@ def __init__(self, cache_dir=None, overwrite_cache=False, self.build_fst('zh_itn', cache_dir, overwrite_cache) def build_tagger(self): - tagger = (add_weight(Date().tagger, 1.02) - | add_weight(Whitelist().tagger, 1.01) - | add_weight(Fraction().tagger, 1.05) - | add_weight(Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05) # noqa - | add_weight(Money(enable_0_to_9=self.enable_0_to_9).tagger, 1.04) # noqa - | add_weight(Time().tagger, 1.05) - | add_weight(Cardinal(self.convert_number, self.enable_0_to_9, self.enable_million).tagger, 1.06) # noqa - | add_weight(Math().tagger, 1.10) - | add_weight(LicensePlate().tagger, 1.0) - | add_weight(Char().tagger, 100)).optimize() + tagger = ( + add_weight(Date().tagger, 1.02) + | add_weight(Whitelist().tagger, 1.01) + | add_weight(Fraction().tagger, 1.05) + | add_weight( + Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05) # noqa + | add_weight(Money(enable_0_to_9=self.enable_0_to_9).tagger, + 1.04) # noqa + | add_weight(Time().tagger, 1.05) + | add_weight( + Cardinal(self.convert_number, self.enable_0_to_9, + self.enable_million).tagger, 1.06) # noqa + | add_weight(Math().tagger, 1.10) + | add_weight(LicensePlate().tagger, 1.0) + | add_weight(Char().tagger, 100)).optimize() tagger = tagger.star # remove the last space self.tagger = tagger @ self.build_rule(delete(' '), '', '[EOS]') def build_verbalizer(self): - verbalizer = (Cardinal(self.convert_number, self.enable_0_to_9, self.enable_million).verbalizer # noqa - | Char().verbalizer - | Date().verbalizer - | Fraction().verbalizer - | Math().verbalizer - | Measure(enable_0_to_9=self.enable_0_to_9).verbalizer - | Money(enable_0_to_9=self.enable_0_to_9).verbalizer - | Time().verbalizer - | LicensePlate().verbalizer - | Whitelist().verbalizer).optimize() + verbalizer = ( + Cardinal(self.convert_number, self.enable_0_to_9, + self.enable_million).verbalizer # noqa + | Char().verbalizer + | Date().verbalizer + | Fraction().verbalizer + | Math().verbalizer + | Measure(enable_0_to_9=self.enable_0_to_9).verbalizer + | Money(enable_0_to_9=self.enable_0_to_9).verbalizer + | Time().verbalizer + | LicensePlate().verbalizer + | Whitelist().verbalizer).optimize() postprocessor = PostProcessor(remove_interjections=True).processor self.verbalizer = (verbalizer @ postprocessor).star diff --git a/itn/chinese/rules/cardinal.py b/itn/chinese/rules/cardinal.py index c863fd7..85cdab9 100644 --- a/itn/chinese/rules/cardinal.py +++ b/itn/chinese/rules/cardinal.py @@ -20,7 +20,9 @@ class Cardinal(Processor): - def __init__(self, enable_standalone_number=True, enable_0_to_9=True, + def __init__(self, + enable_standalone_number=True, + enable_0_to_9=True, enable_million=False): super().__init__('cardinal') self.number = None @@ -32,10 +34,10 @@ def __init__(self, enable_standalone_number=True, enable_0_to_9=True, self.build_verbalizer() def build_tagger(self): - zero = string_file('itn/chinese/data/number/zero.tsv') # 0 + zero = string_file('itn/chinese/data/number/zero.tsv') # 0 digit = string_file('itn/chinese/data/number/digit.tsv') # 1 ~ 9 - sign = string_file('itn/chinese/data/number/sign.tsv') # + - - dot = string_file('itn/chinese/data/number/dot.tsv') # . + sign = string_file('itn/chinese/data/number/sign.tsv') # + - + dot = string_file('itn/chinese/data/number/dot.tsv') # . addzero = insert('0') digits = zero | digit # 0 ~ 9 @@ -52,33 +54,33 @@ def build_tagger(self): | add_weight(addzero**2, 1.0))) # 一千一百一十一 => 1111, 一千零一十一 => 1011, 一千零一 => 1001 # 一千一 => 1100, 一千 => 1000 - thousand = ((hundred | teen | tens | digits) + delete('千') + ( - hundred - | add_weight(zero + (tens | teen), 0.1) - | add_weight(addzero + zero + digit, 0.5) - | add_weight(digit + addzero**2, 0.8) - | add_weight(addzero**3, 1.0))) + thousand = ((hundred | teen | tens | digits) + delete('千') + + (hundred + | add_weight(zero + (tens | teen), 0.1) + | add_weight(addzero + zero + digit, 0.5) + | add_weight(digit + addzero**2, 0.8) + | add_weight(addzero**3, 1.0))) # 10001111, 1001111, 101111, 11111, 10111, 10011, 10001, 10000 if self.enable_million: - ten_thousand = ((thousand | hundred | teen | tens | digits) - + delete('万') - + (thousand - | add_weight(zero + hundred, 0.1) - | add_weight(addzero + zero + (tens | teen), 0.5) - | add_weight(addzero + addzero + zero + digit, 0.5) - | add_weight(digit + addzero**3, 0.8) - | add_weight(addzero**4, 1.0))) + ten_thousand = ( + (thousand | hundred | teen | tens | digits) + delete('万') + + (thousand + | add_weight(zero + hundred, 0.1) + | add_weight(addzero + zero + (tens | teen), 0.5) + | add_weight(addzero + addzero + zero + digit, 0.5) + | add_weight(digit + addzero**3, 0.8) + | add_weight(addzero**4, 1.0))) else: - ten_thousand = ((teen | tens | digits) - + delete('万') - + (thousand - | add_weight(zero + hundred, 0.1) - | add_weight(addzero + zero + (tens | teen), 0.5) - | add_weight(addzero + addzero + zero + digit, 0.5) - | add_weight(digit + addzero**3, 0.8) - | add_weight(addzero**4, 1.0))) - ten_thousand |= (thousand | hundred) + accep("万") + delete("零").ques + ( - thousand | hundred | tens | teen | digits).ques + ten_thousand = ( + (teen | tens | digits) + delete('万') + + (thousand + | add_weight(zero + hundred, 0.1) + | add_weight(addzero + zero + (tens | teen), 0.5) + | add_weight(addzero + addzero + zero + digit, 0.5) + | add_weight(digit + addzero**3, 0.8) + | add_weight(addzero**4, 1.0))) + ten_thousand |= (thousand | hundred) + accep("万") + delete( + "零").ques + (thousand | hundred | tens | teen | digits).ques # 个/十/百/千/万 number = digits | teen | tens | hundred | thousand | ten_thousand # 兆/亿 @@ -107,23 +109,22 @@ def build_tagger(self): # 十/百/千/万 number_exclude_0_to_9 = teen | tens | hundred | thousand | ten_thousand # 兆/亿 - number_exclude_0_to_9 = ( - ((number_exclude_0_to_9 | digits) + accep('兆') + delete('零').ques).ques + - ((number_exclude_0_to_9 | digits) + accep('亿') + delete('零').ques).ques + - number_exclude_0_to_9 - ) + number_exclude_0_to_9 = (((number_exclude_0_to_9 | digits) + + accep('兆') + delete('零').ques).ques + + ((number_exclude_0_to_9 | digits) + + accep('亿') + delete('零').ques).ques + + number_exclude_0_to_9) # 负的xxx 1.11, 1.01 - number_exclude_0_to_9 |= ( - (number_exclude_0_to_9 | digits) + - (dot + digits.plus).plus - ) + number_exclude_0_to_9 |= ((number_exclude_0_to_9 | digits) + + (dot + digits.plus).plus) # 五六万,三五千,六七百,三四十 # 十七八美元 => $17~18, 四十五六岁 => 45-6岁, # 三百七八公里 => 370-80km, 三百七八十千克 => 370-80kg number_exclude_0_to_9 |= special_2number number_exclude_0_to_9 |= add_weight(special_3number, -100.0) - self.number_exclude_0_to_9 = (sign.ques + number_exclude_0_to_9).optimize() # noqa + self.number_exclude_0_to_9 = (sign.ques + + number_exclude_0_to_9).optimize() # noqa # cardinal string like 127.0.0.1, used in ID, IP, etc. cardinal = digits.plus + (dot + digits.plus).plus @@ -131,7 +132,8 @@ def build_tagger(self): cardinal |= (number + dot + digits.plus) # cardinal string like 110 or 12306 or 13125617878, used in phone, # 340621199806051223, used in ID card - cardinal |= (digits**3 | digits**4 | digits**5 | digits**11 | digits**18) + cardinal |= (digits**3 | digits**4 | digits**5 | digits**11 + | digits**18) # cardinal string like 23 if self.enable_standalone_number: if self.enable_0_to_9: diff --git a/itn/chinese/rules/date.py b/itn/chinese/rules/date.py index 2f52665..13d1d73 100644 --- a/itn/chinese/rules/date.py +++ b/itn/chinese/rules/date.py @@ -27,11 +27,11 @@ def __init__(self): def build_tagger(self): digit = string_file('itn/chinese/data/number/digit.tsv') # 1 ~ 9 - zero = string_file('itn/chinese/data/number/zero.tsv') # 0 + zero = string_file('itn/chinese/data/number/zero.tsv') # 0 yyyy = digit + (digit | zero)**3 # 二零零八年 - yyy = digit + (digit | zero)**2 # 公元一六八年 - yy = (digit | zero)**2 # 零八年奥运会 + yyy = digit + (digit | zero)**2 # 公元一六八年 + yy = (digit | zero)**2 # 零八年奥运会 mm = string_file('itn/chinese/data/date/mm.tsv') dd = string_file('itn/chinese/data/date/dd.tsv') diff --git a/itn/chinese/rules/fraction.py b/itn/chinese/rules/fraction.py index f61a1cd..7435132 100644 --- a/itn/chinese/rules/fraction.py +++ b/itn/chinese/rules/fraction.py @@ -28,16 +28,15 @@ def __init__(self): def build_tagger(self): number = Cardinal().number - sign = string_file('itn/chinese/data/number/sign.tsv') # + - + sign = string_file('itn/chinese/data/number/sign.tsv') # + - # NOTE(xcsong): default weight = 1.0, set to -1.0 means higher priority # For example, # 1.0, 负二分之三 -> { sign: "" denominator: "-2" numerator: "3" } # -1.0,负二分之三 -> { sign: "-" denominator: "2" numerator: "3" } tagger = (insert('sign: "') + add_weight(sign, -1.0).ques + - insert('" denominator: "') + number + - delete('分之') + insert('" numerator: "') + - number + insert('"')) + insert('" denominator: "') + number + delete('分之') + + insert('" numerator: "') + number + insert('"')) self.tagger = self.add_tokens(tagger) def build_verbalizer(self): diff --git a/itn/chinese/rules/license_plate.py b/itn/chinese/rules/license_plate.py index 228c3c1..5d66911 100644 --- a/itn/chinese/rules/license_plate.py +++ b/itn/chinese/rules/license_plate.py @@ -27,7 +27,8 @@ def __init__(self): def build_tagger(self): digit = string_file('itn/chinese/data/number/digit.tsv') # 1 ~ 9 - province = string_file('itn/chinese/data/license_plate/province.tsv') # 皖 + province = string_file( + 'itn/chinese/data/license_plate/province.tsv') # 皖 license_plate = province + self.ALPHA + (self.ALPHA | digit)**5 tagger = insert('value: "') + license_plate + insert('"') self.tagger = self.add_tokens(tagger) diff --git a/itn/chinese/rules/measure.py b/itn/chinese/rules/measure.py index 58e41ea..671d1d0 100644 --- a/itn/chinese/rules/measure.py +++ b/itn/chinese/rules/measure.py @@ -32,14 +32,14 @@ def build_tagger(self): units_en = string_file('itn/chinese/data/measure/units_en.tsv') units_zh = string_file('itn/chinese/data/measure/units_zh.tsv') digit = string_file('itn/chinese/data/number/digit.tsv') # 1 ~ 9 - sign = string_file('itn/chinese/data/number/sign.tsv') # + - + sign = string_file('itn/chinese/data/number/sign.tsv') # + - to = cross('到', '~') | cross('到百分之', '~') - units = add_weight((accep('亿') | accep('兆') | accep('万')), -0.5).ques + units_zh - units |= add_weight((cross('亿', '00M') | cross('兆', 'T') | - cross('万', 'W')), -0.5).ques + ( - add_weight(units_en, -1.0) - ) + units = add_weight( + (accep('亿') | accep('兆') | accep('万')), -0.5).ques + units_zh + units |= add_weight( + (cross('亿', '00M') | cross('兆', 'T') | cross('万', 'W')), + -0.5).ques + (add_weight(units_en, -1.0)) number = Cardinal().number if self.enable_0_to_9 else \ Cardinal().number_exclude_0_to_9 @@ -47,8 +47,8 @@ def build_tagger(self): percent = ((sign + delete('的').ques).ques + delete('百分') + delete('之').ques + ((Cardinal().number + (to + Cardinal().number).ques) | - ((Cardinal().number + to).ques + cross('百', '100'))) - + insert('%')) + ((Cardinal().number + to).ques + cross('百', '100'))) + + insert('%')) # 十千米每小时 => 10km/h, 十一到一百千米每小时 => 11~100km/h measure = number + (to + number).ques + units @@ -57,9 +57,8 @@ def build_tagger(self): tagger = insert('value: "') + (measure | percent) + insert('"') # 每小时十千米 => 10km/h, 每小时三十到三百一十一千米 => 30~311km/h - tagger |= ( - insert('denominator: "') + delete('每') + units + - insert('" numerator: "') + measure + insert('"')) + tagger |= (insert('denominator: "') + delete('每') + units + + insert('" numerator: "') + measure + insert('"')) self.tagger = self.add_tokens(tagger) diff --git a/itn/chinese/rules/money.py b/itn/chinese/rules/money.py index 23d118b..d000387 100644 --- a/itn/chinese/rules/money.py +++ b/itn/chinese/rules/money.py @@ -39,9 +39,9 @@ def build_tagger(self): # 三千三百八十元五毛八分 => ¥3380.58 tagger = (insert('value: "') + number + insert('"') + insert(' currency: "') + (code | symbol) + insert('"') + - insert(' decimal: "') + ( - insert(".") + digit + (delete("毛") | delete("角")) + (digit + delete("分")).ques - ).ques + insert('"')) + insert(' decimal: "') + + (insert(".") + digit + (delete("毛") | delete("角")) + + (digit + delete("分")).ques).ques + insert('"')) self.tagger = self.add_tokens(tagger) def build_verbalizer(self): diff --git a/itn/chinese/rules/time.py b/itn/chinese/rules/time.py index 27d6901..f81c340 100644 --- a/itn/chinese/rules/time.py +++ b/itn/chinese/rules/time.py @@ -31,11 +31,10 @@ def build_tagger(self): s = string_file('itn/chinese/data/time/second.tsv') noon = string_file('itn/chinese/data/time/noon.tsv') - tagger = ( - (insert('noon: "') + noon + insert('" ')).ques + - insert('hour: "') + h + insert('"') + - insert(' minute: "') + m + delete('分').ques + insert('"') + - (insert(' second: "') + s + insert('"')).ques) + tagger = ((insert('noon: "') + noon + insert('" ')).ques + + insert('hour: "') + h + insert('"') + insert(' minute: "') + + m + delete('分').ques + insert('"') + + (insert(' second: "') + s + insert('"')).ques) self.tagger = self.add_tokens(tagger) def build_verbalizer(self): @@ -44,6 +43,6 @@ def build_verbalizer(self): minute = delete(' minute: "') + self.SIGMA + delete('"') second = delete(' second: "') + self.SIGMA + delete('"') noon = delete(' noon: "') + self.SIGMA + delete('"') - verbalizer = (hour + addcolon + minute + - (addcolon + second).ques + noon.ques) + verbalizer = (hour + addcolon + minute + (addcolon + second).ques + + noon.ques) self.verbalizer = self.delete_tokens(verbalizer) diff --git a/itn/chinese/test/data/normalizer.txt b/itn/chinese/test/data/normalizer.txt index fc2a123..ac1b088 100644 --- a/itn/chinese/test/data/normalizer.txt +++ b/itn/chinese/test/data/normalizer.txt @@ -1,6 +1,6 @@ 一共有多少人 => 一共有多少人 呃这个呃啊我不知道 => 这个我不知道 -呃呃啊 => +呃呃啊 => 共四百六十五篇,约三百一十五万字 => 共465篇,约315万字 共计六点四二万人 => 共计6.42万人 同比升高零点六个百分点 => 同比升高0.6个百分点 diff --git a/itn/chinese/test/normalizer_test.py b/itn/chinese/test/normalizer_test.py index 879742e..b3efc14 100644 --- a/itn/chinese/test/normalizer_test.py +++ b/itn/chinese/test/normalizer_test.py @@ -22,25 +22,23 @@ class TestNormalizer: - normalizer = InverseNormalizer( - overwrite_cache=True, - enable_standalone_number=True, - enable_0_to_9=True, - enable_million=False) - - normalizer_cases = chain( - parse_test_case('data/cardinal.txt'), - parse_test_case('data/char.txt'), - parse_test_case('data/date.txt'), - parse_test_case('data/fraction.txt'), - parse_test_case('data/math.txt'), - parse_test_case('data/measure.txt'), - parse_test_case('data/money.txt'), - parse_test_case('data/time.txt'), - parse_test_case('data/whitelist.txt'), - parse_test_case('data/number.txt'), - parse_test_case('data/license_plate.txt'), - parse_test_case('data/normalizer.txt')) + normalizer = InverseNormalizer(overwrite_cache=True, + enable_standalone_number=True, + enable_0_to_9=True, + enable_million=False) + + normalizer_cases = chain(parse_test_case('data/cardinal.txt'), + parse_test_case('data/char.txt'), + parse_test_case('data/date.txt'), + parse_test_case('data/fraction.txt'), + parse_test_case('data/math.txt'), + parse_test_case('data/measure.txt'), + parse_test_case('data/money.txt'), + parse_test_case('data/time.txt'), + parse_test_case('data/whitelist.txt'), + parse_test_case('data/number.txt'), + parse_test_case('data/license_plate.txt'), + parse_test_case('data/normalizer.txt')) @pytest.mark.parametrize("spoken, written", normalizer_cases) def test_normalizer(self, spoken, written): @@ -49,23 +47,20 @@ def test_normalizer(self, spoken, written): class TestNormalizerDisablestandalonenumberEnable0to9: - normalizer = InverseNormalizer( - overwrite_cache=True, - enable_standalone_number=False, - enable_0_to_9=True, - enable_million=False) + normalizer = InverseNormalizer(overwrite_cache=True, + enable_standalone_number=False, + enable_0_to_9=True, + enable_million=False) normalizer_cases = chain( - parse_test_case('data/char.txt'), - parse_test_case('data/date.txt'), - parse_test_case('data/fraction.txt'), - parse_test_case('data/math.txt'), - parse_test_case('data/measure.txt'), - parse_test_case('data/money.txt'), + parse_test_case('data/char.txt'), parse_test_case('data/date.txt'), + parse_test_case('data/fraction.txt'), parse_test_case('data/math.txt'), + parse_test_case('data/measure.txt'), parse_test_case('data/money.txt'), parse_test_case('data/time.txt'), parse_test_case('data/whitelist.txt'), parse_test_case('data/license_plate.txt'), - parse_test_case('data/normalizer_disable_standalone_number_enable_0_to_9.txt')) + parse_test_case( + 'data/normalizer_disable_standalone_number_enable_0_to_9.txt')) @pytest.mark.parametrize("spoken, written", normalizer_cases) def test_normalizer(self, spoken, written): @@ -74,22 +69,19 @@ def test_normalizer(self, spoken, written): class TestNormalizerEnablestandalonenumberDisable0to9: - normalizer = InverseNormalizer( - overwrite_cache=True, - enable_standalone_number=True, - enable_0_to_9=False, - enable_million=False) + normalizer = InverseNormalizer(overwrite_cache=True, + enable_standalone_number=True, + enable_0_to_9=False, + enable_million=False) normalizer_cases = chain( - parse_test_case('data/char.txt'), - parse_test_case('data/date.txt'), - parse_test_case('data/fraction.txt'), - parse_test_case('data/math.txt'), - parse_test_case('data/money.txt'), - parse_test_case('data/time.txt'), + parse_test_case('data/char.txt'), parse_test_case('data/date.txt'), + parse_test_case('data/fraction.txt'), parse_test_case('data/math.txt'), + parse_test_case('data/money.txt'), parse_test_case('data/time.txt'), parse_test_case('data/whitelist.txt'), parse_test_case('data/license_plate.txt'), - parse_test_case('data/normalizer_enable_standalone_number_disable_0_to_9.txt')) + parse_test_case( + 'data/normalizer_enable_standalone_number_disable_0_to_9.txt')) @pytest.mark.parametrize("spoken, written", normalizer_cases) def test_normalizer(self, spoken, written): @@ -98,22 +90,19 @@ def test_normalizer(self, spoken, written): class TestNormalizerDisablestandalonenumberDisable0to9: - normalizer = InverseNormalizer( - overwrite_cache=True, - enable_standalone_number=False, - enable_0_to_9=False, - enable_million=False) + normalizer = InverseNormalizer(overwrite_cache=True, + enable_standalone_number=False, + enable_0_to_9=False, + enable_million=False) normalizer_cases = chain( - parse_test_case('data/char.txt'), - parse_test_case('data/date.txt'), - parse_test_case('data/fraction.txt'), - parse_test_case('data/math.txt'), - parse_test_case('data/money.txt'), - parse_test_case('data/time.txt'), + parse_test_case('data/char.txt'), parse_test_case('data/date.txt'), + parse_test_case('data/fraction.txt'), parse_test_case('data/math.txt'), + parse_test_case('data/money.txt'), parse_test_case('data/time.txt'), parse_test_case('data/whitelist.txt'), parse_test_case('data/license_plate.txt'), - parse_test_case('data/normalizer_disable_standalone_number_disable_0_to_9.txt')) + parse_test_case( + 'data/normalizer_disable_standalone_number_disable_0_to_9.txt')) @pytest.mark.parametrize("spoken, written", normalizer_cases) def test_normalizer(self, spoken, written): diff --git a/itn/main.py b/itn/main.py index 02990df..a188938 100644 --- a/itn/main.py +++ b/itn/main.py @@ -17,6 +17,7 @@ # TODO(xcsong): multi-language support from itn.chinese.inverse_normalizer import InverseNormalizer + def str2bool(s, default=False): s = s.lower() if s == 'true': @@ -26,28 +27,35 @@ def str2bool(s, default=False): else: return default + def main(): parser = argparse.ArgumentParser() parser.add_argument('--text', help='input string') parser.add_argument('--file', help='input file path') - parser.add_argument('--cache_dir', type=str, + parser.add_argument('--cache_dir', + type=str, default=None, help='cache dir containing *.fst') - parser.add_argument('--overwrite_cache', action='store_true', + parser.add_argument('--overwrite_cache', + action='store_true', help='rebuild *.fst') - parser.add_argument('--enable_standalone_number', type=str, + parser.add_argument('--enable_standalone_number', + type=str, default='True', help='enable standalone number') - parser.add_argument('--enable_0_to_9', type=str, + parser.add_argument('--enable_0_to_9', + type=str, default='False', help='enable convert number 0 to 9') - parser.add_argument('--enable_million', type=str, + parser.add_argument('--enable_million', + type=str, default='False', help='六百万 = 6000000 if True else 600万') args = parser.parse_args() normalizer = InverseNormalizer( - cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache, + cache_dir=args.cache_dir, + overwrite_cache=args.overwrite_cache, enable_standalone_number=str2bool(args.enable_standalone_number), enable_0_to_9=str2bool(args.enable_0_to_9), enable_million=str2bool(args.enable_million)) diff --git a/runtime/android/app/src/main/cpp/wetextprocessing.cc b/runtime/android/app/src/main/cpp/wetextprocessing.cc index b76ef25..4c3fbc5 100644 --- a/runtime/android/app/src/main/cpp/wetextprocessing.cc +++ b/runtime/android/app/src/main/cpp/wetextprocessing.cc @@ -34,7 +34,7 @@ void init(JNIEnv* env, jobject, jstring jModelDir) { processorITN = std::make_shared(itnTagger, itnVerbalizer); } -jstring normalize(JNIEnv *env, jobject, jstring input) { +jstring normalize(JNIEnv* env, jobject, jstring input) { std::string input_text = std::string(env->GetStringUTFChars(input, nullptr)); std::string tagged_text = processorTN->Tag(input_text); std::string normalized_text = processorTN->Verbalize(tagged_text); @@ -42,7 +42,7 @@ jstring normalize(JNIEnv *env, jobject, jstring input) { return env->NewStringUTF(normalized_text.c_str()); } -jstring inverse_normalize(JNIEnv *env, jobject, jstring input) { +jstring inverse_normalize(JNIEnv* env, jobject, jstring input) { std::string input_text = std::string(env->GetStringUTFChars(input, nullptr)); std::string tagged_text = processorITN->Tag(input_text); std::string normalized_text = processorITN->Verbalize(tagged_text); @@ -51,9 +51,9 @@ jstring inverse_normalize(JNIEnv *env, jobject, jstring input) { } } // namespace wetextprocessing -JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) { - JNIEnv *env; - if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { return JNI_ERR; } @@ -63,12 +63,12 @@ JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) { } static const JNINativeMethod methods[] = { - {"init", "(Ljava/lang/String;)V", reinterpret_cast(wetextprocessing::init)}, - {"normalize", "(Ljava/lang/String;)Ljava/lang/String;", - reinterpret_cast(wetextprocessing::normalize)}, - {"inverse_normalize", "(Ljava/lang/String;)Ljava/lang/String;", - reinterpret_cast(wetextprocessing::inverse_normalize)} - }; + {"init", "(Ljava/lang/String;)V", + reinterpret_cast(wetextprocessing::init)}, + {"normalize", "(Ljava/lang/String;)Ljava/lang/String;", + reinterpret_cast(wetextprocessing::normalize)}, + {"inverse_normalize", "(Ljava/lang/String;)Ljava/lang/String;", + reinterpret_cast(wetextprocessing::inverse_normalize)}}; int rc = env->RegisterNatives(c, methods, sizeof(methods) / sizeof(JNINativeMethod)); diff --git a/runtime/patch/openfst/src/include/fst/flags.h b/runtime/patch/openfst/src/include/fst/flags.h index b5ec8ff..6ef1257 100644 --- a/runtime/patch/openfst/src/include/fst/flags.h +++ b/runtime/patch/openfst/src/include/fst/flags.h @@ -26,8 +26,8 @@ #include #include -#include #include +#include #include "gflags/gflags.h" #include "glog/logging.h" @@ -59,94 +59,90 @@ using std::string; template struct FlagDescription { - FlagDescription(T *addr, const char *doc, const char *type, - const char *file, const T val) + FlagDescription(T* addr, const char* doc, const char* type, const char* file, + const T val) : address(addr), - doc_string(doc), - type_name(type), - file_name(file), - default_value(val) {} - - T *address; - const char *doc_string; - const char *type_name; - const char *file_name; + doc_string(doc), + type_name(type), + file_name(file), + default_value(val) {} + + T* address; + const char* doc_string; + const char* type_name; + const char* file_name; const T default_value; }; template class FlagRegister { public: - static FlagRegister *GetRegister() { + static FlagRegister* GetRegister() { static auto reg = new FlagRegister; return reg; } - const FlagDescription &GetFlagDescription(const string &name) const { + const FlagDescription& GetFlagDescription(const string& name) const { fst::MutexLock l(&flag_lock_); auto it = flag_table_.find(name); return it != flag_table_.end() ? it->second : 0; } - void SetDescription(const string &name, - const FlagDescription &desc) { + void SetDescription(const string& name, const FlagDescription& desc) { fst::MutexLock l(&flag_lock_); flag_table_.insert(make_pair(name, desc)); } - bool SetFlag(const string &val, bool *address) const { + bool SetFlag(const string& val, bool* address) const { if (val == "true" || val == "1" || val.empty()) { *address = true; return true; } else if (val == "false" || val == "0") { *address = false; return true; - } - else { + } else { return false; } } - bool SetFlag(const string &val, string *address) const { + bool SetFlag(const string& val, string* address) const { *address = val; return true; } - bool SetFlag(const string &val, int32 *address) const { - char *p = 0; + bool SetFlag(const string& val, int32* address) const { + char* p = 0; *address = strtol(val.c_str(), &p, 0); return !val.empty() && *p == '\0'; } - bool SetFlag(const string &val, int64 *address) const { - char *p = 0; + bool SetFlag(const string& val, int64* address) const { + char* p = 0; *address = strtoll(val.c_str(), &p, 0); return !val.empty() && *p == '\0'; } - bool SetFlag(const string &val, double *address) const { - char *p = 0; + bool SetFlag(const string& val, double* address) const { + char* p = 0; *address = strtod(val.c_str(), &p); return !val.empty() && *p == '\0'; } - bool SetFlag(const string &arg, const string &val) const { - for (typename std::map< string, FlagDescription >::const_iterator it = - flag_table_.begin(); - it != flag_table_.end(); - ++it) { - const string &name = it->first; - const FlagDescription &desc = it->second; - if (arg == name) - return SetFlag(val, desc.address); + bool SetFlag(const string& arg, const string& val) const { + for (typename std::map>::const_iterator it = + flag_table_.begin(); + it != flag_table_.end(); ++it) { + const string& name = it->first; + const FlagDescription& desc = it->second; + if (arg == name) return SetFlag(val, desc.address); } return false; } - void GetUsage(std::set> *usage_set) const { + void GetUsage(std::set>* usage_set) const { for (auto it = flag_table_.begin(); it != flag_table_.end(); ++it) { - const string &name = it->first; - const FlagDescription &desc = it->second; + const string& name = it->first; + const FlagDescription& desc = it->second; string usage = " --" + name; usage += ": type = "; usage += desc.type_name; @@ -162,43 +158,39 @@ class FlagRegister { return default_value ? "true" : "false"; } - string GetDefault(const string &default_value) const { + string GetDefault(const string& default_value) const { return "\"" + default_value + "\""; } template - string GetDefault(const V &default_value) const { + string GetDefault(const V& default_value) const { std::ostringstream strm; strm << default_value; return strm.str(); } - mutable fst::Mutex flag_lock_; // Multithreading lock. + mutable fst::Mutex flag_lock_; // Multithreading lock. std::map> flag_table_; }; template class FlagRegisterer { public: - FlagRegisterer(const string &name, const FlagDescription &desc) { + FlagRegisterer(const string& name, const FlagDescription& desc) { auto registr = FlagRegister::GetRegister(); registr->SetDescription(name, desc); } private: - FlagRegisterer(const FlagRegisterer &) = delete; - FlagRegisterer &operator=(const FlagRegisterer &) = delete; + FlagRegisterer(const FlagRegisterer&) = delete; + FlagRegisterer& operator=(const FlagRegisterer&) = delete; }; - -#define DEFINE_VAR(type, name, value, doc) \ - type FLAGS_ ## name = value; \ - static FlagRegisterer \ - name ## _flags_registerer(#name, FlagDescription(&FLAGS_ ## name, \ - doc, \ - #type, \ - __FILE__, \ - value)) +#define DEFINE_VAR(type, name, value, doc) \ + type FLAGS_##name = value; \ + static FlagRegisterer name##_flags_registerer( \ + #name, \ + FlagDescription(&FLAGS_##name, doc, #type, __FILE__, value)) // #define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc) // #define DEFINE_string(name, value, doc) \ @@ -207,19 +199,18 @@ class FlagRegisterer { // #define DEFINE_int64(name, value, doc) DEFINE_VAR(int64, name, value, doc) // #define DEFINE_double(name, value, doc) DEFINE_VAR(double, name, value, doc) - // Temporary directory. DECLARE_string(tmpdir); -void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, - const char *src = ""); +void SetFlags(const char* usage, int* argc, char*** argv, bool remove_flags, + const char* src = ""); #define SET_FLAGS(usage, argc, argv, rmflags) \ -gflags::ParseCommandLineFlags(argc, argv, true) + gflags::ParseCommandLineFlags(argc, argv, true) // SetFlags(usage, argc, argv, rmflags, __FILE__) // Deprecated; for backward compatibility. -inline void InitFst(const char *usage, int *argc, char ***argv, bool rmflags) { +inline void InitFst(const char* usage, int* argc, char*** argv, bool rmflags) { return SetFlags(usage, argc, argv, rmflags); } diff --git a/runtime/patch/openfst/src/include/fst/log.h b/runtime/patch/openfst/src/include/fst/log.h index bf041c5..6a8f301 100644 --- a/runtime/patch/openfst/src/include/fst/log.h +++ b/runtime/patch/openfst/src/include/fst/log.h @@ -22,8 +22,8 @@ #include #include -#include #include +#include using std::string; @@ -31,15 +31,14 @@ DECLARE_int32(v); class LogMessage { public: - LogMessage(const string &type) : fatal_(type == "FATAL") { + LogMessage(const string& type) : fatal_(type == "FATAL") { std::cerr << type << ": "; } ~LogMessage() { std::cerr << std::endl; - if(fatal_) - exit(1); + if (fatal_) exit(1); } - std::ostream &stream() { return std::cerr; } + std::ostream& stream() { return std::cerr; } private: bool fatal_; @@ -49,11 +48,9 @@ class LogMessage { // #define VLOG(level) if ((level) <= FLAGS_v) LOG(INFO) // Checks -inline void FstCheck(bool x, const char* expr, - const char *file, int line) { +inline void FstCheck(bool x, const char* expr, const char* file, int line) { if (!x) { - LOG(FATAL) << "Check failed: \"" << expr - << "\" file: " << file + LOG(FATAL) << "Check failed: \"" << expr << "\" file: " << file << " line: " << line; } } @@ -75,7 +72,6 @@ inline void FstCheck(bool x, const char* expr, // #define DCHECK_GE(x, y) DCHECK((x) >= (y)) // #define DCHECK_NE(x, y) DCHECK((x) != (y)) - // Ports #define ATTRIBUTE_DEPRECATED __attribute__((deprecated)) diff --git a/runtime/patch/openfst/src/lib/flags.cc b/runtime/patch/openfst/src/lib/flags.cc index 95f7e2e..24dcd67 100644 --- a/runtime/patch/openfst/src/lib/flags.cc +++ b/runtime/patch/openfst/src/lib/flags.cc @@ -15,14 +15,14 @@ #include #if _MSC_VER -#include #include +#include #endif #include #include -static const char *private_tmpdir = getenv("TMPDIR"); +static const char* private_tmpdir = getenv("TMPDIR"); // DEFINE_int32(v, 0, "verbosity level"); // DEFINE_bool(help, false, "show usage information"); @@ -33,7 +33,7 @@ DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : "/tmp", #else DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : getenv("TEMP"), "temporary directory"); -#endif // !_MSC_VER +#endif // !_MSC_VER using namespace std; @@ -41,7 +41,7 @@ static string flag_usage; static string prog_src; // Sets prog_src to src. -static void SetProgSrc(const char *src) { +static void SetProgSrc(const char* src) { prog_src = src; #if _MSC_VER // This common code is invoked by all FST binaries, and only by them. Switch @@ -65,8 +65,8 @@ static void SetProgSrc(const char *src) { } } -void SetFlags(const char *usage, int *argc, char ***argv, - bool remove_flags, const char *src) { +void SetFlags(const char* usage, int* argc, char*** argv, bool remove_flags, + const char* src) { flag_usage = usage; SetProgSrc(src); @@ -84,20 +84,15 @@ void SetFlags(const char *usage, int *argc, char ***argv, val = argval.substr(pos + 1); } auto bool_register = FlagRegister::GetRegister(); - if (bool_register->SetFlag(arg, val)) - continue; + if (bool_register->SetFlag(arg, val)) continue; auto string_register = FlagRegister::GetRegister(); - if (string_register->SetFlag(arg, val)) - continue; + if (string_register->SetFlag(arg, val)) continue; auto int32_register = FlagRegister::GetRegister(); - if (int32_register->SetFlag(arg, val)) - continue; + if (int32_register->SetFlag(arg, val)) continue; auto int64_register = FlagRegister::GetRegister(); - if (int64_register->SetFlag(arg, val)) - continue; + if (int64_register->SetFlag(arg, val)) continue; auto double_register = FlagRegister::GetRegister(); - if (double_register->SetFlag(arg, val)) - continue; + if (double_register->SetFlag(arg, val)) continue; LOG(FATAL) << "SetFlags: Bad option: " << (*argv)[index]; } if (remove_flags) { @@ -118,15 +113,14 @@ void SetFlags(const char *usage, int *argc, char ***argv, // If flag is defined in file 'src' and 'in_src' true or is not // defined in file 'src' and 'in_src' is false, then print usage. -static void -ShowUsageRestrict(const std::set> &usage_set, - const string &src, bool in_src, bool show_file) { +static void ShowUsageRestrict(const std::set>& usage_set, + const string& src, bool in_src, bool show_file) { string old_file; bool file_out = false; bool usage_out = false; - for (const auto &pair : usage_set) { - const auto &file = pair.first; - const auto &usage = pair.second; + for (const auto& pair : usage_set) { + const auto& file = pair.first; + const auto& usage = pair.second; bool match = file == src; if ((match && !in_src) || (!match && in_src)) continue; if (file != old_file) { diff --git a/tn/chinese/data/number/teen.tsv b/tn/chinese/data/number/teen.tsv index 55620af..07c01df 100644 --- a/tn/chinese/data/number/teen.tsv +++ b/tn/chinese/data/number/teen.tsv @@ -1,4 +1,4 @@ -1 +1 2 二 3 三 4 四 @@ -7,7 +7,7 @@ 7 七 8 八 9 九 -1 +1 2 二 3 三 4 四 diff --git a/tn/chinese/data/time/minute.tsv b/tn/chinese/data/time/minute.tsv index 38879de..c0fc2da 100644 --- a/tn/chinese/data/time/minute.tsv +++ b/tn/chinese/data/time/minute.tsv @@ -1,4 +1,4 @@ -00 +00 01 零一分 02 零二分 03 零三分 diff --git a/tn/chinese/data/time/second.tsv b/tn/chinese/data/time/second.tsv index ac6280d..fba72c6 100644 --- a/tn/chinese/data/time/second.tsv +++ b/tn/chinese/data/time/second.tsv @@ -1,4 +1,4 @@ -00 +00 01 一秒 02 二秒 03 三秒 diff --git a/tn/chinese/rules/postprocessor.py b/tn/chinese/rules/postprocessor.py index 79ee462..3f19914 100644 --- a/tn/chinese/rules/postprocessor.py +++ b/tn/chinese/rules/postprocessor.py @@ -21,8 +21,11 @@ class PostProcessor(Processor): - def __init__(self, remove_interjections=True, remove_puncts=False, - full_to_half=True, tag_oov=False): + def __init__(self, + remove_interjections=True, + remove_puncts=False, + full_to_half=True, + tag_oov=False): super().__init__(name='postprocessor') blacklist = string_file('tn/chinese/data/default/blacklist.tsv') puncts = string_file('tn/chinese/data/char/punctuations_zh.tsv') diff --git a/tn/chinese/rules/whitelist.py b/tn/chinese/rules/whitelist.py index c081fd0..b7d99c4 100644 --- a/tn/chinese/rules/whitelist.py +++ b/tn/chinese/rules/whitelist.py @@ -39,6 +39,6 @@ def build_verbalizer(self): if self.remove_erhua: verbalizer = self.delete_tokens(delete('erhua: "儿"')) else: - verbalizer = self.delete_tokens(delete('erhua: \"') + - accep('儿') + delete('\"')) + verbalizer = self.delete_tokens( + delete('erhua: \"') + accep('儿') + delete('\"')) self.verbalizer |= verbalizer diff --git a/tn/chinese/test/data/postprocessor.txt b/tn/chinese/test/data/postprocessor.txt index baca4b2..5fb00c1 100644 --- a/tn/chinese/test/data/postprocessor.txt +++ b/tn/chinese/test/data/postprocessor.txt @@ -1,5 +1,5 @@ 好! => 好! 好啊 => 好 -啊呃呃 => +啊呃呃 => 我们안녕 => 我们 雪の花 => 雪花 diff --git a/tn/main.py b/tn/main.py index 0312044..1c3de5a 100644 --- a/tn/main.py +++ b/tn/main.py @@ -23,39 +23,48 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument('--text', help='input string') parser.add_argument('--file', help='input file path') - parser.add_argument('--cache_dir', type=str, + parser.add_argument('--cache_dir', + type=str, default=None, help='cache dir containing *.fst') - parser.add_argument('--overwrite_cache', action='store_true', + parser.add_argument('--overwrite_cache', + action='store_true', help='rebuild *.fst') - parser.add_argument('--remove_interjections', type=str, + parser.add_argument('--remove_interjections', + type=str, default='True', help='remove interjections like "啊"') - parser.add_argument('--remove_erhua', type=str, + parser.add_argument('--remove_erhua', + type=str, default='True', help='remove "儿"') - parser.add_argument('--traditional_to_simple', type=str, + parser.add_argument('--traditional_to_simple', + type=str, default='True', help='i.e., "喆" -> "哲"') - parser.add_argument('--remove_puncts', type=str, + parser.add_argument('--remove_puncts', + type=str, default='False', help='remove punctuations like "。" and ","') - parser.add_argument('--full_to_half', type=str, + parser.add_argument('--full_to_half', + type=str, default='True', help='i.e., "A" -> "A"') - parser.add_argument('--tag_oov', type=str, + parser.add_argument('--tag_oov', + type=str, default='False', help='tag OOV with "OOV"') args = parser.parse_args() - normalizer = Normalizer(cache_dir=args.cache_dir, - overwrite_cache=args.overwrite_cache, - remove_interjections=str2bool(args.remove_interjections), - remove_erhua=str2bool(args.remove_erhua), - traditional_to_simple=str2bool(args.traditional_to_simple), - remove_puncts=str2bool(args.remove_puncts), - full_to_half=str2bool(args.full_to_half), - tag_oov=str2bool(args.tag_oov)) + normalizer = Normalizer( + cache_dir=args.cache_dir, + overwrite_cache=args.overwrite_cache, + remove_interjections=str2bool(args.remove_interjections), + remove_erhua=str2bool(args.remove_erhua), + traditional_to_simple=str2bool(args.traditional_to_simple), + remove_puncts=str2bool(args.remove_puncts), + full_to_half=str2bool(args.full_to_half), + tag_oov=str2bool(args.tag_oov)) if args.text: print(normalizer.tag(args.text)) diff --git a/tn/token_parser.py b/tn/token_parser.py index deb92fd..d047d86 100644 --- a/tn/token_parser.py +++ b/tn/token_parser.py @@ -20,16 +20,19 @@ 'fraction': ['denominator', 'numerator'], 'measure': ['denominator', 'numerator', 'value'], 'money': ['value', 'currency'], - 'time': ['noon', 'hour', 'minute', 'second']} + 'time': ['noon', 'hour', 'minute', 'second'] +} ITN_ORDERS = { 'date': ['year', 'month', 'day'], 'fraction': ['sign', 'numerator', 'denominator'], 'measure': ['numerator', 'denominator', 'value'], 'money': ['currency', 'value', 'decimal'], - 'time': ['hour', 'minute', 'second', 'noon']} + 'time': ['hour', 'minute', 'second', 'noon'] +} class Token: + def __init__(self, name): self.name = name self.order = [] @@ -52,6 +55,7 @@ def string(self, orders): class TokenParser: + def __init__(self, ordertype="tn"): if ordertype == "tn": self.orders = TN_ORDERS From 93b78f4d36cecda6470b873b749c6547f4ce91e6 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Thu, 7 Dec 2023 21:10:02 +0800 Subject: [PATCH 2/3] feat(all): format all files --- .pre-commit-config.yaml | 1 + itn/chinese/test/data/normalizer.txt | 2 +- tn/chinese/data/number/teen.tsv | 4 ++-- tn/chinese/data/time/minute.tsv | 2 +- tn/chinese/data/time/second.tsv | 2 +- tn/chinese/test/data/postprocessor.txt | 2 +- 6 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bee2fc6..6c09afb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,4 @@ +exclude: '.*\.(txt|tsv)$' repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 diff --git a/itn/chinese/test/data/normalizer.txt b/itn/chinese/test/data/normalizer.txt index ac1b088..fc2a123 100644 --- a/itn/chinese/test/data/normalizer.txt +++ b/itn/chinese/test/data/normalizer.txt @@ -1,6 +1,6 @@ 一共有多少人 => 一共有多少人 呃这个呃啊我不知道 => 这个我不知道 -呃呃啊 => +呃呃啊 => 共四百六十五篇,约三百一十五万字 => 共465篇,约315万字 共计六点四二万人 => 共计6.42万人 同比升高零点六个百分点 => 同比升高0.6个百分点 diff --git a/tn/chinese/data/number/teen.tsv b/tn/chinese/data/number/teen.tsv index 07c01df..55620af 100644 --- a/tn/chinese/data/number/teen.tsv +++ b/tn/chinese/data/number/teen.tsv @@ -1,4 +1,4 @@ -1 +1 2 二 3 三 4 四 @@ -7,7 +7,7 @@ 7 七 8 八 9 九 -1 +1 2 二 3 三 4 四 diff --git a/tn/chinese/data/time/minute.tsv b/tn/chinese/data/time/minute.tsv index c0fc2da..38879de 100644 --- a/tn/chinese/data/time/minute.tsv +++ b/tn/chinese/data/time/minute.tsv @@ -1,4 +1,4 @@ -00 +00 01 零一分 02 零二分 03 零三分 diff --git a/tn/chinese/data/time/second.tsv b/tn/chinese/data/time/second.tsv index fba72c6..ac6280d 100644 --- a/tn/chinese/data/time/second.tsv +++ b/tn/chinese/data/time/second.tsv @@ -1,4 +1,4 @@ -00 +00 01 一秒 02 二秒 03 三秒 diff --git a/tn/chinese/test/data/postprocessor.txt b/tn/chinese/test/data/postprocessor.txt index 5fb00c1..baca4b2 100644 --- a/tn/chinese/test/data/postprocessor.txt +++ b/tn/chinese/test/data/postprocessor.txt @@ -1,5 +1,5 @@ 好! => 好! 好啊 => 好 -啊呃呃 => +啊呃呃 => 我们안녕 => 我们 雪の花 => 雪花 From d16bc8f4f5335bfc2e5cbe0947cccbf880a050d0 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Thu, 7 Dec 2023 21:25:27 +0800 Subject: [PATCH 3/3] feat(all): format all files --- runtime/processor/wetext_processor.h | 3 +++ runtime/processor/wetext_token_parser.cc | 2 +- runtime/processor/wetext_token_parser.h | 6 +++--- runtime/test/processor_test.cc | 2 +- tn/chinese/rules/sport.py | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/runtime/processor/wetext_processor.h b/runtime/processor/wetext_processor.h index e010aea..e11d307 100644 --- a/runtime/processor/wetext_processor.h +++ b/runtime/processor/wetext_processor.h @@ -15,6 +15,9 @@ #ifndef PROCESSOR_WETEXT_PROCESSOR_H_ #define PROCESSOR_WETEXT_PROCESSOR_H_ +#include +#include + #include "fst/fstlib.h" #include "processor/wetext_token_parser.h" diff --git a/runtime/processor/wetext_token_parser.cc b/runtime/processor/wetext_token_parser.cc index ad696ba..cea4f3f 100644 --- a/runtime/processor/wetext_token_parser.cc +++ b/runtime/processor/wetext_token_parser.cc @@ -18,7 +18,7 @@ #include "utils/wetext_string.h" namespace wetext { -const std::string EOS = ""; +const char EOS[] = ""; const std::set UTF8_WHITESPACE = {" ", "\t", "\n", "\r", "\x0b\x0c"}; const std::set ASCII_LETTERS = { diff --git a/runtime/processor/wetext_token_parser.h b/runtime/processor/wetext_token_parser.h index e035864..766ea7a 100644 --- a/runtime/processor/wetext_token_parser.h +++ b/runtime/processor/wetext_token_parser.h @@ -22,7 +22,7 @@ namespace wetext { -extern const std::string EOS; +extern const char EOS[]; extern const std::set UTF8_WHITESPACE; extern const std::set ASCII_LETTERS; extern const std::unordered_map> @@ -35,7 +35,7 @@ struct Token { std::vector order; std::unordered_map members; - Token(const std::string& name) : name(name) {} + explicit Token(const std::string& name) : name(name) {} void Append(const std::string& key, const std::string& value) { order.emplace_back(key); @@ -66,7 +66,7 @@ enum ParseType { class TokenParser { public: - TokenParser(ParseType type); + explicit TokenParser(ParseType type); std::string Reorder(const std::string& input); private: diff --git a/runtime/test/processor_test.cc b/runtime/test/processor_test.cc index 5d77749..e6df32f 100644 --- a/runtime/test/processor_test.cc +++ b/runtime/test/processor_test.cc @@ -57,7 +57,7 @@ class ProcessorTest processor = new wetext::Processor(tagger_path, verbalizer_path); written = GetParam().first; spoken = GetParam().second; - }; + } virtual void TearDown() { delete processor; } }; diff --git a/tn/chinese/rules/sport.py b/tn/chinese/rules/sport.py index 2cb8ba7..97e5ed4 100644 --- a/tn/chinese/rules/sport.py +++ b/tn/chinese/rules/sport.py @@ -15,7 +15,7 @@ from tn.chinese.rules.cardinal import Cardinal from tn.processor import Processor -from pynini import cross, string_file +from pynini import string_file from pynini.lib.pynutil import delete, insert