Skip to content

Commit

Permalink
Enable multi-label prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
Eszti committed Jan 27, 2022
1 parent ff78e58 commit 8769eeb
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions xpotato/graph_extractor/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def init_nlp(self):
def parse_iterable(self, iterable, graph_type="fourlang"):
if graph_type == "fourlang":
with TextTo4lang(
lang=self.lang, nlp_cache=self.cache_fn, cache_dir=self.cache_dir
lang=self.lang, nlp_cache=self.cache_fn, cache_dir=self.cache_dir
) as tfl:
for sen in tqdm(iterable):
fl_graphs = list(tfl(sen))
Expand Down Expand Up @@ -74,7 +74,7 @@ class FeatureEvaluator:
def __init__(self, graph_format="ud"):
self.graph_format = graph_format

def match_features(self, dataset, features):
def match_features(self, dataset, features, multi=False):
graphs = dataset.graph.tolist()

matches = []
Expand All @@ -84,13 +84,10 @@ def match_features(self, dataset, features):

for i, g in tqdm(enumerate(graphs)):
feats = matcher.match(g)
for key, feature in feats:
matches.append(features[feature])
predicted.append(key)
break
if multi:
self.match_multi(feats, features, matches, predicted)
else:
matches.append("")
predicted.append("")
self.match_not_multi(feats, features, matches, predicted)

d = {
"Sentence": dataset.text.tolist(),
Expand All @@ -100,6 +97,29 @@ def match_features(self, dataset, features):
df = pd.DataFrame(d)
return df

def match_multi(self, feats, features, matches, predicted):
keys = []
matched_rules = []
for key, feature in feats:
if key not in keys:
matched_rules.append(features[feature])
keys.append(key)
if not keys:
matches.append("")
predicted.append("")
else:
matches.append(matched_rules)
predicted.append(keys)

def match_not_multi(self, feats, features, matches, predicted):
for key, feature in feats:
matches.append(features[feature])
predicted.append(key)
break
else:
matches.append("")
predicted.append("")

def one_versus_rest(self, df, entity):
mapper = {entity: 1}

Expand Down Expand Up @@ -254,8 +274,8 @@ def select_words(self, trained_features):

for word in words_to_measures:
if words_to_measures[word]["precision"] > 0.9 and (
words_to_measures[word]["TP"] > 1
or words_to_measures[word]["recall"] > 0.01
words_to_measures[word]["TP"] > 1
or words_to_measures[word]["recall"] > 0.01
):
selected_words.add(word)

Expand Down Expand Up @@ -291,7 +311,7 @@ def evaluate_feature(self, cl, features, data, graph_format="ud"):

accuracy = []
for pcf in precision_recall_fscore_support(
labels, whole_predicted, average=None
labels, whole_predicted, average=None
):
if len(pcf) > 1:
accuracy.append(pcf[1])
Expand Down

0 comments on commit 8769eeb

Please sign in to comment.