From 8769eebebd5034d3e59552dd7d3add4784bee5a2 Mon Sep 17 00:00:00 2001 From: Eszti Date: Thu, 27 Jan 2022 19:12:14 +0100 Subject: [PATCH] Enable multi-label prediction --- xpotato/graph_extractor/extract.py | 42 ++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/xpotato/graph_extractor/extract.py b/xpotato/graph_extractor/extract.py index 79786c0..d08ffa0 100644 --- a/xpotato/graph_extractor/extract.py +++ b/xpotato/graph_extractor/extract.py @@ -43,7 +43,7 @@ def init_nlp(self): def parse_iterable(self, iterable, graph_type="fourlang"): if graph_type == "fourlang": with TextTo4lang( - lang=self.lang, nlp_cache=self.cache_fn, cache_dir=self.cache_dir + lang=self.lang, nlp_cache=self.cache_fn, cache_dir=self.cache_dir ) as tfl: for sen in tqdm(iterable): fl_graphs = list(tfl(sen)) @@ -74,7 +74,7 @@ class FeatureEvaluator: def __init__(self, graph_format="ud"): self.graph_format = graph_format - def match_features(self, dataset, features): + def match_features(self, dataset, features, multi=False): graphs = dataset.graph.tolist() matches = [] @@ -84,13 +84,10 @@ def match_features(self, dataset, features): for i, g in tqdm(enumerate(graphs)): feats = matcher.match(g) - for key, feature in feats: - matches.append(features[feature]) - predicted.append(key) - break + if multi: + self.match_multi(feats, features, matches, predicted) else: - matches.append("") - predicted.append("") + self.match_not_multi(feats, features, matches, predicted) d = { "Sentence": dataset.text.tolist(), @@ -100,6 +97,29 @@ def match_features(self, dataset, features): df = pd.DataFrame(d) return df + def match_multi(self, feats, features, matches, predicted): + keys = [] + matched_rules = [] + for key, feature in feats: + if key not in keys: + matched_rules.append(features[feature]) + keys.append(key) + if not keys: + matches.append("") + predicted.append("") + else: + matches.append(matched_rules) + predicted.append(keys) + + def match_not_multi(self, feats, features, matches, predicted): + for key, feature in feats: + matches.append(features[feature]) + predicted.append(key) + break + else: + matches.append("") + predicted.append("") + def one_versus_rest(self, df, entity): mapper = {entity: 1} @@ -254,8 +274,8 @@ def select_words(self, trained_features): for word in words_to_measures: if words_to_measures[word]["precision"] > 0.9 and ( - words_to_measures[word]["TP"] > 1 - or words_to_measures[word]["recall"] > 0.01 + words_to_measures[word]["TP"] > 1 + or words_to_measures[word]["recall"] > 0.01 ): selected_words.add(word) @@ -291,7 +311,7 @@ def evaluate_feature(self, cl, features, data, graph_format="ud"): accuracy = [] for pcf in precision_recall_fscore_support( - labels, whole_predicted, average=None + labels, whole_predicted, average=None ): if len(pcf) > 1: accuracy.append(pcf[1])