-
Notifications
You must be signed in to change notification settings - Fork 5
/
generate_features.py
107 lines (90 loc) · 4.24 KB
/
generate_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import pickle
import itertools
import numpy as np
import pandas as pd
import rdflib
import tqdm
from SemanticProcessor.decoder import decode
from SemanticProcessor import encoder, generator
from WFL.kernel import wf_kernel
from sklearn.metrics import accuracy_score, f1_score, classification_report
import warnings
warnings.filterwarnings('ignore')
# ################################## #
# Load our headache & ICHD KG #
# ################################## #
print("Creating the migbase + ICHD knowledge graph (rdflib)...")
g = rdflib.Graph()
g.parse('data/headache_KG.ttl', format='turtle')
# First, we build a dictionary with labels
labels = {}
qres = g.query("""SELECT ?headache ?label WHERE {
?headache chron:isType ?label .
}""",
initNs={'chron': rdflib.Namespace('http://chronicals.ugent.be/')})
for row in qres:
labels[row[0]] = row[1]
# Then, we remove all triples that have chron:isType as predicate
# since these form a data leak...
new_g = rdflib.Graph()
qres = g.query("""SELECT ?s ?p ?o WHERE {
?s ?p ?o .
MINUS {
?s chron:isType ?o .
}
}""",
initNs={'chron': rdflib.Namespace('http://chronicals.ugent.be/')})
for s, p, o in qres:
if 'ugent' in str(p) and 'isType' in str(p):
# This shouldn't happen... (if the query works)
print('We added the label to the graph...')
new_g.add((s, p, o))
g = new_g
# Create a 'prototype' KG for each class based on the ICHD knowledge base
ichd_kg = rdflib.Graph()
ichd_kg.parse('data/ICHD_KB.ttl', format='turtle')
qres = ichd_kg.query("""SELECT ?diagnose ?property ?item WHERE {
?diagnose rdfs:subClassOf ?bnode1 .
?bnode1 owl:intersectionOf ?bnode2 .
?bnode2 rdf:type owl:Restriction .
?bnode2 owl:onProperty ?property .
?bnode2 (owl:oneValueFrom|owl:someValuesFrom)/rdf:rest*/rdf:first ?item .
}
""",
initNs={'rdf': rdflib.Namespace('http://www.w3.org/1999/02/22-rdf-syntax-ns#'),
'rdfs': rdflib.Namespace('http://www.w3.org/2000/01/rdf-schema#'),
'owl': rdflib.Namespace('http://www.w3.org/2002/07/owl#'),
'chron': rdflib.Namespace('http://chronicals.ugent.be/')})
# The roots of the prototype subgraphs
prototypes = [rdflib.URIRef('http://chronicals.ugent.be/Cluster'),
rdflib.URIRef('http://chronicals.ugent.be/Tension'),
rdflib.URIRef('http://chronicals.ugent.be/Migraine')]
for s,p,o in qres:
if 'ugent' in str(p) and 'isType' in str(p): print('We added the label to the graph...')
g.add((s,p,o))
# Convert URIRefs to integers to use sklearn metrics
uri_to_int = {rdflib.URIRef('http://chronicals.ugent.be/Cluster'): 0,
rdflib.URIRef('http://chronicals.ugent.be/Tension'): 1,
rdflib.URIRef('http://chronicals.ugent.be/Migraine'): 2}
# ################################## #
# Create feature vectors #
# ################################## #
print('Generating distances from each sample (represented as a graph) to each class concept (graph)...')
correct = 0
total = 0
real_labels = []
predicted_labels = []
wf_features = {}
for headache in tqdm.tqdm(labels.keys()):
feature_vector = [sum(wf_kernel(g, prototype, headache)[1:]) for prototype in prototypes]
wf_features[int(str(headache).split('#')[-1])] = feature_vector
correct += prototypes[np.argmax([sum(wf_kernel(g, prototype, headache)[1:]) for prototype in prototypes])] == labels[headache]
total += 1
real_labels.append(uri_to_int[labels[headache]])
predicted_labels.append(uri_to_int[prototypes[np.argmax([sum(wf_kernel(g, prototype, headache)[1:]) for prototype in prototypes])]])
print('Metrics...')
print('Unsupervised accuracy:', accuracy_score(real_labels, predicted_labels))
print('Unsupervised F1:', f1_score(real_labels, predicted_labels, average='micro'))
print(classification_report(real_labels, predicted_labels))
print('Writing features to data/...')
pickle.dump(wf_features, open('data/wf_features.p', 'wb'))